Uddeshsingh/q4k fused kernels#20231
Conversation
Replace the export-time GGUF-to-MLX qparam repack path with fused Metal kernels
Keep the legacy MLX-native repack path available when the env var is set to 0, per maintainer request on pytorch#20172.
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20231
Note: Links to docs will display an error until the docs builds have been completed.
|
|
Hi @uddeshsingh! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
This PR needs a
|
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
Adds fused Q4_K Metal kernel support (linear + embedding) to the MLX GGUF lowering path, with an environment-variable switch to fall back to the legacy MLX-native repack implementation.
Changes:
- Extend GGUF linear/embedding tests to cover both Q6_K and Q4_K.
- Implement fused Q4_K Metal kernels for linear (mat-vec + mat-mat + dynamic IfNode) and embedding gather.
- Add
ET_MLX_EMIT_DIRECT_GGUF-controlled dispatch between fused-kernel and legacy repack paths.
Reviewed changes
Copilot reviewed 10 out of 10 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
| backends/mlx/custom_kernel_ops/gguf/test/test_linear.py | Updates linear tests to include Q4_K configs and adjusts reference path assumptions. |
| backends/mlx/custom_kernel_ops/gguf/test/test_embedding.py | Updates embedding tests to generate Q4_K blobs and run additional Q4_K cases. |
| backends/mlx/custom_kernel_ops/gguf/q4k/repack_mlx.py | New helper to repack raw Q4_K GGUF blobs into MLX qparams for the legacy path. |
| backends/mlx/custom_kernel_ops/gguf/q4k/linear_mlx_native.py | New legacy Q4_K lowering via MLX native quantized matmul using repacked qparams. |
| backends/mlx/custom_kernel_ops/gguf/q4k/linear.py | Replaces prior approach with fused Metal mat-vec/mat-mat kernels + dynamic dispatch. |
| backends/mlx/custom_kernel_ops/gguf/q4k/embedding_mlx_native.py | New legacy Q4_K lowering via MLX native quantized gather using repacked qparams. |
| backends/mlx/custom_kernel_ops/gguf/q4k/embedding.py | New fused Metal gather kernel reading raw Q4_K GGUF bytes directly. |
| backends/mlx/custom_kernel_ops/gguf/q4k/common.py | Adds shared Q4_K constants + shared Metal header (block layout + dequant helpers). |
| backends/mlx/custom_kernel_ops/gguf/q4k/init.py | Adds emit_direct_gguf() env-var gate and updated package documentation. |
| backends/mlx/custom_kernel_ops/gguf/patterns.py | Dispatches Q4_K lowering between fused kernels and legacy repack path based on env var. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| constexpr short NL = 16; // Q4_K: QK_K / 32 | ||
| constexpr short NL0 = NK / 16; // = 2 — dequant iterations per thread for weight | ||
| constexpr short NL1 = NK / 8; // = 4 — load iterations per thread for activation |
| Both Q6_K and Q4_K kernels dequantize the raw GGUF blob in-kernel; use the | ||
| gguf-exact dequant as the reference oracle. | ||
| """ | ||
| lin = model.linear | ||
| weight = lin.weight | ||
| if getattr(weight, "ggml_type", None) == "q4_k": | ||
| # Q4_K is repacked into bf16 MLX affine qparams (S, Q, B); reconstruct | ||
| # exactly what the kernel dequantizes so the oracle isolates kernel | ||
| # accumulation (repack precision vs gguf is covered by test_gguf.py). | ||
| from executorch.backends.mlx.builder.op_helpers import to_mlx_qparams | ||
|
|
||
| intx = weight.to_intx_unpacked_to_int8_tensor() | ||
| gs = int(intx.block_size[-1]) | ||
| Q, B = to_mlx_qparams(intx.qdata, intx.scale, intx.zero_point, 4) | ||
| qb = Q.view(torch.uint8) | ||
| nibbles = torch.stack([(qb & 0xF).float(), ((qb >> 4) & 0xF).float()], dim=-1) | ||
| q_unsigned = nibbles.reshape(intx.qdata.shape[0], -1) | ||
| scale = intx.scale.float().repeat_interleave(gs, dim=1) | ||
| bias_b = B.float().repeat_interleave(gs, dim=1) | ||
| w = scale * q_unsigned + bias_b | ||
| else: | ||
| w = weight.dequantize(torch.float32) | ||
| w = weight.dequantize(torch.float32) |
| K: int, | ||
| out: Slot, | ||
| ) -> None: | ||
| in_dtype_int = torch_dtype_to_scalar_type(x_node.meta["val"].dtype) |
| output_names=["out"], | ||
| output_shapes_flat=out_shape_flat, | ||
| output_shape_lengths=[len(out_shape_flat)], | ||
| output_dtypes=[in_dtype_int], |
| out_shape_flat = leading + [IntOrVid.from_literal(K)] | ||
|
|
||
| # threadgroup.x must divide grid.x (= K, a multiple of 256). | ||
| tg_x = 256 if K % 256 == 0 else K |
| const uint j = thread_position_in_grid.x; // 0..K-1 | ||
| const uint r = thread_position_in_grid.y; // gathered row | ||
| const int row = (int) indices[r]; | ||
| const int nb = K / QK_K; | ||
| device const block_q4_K * blk = | ||
| ((device const block_q4_K *) weight) + (uint)row * nb + (j / QK_K); |
| if emit_direct_gguf(): | ||
| from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.linear import ( | ||
| emit_linear, | ||
| ) | ||
| else: | ||
| from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.linear_mlx_native import ( | ||
| emit_linear, | ||
| ) |
Fixes #20172
Summary
ET_MLX_EMIT_DIRECT_GGUF=0Test plan
python -m executorch.backends.mlx.custom_kernel_ops.gguf.test.test_linear runpython -m executorch.backends.mlx.custom_kernel_ops.gguf.test.test_embedding run