Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 142 additions & 8 deletions examples/models/qwen3_5_moe/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,35 @@ def _prepare_and_quantize_mlx(model, config, args):
pack_all_switch_linears(model)


def load_and_quantize(args):
def _prepare_and_quantize_metal(model, config, args):
"""Metal: apply source transforms, quantize experts + non-expert layers."""
import executorch.backends.apple.metal.ops.gated_delta_rule # noqa: F401
import executorch.backends.apple.metal.ops.gather_qmv # noqa: F401
from executorch.examples.models.qwen3_5_moe.metal_source_transformations import (
metal_source_transformations,
quantize_experts_metal,
)

# Quantize expert weights to Metal-compatible INT4 format
if args.qlinear:
quantize_experts_metal(model, config, args.qlinear_group_size)

if args.qlinear:
from executorch.extension.llm.export.quantize import quantize_model_

# skip_incompatible_shapes skips shared_expert_gate (N=1, N%4!=0)
quantize_model_(
model,
qlinear_config=args.qlinear,
qlinear_group_size=args.qlinear_group_size,
skip_incompatible_shapes=True,
)

_materialize_buffers(model, config)
metal_source_transformations(model, config=config)


def load_and_quantize(args): # noqa: C901
"""Load model from checkpoint, optionally quantize.

For CUDA: quantizes experts with packed INT4, then transformer layers on CUDA.
Expand Down Expand Up @@ -146,6 +174,11 @@ def load_and_quantize(args):
)
_prepare_and_quantize_mlx(model, config, args)

elif backend == "metal":
if args.prequantized:
raise ValueError("Metal backend does not support --prequantized.")
_prepare_and_quantize_metal(model, config, args)

elif backend == "cuda":
if args.prequantized:
return load_prequantized_model(args.prequantized, args.max_seq_len)
Expand Down Expand Up @@ -497,6 +530,8 @@ def export_and_lower(model, config, args):

if backend == "mlx":
_export_mlx(model, config, args)
elif backend == "metal":
_export_metal(model, config, args)
else:
_export_cuda(model, config, args)

Expand Down Expand Up @@ -581,6 +616,100 @@ def _export_mlx(model, config, args):
print("Done!")


def _export_metal(model, config, args):
"""Export model to .pte via torch.export + Metal backend."""
import torch._inductor.config as inductor_config

from executorch.backends.apple.metal.metal_backend import MetalBackend
from executorch.backends.apple.metal.metal_partitioner import MetalPartitioner
from executorch.exir import (
EdgeCompileConfig,
ExecutorchBackendConfig,
to_edge_transform_and_lower,
)
from executorch.exir.passes import MemoryPlanningPass
from torch.export import Dim, export

inductor_config.coordinate_descent_tuning = False
inductor_config.aot_inductor.compile_wrapper_opt_level = "O0"

# --- Decode method (T=1, static shape) ---
print("Exporting decode method...")
decode_tokens = torch.tensor([[0]], dtype=torch.long)
decode_pos = torch.tensor([0], dtype=torch.long)
with torch.no_grad():
decode_ep = export(model, (decode_tokens, decode_pos), strict=True)
print("Decode export successful!")

# --- Prefill method (T>=2, dynamic shape) ---
print("Exporting prefill method...")
prefill_tokens = torch.tensor([[0, 1]], dtype=torch.long)
prefill_pos = torch.tensor([0, 1], dtype=torch.long)
seq_dim = Dim("seq_len", min=2, max=config.max_seq_len - 1)
prefill_dynamic_shapes = ({1: seq_dim}, {0: seq_dim})
with torch.no_grad():
prefill_ep = export(
model,
(prefill_tokens, prefill_pos),
dynamic_shapes=prefill_dynamic_shapes,
strict=True,
)
print("Prefill export successful!")

# Lower with Metal backend
print("Lowering to ExecuTorch with Metal...")
metadata = {
"get_max_seq_len": config.max_seq_len,
"get_vocab_size": config.vocab_size,
"get_n_layers": config.num_hidden_layers,
"use_kv_cache": True,
"use_sdpa_with_kv_cache": False,
"enable_dynamic_shape": True,
}
et_prog = to_edge_transform_and_lower(
{"decode": decode_ep, "prefill": prefill_ep},
partitioner={
"decode": [
MetalPartitioner(
[MetalBackend.generate_method_name_compile_spec("decode")]
)
],
"prefill": [
MetalPartitioner(
[MetalBackend.generate_method_name_compile_spec("prefill")]
)
],
},
compile_config=EdgeCompileConfig(
_check_ir_validity=False,
_skip_dim_order=True,
),
constant_methods=metadata,
)
et_program = et_prog.to_executorch(
config=ExecutorchBackendConfig(
extract_delegate_segments=True,
do_quant_fusion_and_const_prop=True,
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
),
)

# Save .pte
os.makedirs(args.output_dir, exist_ok=True)
pte_path = os.path.join(args.output_dir, "model.pte")
print(f"Saving to {pte_path}...")
with open(pte_path, "wb") as f:
et_program.write_to_file(f)
size_mb = os.path.getsize(pte_path) / (1024 * 1024)
print(f"Saved {size_mb:.1f} MB")

if et_program._tensor_data:
et_program.write_tensor_data_to_file(args.output_dir)
print(f"Saved tensor data to {args.output_dir}/")

print("Done!")


def _export_cuda(model, config, args):
"""Export model to .pte via torch.export + CUDA backend.

Expand Down Expand Up @@ -708,10 +837,8 @@ def _export_cuda(model, config, args):
# ---------------------------------------------------------------------------


def main():
parser = argparse.ArgumentParser(
description="Export Qwen3.5 MoE to ExecuTorch (CUDA or MLX)"
)
def main(): # noqa: C901
parser = argparse.ArgumentParser(description="Export Qwen3.5 MoE to ExecuTorch")
parser.add_argument(
"--model-dir",
default=None,
Expand All @@ -729,13 +856,13 @@ def main():
parser.add_argument(
"--backend",
default="cuda",
choices=["cuda", "mlx"],
help="Backend for export: cuda (default) or mlx.",
choices=["cuda", "mlx", "metal"],
help="Backend for export: cuda (default), mlx, or metal.",
)
parser.add_argument(
"--qlinear",
default=None,
choices=["4w", "8w", "8da4w", "8da8w"],
choices=["4w", "8w", "8da4w", "8da8w", "fpa4w"],
help="Quantize linear layers.",
)
parser.add_argument(
Expand Down Expand Up @@ -805,6 +932,13 @@ def main():
if args.turboquant:
parser.error("--turboquant is not supported with --backend mlx")

if args.backend == "metal":
if args.turboquant:
parser.error("--turboquant is not supported with --backend metal")

if args.qlinear == "fpa4w" and args.backend != "metal":
parser.error("--qlinear=fpa4w can only be used with --backend=metal")

model, config = load_and_quantize(args)

if args.backend == "cuda":
Expand Down
3 changes: 3 additions & 0 deletions extension/llm/export/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ def _check_shape_compatible(m, fqn, config_name, group_size, skip_incompatible_s
shape = m.weight.shape
if config_name == "nvfp4":
compatible = shape[-2] % group_size == 0 and shape[-1] % group_size == 0
elif config_name == "fpa4w":
# MPS UIntx kernel requires N % 4 == 0 when M > 1 (e.g. prefill)
compatible = shape[-1] % group_size == 0 and shape[-2] % 4 == 0
elif group_size != 0:
compatible = shape[-1] % group_size == 0
else:
Expand Down
Loading