From 805a09d527cae3858174d7b71166a8da72f3ed2d Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Tue, 14 Apr 2026 12:25:30 -0400 Subject: [PATCH] Update [ghstack-poisoned] --- backends/apple/metal/CMakeLists.txt | 1 + backends/apple/metal/metal_backend.py | 10 + backends/apple/metal/ops/gated_delta_rule.py | 84 +++++ .../metal/runtime/ops/op_gated_delta_rule.mm | 323 ++++++++++++++++++ backends/apple/metal/tests/test_modules.py | 42 +++ 5 files changed, 460 insertions(+) create mode 100644 backends/apple/metal/ops/gated_delta_rule.py create mode 100644 backends/apple/metal/runtime/ops/op_gated_delta_rule.mm diff --git a/backends/apple/metal/CMakeLists.txt b/backends/apple/metal/CMakeLists.txt index a00c440363f..84861c1d517 100644 --- a/backends/apple/metal/CMakeLists.txt +++ b/backends/apple/metal/CMakeLists.txt @@ -46,6 +46,7 @@ set(_aoti_metal_sources runtime/ops/op_bmm.mm runtime/ops/op_convolution.mm runtime/ops/op_gather_qmv.mm + runtime/ops/op_gated_delta_rule.mm runtime/ops/op_linear_4bit.mm runtime/ops/op_mm.mm runtime/ops/op_sdpa.mm diff --git a/backends/apple/metal/metal_backend.py b/backends/apple/metal/metal_backend.py index 1da6d603ce1..5c578fce3aa 100644 --- a/backends/apple/metal/metal_backend.py +++ b/backends/apple/metal/metal_backend.py @@ -38,6 +38,7 @@ def get_supported_fallback_kernels(cls) -> Dict[str, Any]: "torchao::_linear_fp_act_4bit_weight": None, "at::_ops::topk::call": None, "metal::gather_qmv": None, + "metal::gated_delta_rule": None, } @classmethod @@ -88,6 +89,15 @@ def get_aoti_compile_options( except ImportError: pass + try: + from executorch.backends.apple.metal.ops.gated_delta_rule import ( + metal_gated_delta_rule_c_shim, + ) + + custom_c_shims.update(metal_gated_delta_rule_c_shim) + except ImportError: + pass + inductor_configs["aot_inductor.custom_ops_to_c_shims"] = custom_c_shims return inductor_configs diff --git a/backends/apple/metal/ops/gated_delta_rule.py b/backends/apple/metal/ops/gated_delta_rule.py new file mode 100644 index 00000000000..7c4c2bbfc08 --- /dev/null +++ b/backends/apple/metal/ops/gated_delta_rule.py @@ -0,0 +1,84 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +metal::gated_delta_rule custom op for linear attention recurrence. + +Performs the gated delta rule recurrence over T time steps, mutating +the recurrent state in-place. The Metal fallback kernel is in +runtime/ops/op_gated_delta_rule.mm. +""" + +import torch +from torch import Tensor + + +@torch.library.custom_op("metal::gated_delta_rule", mutates_args=("state",)) +def gated_delta_rule( + q: Tensor, # [B, T, Hk, Dk] + k: Tensor, # [B, T, Hk, Dk] + v: Tensor, # [B, T, Hv, Dv] + g: Tensor, # [B, T, Hv] — decay gate (already exp'd) + beta: Tensor, # [B, T, Hv] — update gate + state: Tensor, # [B, Hv, Dv, Dk] — recurrent state (MUTATED) +) -> Tensor: + """Reference implementation: sequential recurrence over T.""" + B, T_len, Hk, Dk = q.shape + Hv, Dv = v.shape[-2:] + + s = state.clone().float() + ys = [] + + for t in range(T_len): + q_t = q[:, t].float() # [B, Hk, Dk] + k_t = k[:, t].float() # [B, Hk, Dk] + v_t = v[:, t].float() # [B, Hv, Dv] + g_t = g[:, t].float() # [B, Hv] + beta_t = beta[:, t].float() # [B, Hv] + + # Decay + s = s * g_t[:, :, None, None] + + # Project state by key + kv_mem = (s * k_t[:, :, None, :]).sum(dim=-1) # [B, Hv, Dv] + + # Delta rule update + delta = (v_t - kv_mem) * beta_t[:, :, None] # [B, Hv, Dv] + s = s + k_t[:, :, None, :] * delta[:, :, :, None] # [B, Hv, Dv, Dk] + + # Read from state + y_t = (s * q_t[:, :, None, :]).sum(dim=-1) # [B, Hv, Dv] + ys.append(y_t) + + state.copy_(s.to(state.dtype)) + return torch.stack(ys, dim=1).to(q.dtype) + + +@torch.library.register_fake("metal::gated_delta_rule") +def gated_delta_rule_fake( + q: Tensor, + k: Tensor, + v: Tensor, + g: Tensor, + beta: Tensor, + state: Tensor, +) -> Tensor: + B, T = q.shape[:2] + Hv, Dv = v.shape[-2:] + return torch.empty(B, T, Hv, Dv, dtype=q.dtype, device=q.device) + + +# C shim mapping for AOTInductor code generation. +# The op mutates state in-place and returns one tensor (y). AOTInductor's +# auto_functionalized wrapper passes 6 input handles + 1 output pointer. +metal_gated_delta_rule_c_shim = { + torch.ops.metal.gated_delta_rule.default: [ + "AOTITorchError aoti_torch_mps_gated_delta_rule(" + "AtenTensorHandle Q, AtenTensorHandle K, AtenTensorHandle V, " + "AtenTensorHandle G, AtenTensorHandle Beta, AtenTensorHandle StateIn, " + "AtenTensorHandle* retY)" + ], +} diff --git a/backends/apple/metal/runtime/ops/op_gated_delta_rule.mm b/backends/apple/metal/runtime/ops/op_gated_delta_rule.mm new file mode 100644 index 00000000000..f119500f2b0 --- /dev/null +++ b/backends/apple/metal/runtime/ops/op_gated_delta_rule.mm @@ -0,0 +1,323 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Gated delta rule recurrence kernel for linear attention (Qwen 3.5 MoE). +// +// Ported from the MLX delegate PR (#18785) Metal shader. The kernel processes +// the full sequence sequentially within a single GPU dispatch, keeping +// recurrent state in per-thread registers. +// +// Recurrence per time step: +// state *= exp(g_t) -- decay +// kv_mem = sum(state * k_t, dim=-1) -- project state by key +// delta = beta_t * (v_t - kv_mem) -- delta rule update +// state += outer(k_t, delta) -- rank-1 state update +// output_t = sum(state * q_t, dim=-1) -- read from state +// +// Grid: [32, Dv, B*Hv] Threadgroup: [32, 4, 1] +// Each simdgroup of 32 threads handles Dk/32 elements of the key dimension. + +#include + +namespace executorch { +namespace backends { +namespace metal { +namespace { + +static std::string get_gated_delta_rule_metal_source() { + return R"( + #include + #include + using namespace metal; + + // Gated delta rule recurrence kernel. + // Template args: InT=data type, Dk/Dv/Hk/Hv=static dimensions. + // From MLX delegate PR #18785 (Copyright Meta Platforms, Inc.). + template + [[kernel]] void gated_delta_step( + const device InT* q [[buffer(0)]], // [B, T, Hk, Dk] + const device InT* k [[buffer(1)]], // [B, T, Hk, Dk] + const device InT* v [[buffer(2)]], // [B, T, Hv, Dv] + const device InT* g [[buffer(3)]], // [B, T, Hv] + const device InT* beta [[buffer(4)]], // [B, T, Hv] + const device InT* state_in [[buffer(5)]], // [B, Hv, Dv, Dk] + device InT* y [[buffer(6)]], // [B, T, Hv, Dv] + device InT* state_out [[buffer(7)]], // [B, Hv, Dv, Dk] + constant uint& T_val [[buffer(8)]], // sequence length + uint3 thread_position_in_grid [[thread_position_in_grid]], + uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]], + uint thread_index_in_simdgroup [[thread_index_in_simdgroup]]) { + + auto n = thread_position_in_grid.z; + auto b_idx = n / Hv; + auto hv_idx = n % Hv; + auto hk_idx = hv_idx / (Hv / Hk); + constexpr int n_per_t = Dk / 32; + + int T = static_cast(T_val); + + // q, k: [B, T, Hk, Dk] + auto q_ = q + b_idx * T * Hk * Dk + hk_idx * Dk; + auto k_ = k + b_idx * T * Hk * Dk + hk_idx * Dk; + + // v, y: [B, T, Hv, Dv] + auto v_ = v + b_idx * T * Hv * Dv + hv_idx * Dv; + y += b_idx * T * Hv * Dv + hv_idx * Dv; + + auto dk_idx = thread_position_in_threadgroup.x; + auto dv_idx = thread_position_in_grid.y; + + // state_in, state_out: [B, Hv, Dv, Dk] + auto i_state = state_in + (n * Dv + dv_idx) * Dk; + auto o_state = state_out + (n * Dv + dv_idx) * Dk; + + float state[n_per_t]; + for (int i = 0; i < n_per_t; ++i) { + auto s_idx = n_per_t * dk_idx + i; + state[i] = static_cast(i_state[s_idx]); + } + + // g, beta: [B, T, Hv] + auto g_ = g + b_idx * T * Hv; + auto beta_ = beta + b_idx * T * Hv; + + for (int t = 0; t < T; ++t) { + float kv_mem = 0.0f; + for (int i = 0; i < n_per_t; ++i) { + auto s_idx = n_per_t * dk_idx + i; + state[i] = state[i] * static_cast(g_[hv_idx]); + kv_mem += state[i] * static_cast(k_[s_idx]); + } + kv_mem = simd_sum(kv_mem); + + auto delta = (static_cast(v_[dv_idx]) - kv_mem) * static_cast(beta_[hv_idx]); + + float out = 0.0f; + for (int i = 0; i < n_per_t; ++i) { + auto s_idx = n_per_t * dk_idx + i; + state[i] = state[i] + static_cast(k_[s_idx]) * delta; + out += state[i] * static_cast(q_[s_idx]); + } + out = simd_sum(out); + if (thread_index_in_simdgroup == 0) { + y[dv_idx] = static_cast(out); + } + // Advance to next time step + q_ += Hk * Dk; + k_ += Hk * Dk; + v_ += Hv * Dv; + y += Hv * Dv; + g_ += Hv; + beta_ += Hv; + } + for (int i = 0; i < n_per_t; ++i) { + auto s_idx = n_per_t * dk_idx + i; + o_state[s_idx] = static_cast(state[i]); + } + } + + // Instantiate for Qwen 3.5 MoE dimensions: Dk=128, Dv=128, Hk=16, Hv=32 + #define INSTANTIATE_GDR(DTYPE, Dk, Dv, Hk, Hv) \ + template [[host_name("gated_delta_step_" #DTYPE \ + "_dk" #Dk "_dv" #Dv "_hk" #Hk "_hv" #Hv)]] \ + [[kernel]] void gated_delta_step( \ + const device DTYPE* q [[buffer(0)]], \ + const device DTYPE* k [[buffer(1)]], \ + const device DTYPE* v [[buffer(2)]], \ + const device DTYPE* g [[buffer(3)]], \ + const device DTYPE* beta [[buffer(4)]], \ + const device DTYPE* state_in [[buffer(5)]], \ + device DTYPE* y [[buffer(6)]], \ + device DTYPE* state_out [[buffer(7)]], \ + constant uint& T_val [[buffer(8)]], \ + uint3 thread_position_in_grid [[thread_position_in_grid]], \ + uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]], \ + uint thread_index_in_simdgroup [[thread_index_in_simdgroup]]) + + // Qwen 3.5 MoE real model dimensions (Hk=16 after repeat_interleave → 32) + INSTANTIATE_GDR(float, 128, 128, 32, 32); + INSTANTIATE_GDR(bfloat, 128, 128, 32, 32); + // Tiny test model dimensions (Hk=2 after repeat_interleave → 4) + INSTANTIATE_GDR(float, 64, 64, 4, 4); + INSTANTIATE_GDR(bfloat, 64, 64, 4, 4); + + )"; +} + +std::unique_ptr gdr_shader_library = nullptr; +std::once_flag gdr_shader_library_once_flag; + +ETMetalShaderLibrary* get_gdr_shader_library() { + std::call_once(gdr_shader_library_once_flag, []() { + std::string source = get_gated_delta_rule_metal_source(); + gdr_shader_library = std::make_unique(source); + }); + return gdr_shader_library.get(); +} + +} // namespace + + +extern "C" { + +AOTITorchError aoti_torch_mps_gated_delta_rule( + AOTITensorHandle Q, + AOTITensorHandle K, + AOTITensorHandle V, + AOTITensorHandle G, + AOTITensorHandle Beta, + AOTITensorHandle StateIn, + AOTITensorHandle* retY) { + + ET_LOG(Debug, "aoti_torch_mps_gated_delta_rule: Starting"); + + if (!Q || !K || !V || !G || !Beta || !StateIn || !retY) { + ET_LOG(Error, "aoti_torch_mps_gated_delta_rule: null required tensor handles"); + return Error::InvalidArgument; + } + + ETMetalStream* stream = getCurrentMetalStream(); + if (!stream) { + ET_LOG(Error, "aoti_torch_mps_gated_delta_rule: Failed to get Metal stream"); + return Error::Internal; + } + + try { + @autoreleasepool { + auto* q_tensor = reinterpret_cast(Q); // [B, T, Hk, Dk] + auto* k_tensor = reinterpret_cast(K); // [B, T, Hk, Dk] + auto* v_tensor = reinterpret_cast(V); // [B, T, Hv, Dv] + auto* g_tensor = reinterpret_cast(G); // [B, T, Hv] + auto* beta_tensor = reinterpret_cast(Beta); // [B, T, Hv] + auto* state_tensor = reinterpret_cast(StateIn); // [B, Hv, Dv, Dk] + + if (q_tensor->dim() != 4 || v_tensor->dim() != 4 || state_tensor->dim() != 4) { + ET_LOG(Error, "aoti_torch_mps_gated_delta_rule: q/v must be 4D, state must be 4D"); + return Error::InvalidArgument; + } + + int32_t B = static_cast(q_tensor->sizes()[0]); + int32_t T = static_cast(q_tensor->sizes()[1]); + int32_t Hk = static_cast(q_tensor->sizes()[2]); + int32_t Dk = static_cast(q_tensor->sizes()[3]); + int32_t Hv = static_cast(v_tensor->sizes()[2]); + int32_t Dv = static_cast(v_tensor->sizes()[3]); + + ET_LOG(Debug, "aoti_torch_mps_gated_delta_rule: B=%d, T=%d, Hk=%d, Dk=%d, Hv=%d, Dv=%d", + B, T, Hk, Dk, Hv, Dv); + + if (Dk % 32 != 0) { + ET_LOG(Error, "aoti_torch_mps_gated_delta_rule: Dk=%d must be multiple of 32", Dk); + return Error::InvalidArgument; + } + + // Determine dtype + int32_t dtype = static_cast(q_tensor->scalar_type()); + size_t element_size; + std::string type_str; + + if (dtype == static_cast(SupportedDTypes::FLOAT32)) { + element_size = sizeof(float); + type_str = "float"; + } else if (dtype == static_cast(SupportedDTypes::BFLOAT16)) { + element_size = sizeof(uint16_t); + type_str = "bfloat"; + } else { + ET_LOG(Error, "aoti_torch_mps_gated_delta_rule: Unsupported dtype %d", dtype); + return Error::InvalidArgument; + } + + ETMetalShaderLibrary* library = get_gdr_shader_library(); + if (!library) { + ET_LOG(Error, "aoti_torch_mps_gated_delta_rule: Failed to get shader library"); + return Error::Internal; + } + + std::string kernel_name = "gated_delta_step_" + type_str + + "_dk" + std::to_string(Dk) + "_dv" + std::to_string(Dv) + + "_hk" + std::to_string(Hk) + "_hv" + std::to_string(Hv); + ET_LOG(Debug, "aoti_torch_mps_gated_delta_rule: Using kernel: %s", kernel_name.c_str()); + + auto kernel_func = library->getKernelFunction(kernel_name); + if (!kernel_func) { + ET_LOG(Error, "aoti_torch_mps_gated_delta_rule: Kernel not found: %s", kernel_name.c_str()); + return Error::Internal; + } + + // Allocate output y [B, T, Hv, Dv] + size_t y_bytes = B * T * Hv * Dv * element_size; + void* y_ptr = nullptr; + allocate_mtl_buffer(&y_ptr, y_bytes); + + std::vector y_sizes = {B, T, Hv, Dv}; + std::vector y_strides = {T * Hv * Dv, Hv * Dv, Dv, 1}; + + AOTITensorHandle y_handle = nullptr; + aoti_torch_create_tensor_from_blob_v2( + y_ptr, 4, y_sizes.data(), y_strides.data(), + 0, dtype, 13, 0, &y_handle, 0, nullptr, 0); + + if (!y_handle) { + aoti_torch_mps_free(y_ptr); + return Error::Internal; + } + extern std::unordered_map memory_to_n_tensor; + memory_to_n_tensor[y_ptr] = 1; + + auto* y_tensor = reinterpret_cast(y_handle); + + // State is mutated in-place: kernel writes to state_tensor directly + // (state_out = state_in in the kernel args) + uint T_uint = static_cast(T); + + // Execute kernel + kernel_func->runCommandBlock([&]() { + kernel_func->startEncoding(); + + kernel_func->setArg(0, *q_tensor); + kernel_func->setArg(1, *k_tensor); + kernel_func->setArg(2, *v_tensor); + kernel_func->setArg(3, *g_tensor); + kernel_func->setArg(4, *beta_tensor); + kernel_func->setArg(5, *state_tensor); // state_in + kernel_func->setArg(6, *y_tensor); + kernel_func->setArg(7, *state_tensor); // state_out = state_in (in-place) + kernel_func->setArg(8, T_uint); + + // Grid: [32, Dv, B*Hv] Threadgroup: [32, 4, 1] + kernel_func->dispatchThreadgroups( + 1, // gridX (32 threads in threadgroup.x) + Dv, // gridY: one per value dim + B * Hv, // gridZ: one per (batch, head) + 32, // threadsX: simdgroup size + 4, // threadsY + 1); // threadsZ + }); + + *retY = y_handle; + + ET_LOG(Debug, "aoti_torch_mps_gated_delta_rule: Completed successfully"); + + } // @autoreleasepool + + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_gated_delta_rule exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_gated_delta_rule: unknown exception"); + return Error::Internal; + } +} + +} // extern "C" + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/apple/metal/tests/test_modules.py b/backends/apple/metal/tests/test_modules.py index 1353c418950..5c6a2e29791 100644 --- a/backends/apple/metal/tests/test_modules.py +++ b/backends/apple/metal/tests/test_modules.py @@ -734,6 +734,48 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: } +# ------------------------------------------------------------------------- +# Gated Delta Rule (linear attention recurrence) +# ------------------------------------------------------------------------- + + +class GatedDeltaRule(nn.Module): + """Wrapper around metal::gated_delta_rule for testing the linear + attention recurrence kernel.""" + + def __init__(self): + super().__init__() + B, Hv, Dv, Dk = 1, 4, 64, 64 + self.register_buffer("state", torch.zeros(B, Hv, Dv, Dk)) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + ) -> torch.Tensor: + import executorch.backends.apple.metal.ops.gated_delta_rule # noqa: F401 + + return torch.ops.metal.gated_delta_rule(q, k, v, g, beta, self.state) + + +MODULE_REGISTRY["gated_delta_rule"] = { + "model_class": GatedDeltaRule, + "input_shapes": [ + (1, 2, 4, 64), # q: [B, T, Hk, Dk] + (1, 2, 4, 64), # k + (1, 2, 4, 64), # v: [B, T, Hv, Dv] + (1, 2, 4), # g: [B, T, Hv] + (1, 2, 4), # beta: [B, T, Hv] + ], + "description": "Gated delta rule recurrence for linear attention (metal::gated_delta_rule)", + "atol_float32": 1e-4, + "atol_bfloat16": 5e-2, +} + + # ============================================================================= # Helper Functions # =============================================================================