Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
af6810b
up
metascroy Mar 3, 2026
0f03f2b
up
metascroy Mar 3, 2026
bf36732
up
metascroy Mar 3, 2026
6a2d455
up
metascroy Mar 3, 2026
493d9ea
up
metascroy Mar 5, 2026
5ee8ac4
up
metascroy Mar 5, 2026
0df21d9
up
metascroy Mar 5, 2026
93afd3e
up
metascroy Mar 5, 2026
0adbe8c
up
metascroy Mar 5, 2026
f0b8e71
up
metascroy Mar 3, 2026
5493ea1
up
metascroy Mar 3, 2026
51462b3
up
metascroy Mar 3, 2026
1add04d
up
metascroy Mar 5, 2026
d15ee3c
up
metascroy Mar 5, 2026
b7da263
up
metascroy Mar 5, 2026
681ae8c
up
metascroy Mar 3, 2026
848cbd3
up
metascroy Mar 4, 2026
1ae519b
up
metascroy Mar 5, 2026
b61faab
up
metascroy Mar 5, 2026
eb22885
up
metascroy Mar 5, 2026
b8f0fa6
up
metascroy Mar 5, 2026
1ed89ee
up
metascroy Mar 3, 2026
cbb43c7
up
metascroy Mar 3, 2026
7e5abd4
up
metascroy Mar 3, 2026
c534395
up
metascroy Mar 3, 2026
da81e86
up
metascroy Mar 5, 2026
74a3f1b
up
metascroy Mar 5, 2026
0b8b0af
up
metascroy Mar 5, 2026
afa912e
up
metascroy Mar 5, 2026
6e924fe
up
metascroy Mar 5, 2026
6f805bf
Merge branch 'mlx-delegate-part1' into mlx-delegate-part2
metascroy Mar 27, 2026
6a699b1
Merge branch 'mlx-delegate-part2' into mlx-delegate-part3
metascroy Mar 27, 2026
ea26815
up
metascroy Apr 7, 2026
a7d9435
up
metascroy Apr 7, 2026
fd01e12
Merge branch 'main' into mlx-delegate-part3
metascroy Apr 7, 2026
8f047dd
up
metascroy Apr 7, 2026
3cbfa30
Merge branch 'main' into mlx-delegate-part3
metascroy Apr 7, 2026
984de05
up
metascroy Apr 7, 2026
417963a
Merge branch 'mlx-delegate-part3' into qwen-moe-part4
metascroy Apr 7, 2026
6e5f323
up
metascroy Apr 7, 2026
0d3a549
up
metascroy Apr 7, 2026
4d833fe
Merge branch 'mlx-delegate-part3' into qwen-moe-part4
metascroy Apr 8, 2026
46e8bbc
Merge branch 'main' into qwen-moe-part4
metascroy Apr 8, 2026
89b499b
Merge branch 'main' into qwen-moe-part4
metascroy Apr 8, 2026
2ca57a7
up
metascroy Apr 8, 2026
92c541a
up
metascroy Apr 9, 2026
47e5f99
up
metascroy Apr 9, 2026
2bf242c
up
metascroy Apr 9, 2026
2b8e675
Merge branch 'main' into qwen-moe-part4
metascroy Apr 13, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions .github/workflows/mlx.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ on:
- extension/audio/**
- examples/models/parakeet/**
- examples/models/voxtral_realtime/**
- examples/models/qwen3_5_moe/**
workflow_dispatch:

permissions: {}
Expand Down Expand Up @@ -63,6 +64,61 @@ jobs:
./cmake-out/backends/mlx/test/multi_thread_test_runner
echo "::endgroup::"

echo "::group::Run gated_delta_rule op tests"
${CONDA_RUN} python -m executorch.backends.mlx.model_ops.test_gated_delta_rule run -v
echo "::endgroup::"

test-mlx-qwen35-moe:
uses: pytorch/test-infra/.github/workflows/macos_job.yml@main
with:
job-name: test-mlx-qwen35-moe
runner: macos-14-xlarge
python-version: "3.12"
submodules: recursive
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
timeout: 90
script: |
set -eux

echo "::group::Install ExecuTorch"
${CONDA_RUN} python install_executorch.py > /dev/null
echo "::endgroup::"

${CONDA_RUN} pip list

echo "::group::Export Qwen 3.5 MoE (tiny model)"
${CONDA_RUN} python -m executorch.examples.models.qwen3_5_moe.export \
--tiny-test \
--backend mlx \
--qlinear 4w \
--qlinear-group-size 32 \
--output-dir /tmp/qwen35_moe_mlx_tiny
echo "::endgroup::"

echo "::group::Check AsType node count"
ASTYPE_COUNT=$(${CONDA_RUN} python -m executorch.backends.mlx.pte_inspector \
/tmp/qwen35_moe_mlx_tiny/model.pte --mlx-instructions 2>&1 | grep -c "AsTypeNode" || true)
echo "AsType nodes: ${ASTYPE_COUNT}"
if [ "$ASTYPE_COUNT" -gt 23 ]; then
echo "Failed: expected no more than 23 AsType nodes, got ${ASTYPE_COUNT}"
exit 1
fi
echo "::endgroup::"

echo "::group::Run Qwen 3.5 MoE inference"
OUTPUT=$(${CONDA_RUN} python -m executorch.examples.models.qwen3_5_moe.run \
--pte /tmp/qwen35_moe_mlx_tiny/model.pte \
--prompt-len 4 \
--max-new-tokens 5 2>&1)
echo "$OUTPUT"
if echo "$OUTPUT" | grep -q "Generated token ids: \[167, 167, 81, 167, 81\]"; then
echo "Success: Qwen 3.5 MoE MLX export + inference completed with expected output"
else
echo "Failed: unexpected output (expected [167, 167, 81, 167, 81])"
exit 1
fi
echo "::endgroup::"

backend-tester:
strategy:
fail-fast: false
Expand Down
15 changes: 15 additions & 0 deletions backends/mlx/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,14 @@ add_subdirectory(${MLX_SOURCE_DIR} ${CMAKE_CURRENT_BINARY_DIR}/mlx)
# Op logging option (for debugging) - OFF by default for performance
option(ET_MLX_ENABLE_OP_LOGGING "Enable per-op logging in MLX delegate" OFF)

# Custom kernel execution - OFF by default for security. When enabled,
# MetalKernelNode can execute arbitrary Metal shader code embedded in .pte
# files. Only enable for trusted .pte sources.
option(ET_MLX_ALLOW_CUSTOM_KERNEL_EXECUTION
"Allow MetalKernelNode to execute custom Metal shaders from .pte files"
ON
)

set(_mlx_backend__srcs ${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXLoader.cpp
${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXBackend.cpp
)
Expand All @@ -262,6 +270,13 @@ if(ET_MLX_ENABLE_OP_LOGGING)
message(STATUS "MLX delegate op logging ENABLED")
endif()

if(ET_MLX_ALLOW_CUSTOM_KERNEL_EXECUTION)
target_compile_definitions(
mlxdelegate PRIVATE ET_MLX_ALLOW_CUSTOM_KERNEL_EXECUTION
)
message(STATUS "MLX delegate custom kernel execution ENABLED")
endif()

target_include_directories(
mlxdelegate PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/runtime
)
Expand Down
89 changes: 76 additions & 13 deletions backends/mlx/builder/program_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import traceback
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, DefaultDict, Dict, List, Optional, Set, Tuple, Union

Expand Down Expand Up @@ -172,6 +173,24 @@ def emit_init(self, op: OpNodeUnion) -> None:
self._chains.append([])
self._chains[self.init_chain_idx].append(Instruction(op=op))

@contextmanager
def new_chain(self):
"""Context manager that creates a new instruction chain and redirects emit() to it.

Usage:
with P.new_chain() as chain_idx:
P.emit(MulNode(...)) # goes to the new chain
# P.emit() goes back to the previous chain
"""
chain_idx = len(self._chains)
self._chains.append([])
prev_chain = self._current_chain
self._current_chain = chain_idx
try:
yield chain_idx
finally:
self._current_chain = prev_chain

def args(self, node: Node) -> Tuple[Any, ...]:
return self.slot_map(node.args)

Expand Down Expand Up @@ -629,9 +648,12 @@ def _verify_build(self):
info.handler in (noop_handler, PatternHandler.deferred_handler)
or n.users == {}
):
assert (
self.slot_manager.get_slot(n) is None
), f"Did not expect node {n} handled by {info.handler} to have a slot"
# Deferred body nodes may or may not have slots — this is fine.
# Pattern handlers absorb nodes into their body and may set
# slots on them (e.g., GatedDeltaRuleHandler sets getitem[0]'s
# slot to the ScanNode output). Dead nodes (no users) also
# skip the slot check.
pass
else:
assert (
self.slot_manager.get_slot(n) is not None
Expand Down Expand Up @@ -962,6 +984,11 @@ def get_named_data_store(self) -> NamedDataStore:
``ep.constants`` / ``extra_constants`` (which all use unprefixed
keys). The prefix is applied at the exit boundary — the
``NamedDataStore`` key — so it matches the FlatBuffer ``named_slots``.

To reduce peak memory, each constant is deleted from the EP
immediately after its bytes are added to the NamedDataStore.
This avoids holding two full copies of all constants simultaneously
(important for large models where constants can be 20+ GB).
"""
named_data_store = NamedDataStore()

Expand All @@ -971,6 +998,17 @@ def get_named_data_store(self) -> NamedDataStore:
key=lambda x: self._slot_to_final_tid.get(x[1], 0),
)

# Free EP constants not used by the MLX graph to reduce peak memory.
used = set(self._constant_name_to_slot.keys())
for ispec in self.ep.graph_signature.input_specs:
if ispec.arg.name in used and ispec.target is not None:
used.add(ispec.target)

for d in (self.ep._state_dict, self.ep._constants):
for name in list(d.keys()):
if name not in used and isinstance(d[name], torch.Tensor):
del d[name]

logger.debug(f"Adding {len(entries)} constants to NamedDataStore...")
for canonical_name, _slot in entries:
tensor = self._find_constant_tensor(canonical_name)
Expand All @@ -983,6 +1021,15 @@ def get_named_data_store(self) -> NamedDataStore:
data=t,
alignment=16,
)

# Free the original tensor from the EP immediately.
# The contiguous copy is now serialized as bytes in the
# NamedDataStore — the EP reference is no longer needed.
# (It would be deleted by lowered_backend_module.py after
# preprocess() returns anyway.)
self._delete_constant_tensor(canonical_name)
del tensor, t

logger.debug("Done adding constants to NamedDataStore")

return named_data_store
Expand Down Expand Up @@ -1011,17 +1058,33 @@ def get_mutable_buffer_names(self) -> List[str]:

def _find_constant_tensor(self, name: str) -> Optional[torch.Tensor]:
"""Find a constant tensor by name from various sources."""
if name in self.ep.state_dict:
return self.ep.state_dict[name]
if name in self.ep.constants:
return self.ep.constants[name]
result = self._resolve_constant(name)
if result is None:
return None

d, k = result
return d[k]

def _delete_constant_tensor(self, name: str) -> None:
"""Delete a constant from the EP to free memory during serialization."""

result = self._resolve_constant(name)
if result:
d, k = result
del d[k]

def _resolve_constant(self, name):
"""Returns (dict, key) or None."""
if name in self.ep._state_dict:
return self.ep._state_dict, name
if name in self.ep._constants:
return self.ep._constants, name
if name in self.extra_constants:
return self.extra_constants[name]
# Look up by target
return self.extra_constants, name
for ispec in self.ep.graph_signature.input_specs:
if ispec.arg.name == name and ispec.target is not None:
if ispec.target in self.ep.state_dict:
return self.ep.state_dict[ispec.target]
if ispec.target in self.ep.constants:
return self.ep.constants[ispec.target]
if ispec.target in self.ep._state_dict:
return self.ep._state_dict, ispec.target
if ispec.target in self.ep._constants:
return self.ep._constants, ispec.target
return None
16 changes: 15 additions & 1 deletion backends/mlx/builder/slot_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,26 @@ class IdSpace(Enum):
Temp = auto()


@dataclass(frozen=True)
@dataclass(eq=False, frozen=True)
class Slot:
"""Represents an allocated tensor or symbolic int slot.

Uses identity-based equality and hashing (not field-based) so that
two Slots with the same (id_type, id_space, idx) — which can happen
when the delete-as-you-go allocator recycles an idx — remain distinct
in sets and dicts during build().
"""

id_type: IdType
id_space: IdSpace
idx: Optional[int] = None

def __eq__(self, other):
return self is other

def __hash__(self):
return id(self)


class IdManager:
def __init__(self):
Expand Down
114 changes: 114 additions & 0 deletions backends/mlx/custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,3 +269,117 @@ def rope_fake(
) -> Tensor:
"""Fake implementation for tracing."""
return x.new_empty(x.shape)


@torch.library.custom_op("mlx::gather_mm", mutates_args=())
def gather_mm(
a: Tensor, # [..., M, K]
b: Tensor, # [E, K, N] or [..., K, N]
rhs_indices: Optional[Tensor] = None, # Expert selection indices
lhs_indices: Optional[Tensor] = None, # Optional LHS gather indices
sorted_indices: bool = False,
) -> Tensor:
"""
Gather matrix multiply — matches mlx::core::gather_mm semantics exactly.

Output shape = broadcast(lhs_indices, rhs_indices).shape + [M, N]
where M = a.shape[-2], N = b.shape[-1].

For MoE: a=[N_tokens, 1, K], b=[E, K, out], rhs_indices=[N_tokens]
→ output=[N_tokens, 1, out]. Caller squeezes dim -2.
"""
if rhs_indices is not None:
b_sel = b[rhs_indices]
else:
b_sel = b
return torch.matmul(a, b_sel)


@torch.library.register_fake("mlx::gather_mm")
def gather_mm_fake(
a: Tensor,
b: Tensor,
rhs_indices: Optional[Tensor] = None,
lhs_indices: Optional[Tensor] = None,
sorted_indices: bool = False,
) -> Tensor:
# Matches MLX: output = indices.shape + [M, N]
# For simplicity, use matmul shape rules after gather
M = a.shape[-2]
N = b.shape[-1]
if rhs_indices is not None:
batch = rhs_indices.shape
else:
batch = b.shape[:-2]
return a.new_empty((*batch, M, N))


@torch.library.custom_op("mlx::gather_qmm", mutates_args=())
def gather_qmm(
x: Tensor, # [..., M, K]
w: Tensor, # [E, out, in_packed]
scales: Tensor, # [E, out, in//gs]
biases: Optional[Tensor] = None, # [E, out, in//gs] (affine mode)
rhs_indices: Optional[Tensor] = None, # Expert selection indices
lhs_indices: Optional[Tensor] = None, # Optional LHS gather indices
transpose: bool = True,
group_size: int = 32,
bits: int = 4,
mode: str = "affine",
sorted_indices: bool = False,
) -> Tensor:
"""
Gather quantized matrix multiply — matches mlx::core::gather_qmm semantics.

Output shape = broadcast(lhs_indices, rhs_indices).shape + [M, N]

For MoE: x=[N_tokens, 1, K], w=[E, out, K_packed], rhs_indices=[N_tokens]
→ output=[N_tokens, 1, out]. Caller squeezes dim -2.
"""
# Eager fallback: gather, dequantize, matmul
if rhs_indices is not None:
w_sel = w[rhs_indices]
s_sel = scales[rhs_indices]
b_sel = biases[rhs_indices] if biases is not None else None
else:
w_sel = w
s_sel = scales
b_sel = biases

# Dequantize
w_float = w_sel.to(x.dtype)
s_expanded = s_sel.repeat_interleave(group_size, dim=-1)
if b_sel is not None:
b_expanded = b_sel.repeat_interleave(group_size, dim=-1)
w_dequant = w_float * s_expanded + b_expanded
else:
w_dequant = w_float * s_expanded

if transpose:
w_dequant = w_dequant.transpose(-1, -2)

return torch.matmul(x, w_dequant)


@torch.library.register_fake("mlx::gather_qmm")
def gather_qmm_fake(
x: Tensor,
w: Tensor,
scales: Tensor,
biases: Optional[Tensor] = None,
rhs_indices: Optional[Tensor] = None,
lhs_indices: Optional[Tensor] = None,
transpose: bool = True,
group_size: int = 32,
bits: int = 4,
mode: str = "affine",
sorted_indices: bool = False,
) -> Tensor:
# Matches MLX: output = indices.shape + [M, N]
M = x.shape[-2]
N = w.shape[-2] if transpose else w.shape[-1]
if rhs_indices is not None:
batch = rhs_indices.shape
else:
batch = w.shape[:-2]
return x.new_empty((*batch, M, N))
Loading
Loading