diff --git a/examples/models/qwen3_5_moe/export.py b/examples/models/qwen3_5_moe/export.py index 398df1bb086..1a5c4dd9edf 100644 --- a/examples/models/qwen3_5_moe/export.py +++ b/examples/models/qwen3_5_moe/export.py @@ -68,7 +68,35 @@ def _prepare_and_quantize_mlx(model, config, args): pack_all_switch_linears(model) -def load_and_quantize(args): +def _prepare_and_quantize_metal(model, config, args): + """Metal: apply source transforms, quantize experts + non-expert layers.""" + import executorch.backends.apple.metal.ops.gated_delta_rule # noqa: F401 + import executorch.backends.apple.metal.ops.gather_qmv # noqa: F401 + from executorch.examples.models.qwen3_5_moe.metal_source_transformations import ( + metal_source_transformations, + quantize_experts_metal, + ) + + # Quantize expert weights to Metal-compatible INT4 format + if args.qlinear: + quantize_experts_metal(model, config, args.qlinear_group_size) + + if args.qlinear: + from executorch.extension.llm.export.quantize import quantize_model_ + + # skip_incompatible_shapes skips shared_expert_gate (N=1, N%4!=0) + quantize_model_( + model, + qlinear_config=args.qlinear, + qlinear_group_size=args.qlinear_group_size, + skip_incompatible_shapes=True, + ) + + _materialize_buffers(model, config) + metal_source_transformations(model, config=config) + + +def load_and_quantize(args): # noqa: C901 """Load model from checkpoint, optionally quantize. For CUDA: quantizes experts with packed INT4, then transformer layers on CUDA. @@ -146,6 +174,11 @@ def load_and_quantize(args): ) _prepare_and_quantize_mlx(model, config, args) + elif backend == "metal": + if args.prequantized: + raise ValueError("Metal backend does not support --prequantized.") + _prepare_and_quantize_metal(model, config, args) + elif backend == "cuda": if args.prequantized: return load_prequantized_model(args.prequantized, args.max_seq_len) @@ -497,6 +530,8 @@ def export_and_lower(model, config, args): if backend == "mlx": _export_mlx(model, config, args) + elif backend == "metal": + _export_metal(model, config, args) else: _export_cuda(model, config, args) @@ -581,6 +616,100 @@ def _export_mlx(model, config, args): print("Done!") +def _export_metal(model, config, args): + """Export model to .pte via torch.export + Metal backend.""" + import torch._inductor.config as inductor_config + + from executorch.backends.apple.metal.metal_backend import MetalBackend + from executorch.backends.apple.metal.metal_partitioner import MetalPartitioner + from executorch.exir import ( + EdgeCompileConfig, + ExecutorchBackendConfig, + to_edge_transform_and_lower, + ) + from executorch.exir.passes import MemoryPlanningPass + from torch.export import Dim, export + + inductor_config.coordinate_descent_tuning = False + inductor_config.aot_inductor.compile_wrapper_opt_level = "O0" + + # --- Decode method (T=1, static shape) --- + print("Exporting decode method...") + decode_tokens = torch.tensor([[0]], dtype=torch.long) + decode_pos = torch.tensor([0], dtype=torch.long) + with torch.no_grad(): + decode_ep = export(model, (decode_tokens, decode_pos), strict=True) + print("Decode export successful!") + + # --- Prefill method (T>=2, dynamic shape) --- + print("Exporting prefill method...") + prefill_tokens = torch.tensor([[0, 1]], dtype=torch.long) + prefill_pos = torch.tensor([0, 1], dtype=torch.long) + seq_dim = Dim("seq_len", min=2, max=config.max_seq_len - 1) + prefill_dynamic_shapes = ({1: seq_dim}, {0: seq_dim}) + with torch.no_grad(): + prefill_ep = export( + model, + (prefill_tokens, prefill_pos), + dynamic_shapes=prefill_dynamic_shapes, + strict=True, + ) + print("Prefill export successful!") + + # Lower with Metal backend + print("Lowering to ExecuTorch with Metal...") + metadata = { + "get_max_seq_len": config.max_seq_len, + "get_vocab_size": config.vocab_size, + "get_n_layers": config.num_hidden_layers, + "use_kv_cache": True, + "use_sdpa_with_kv_cache": False, + "enable_dynamic_shape": True, + } + et_prog = to_edge_transform_and_lower( + {"decode": decode_ep, "prefill": prefill_ep}, + partitioner={ + "decode": [ + MetalPartitioner( + [MetalBackend.generate_method_name_compile_spec("decode")] + ) + ], + "prefill": [ + MetalPartitioner( + [MetalBackend.generate_method_name_compile_spec("prefill")] + ) + ], + }, + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + _skip_dim_order=True, + ), + constant_methods=metadata, + ) + et_program = et_prog.to_executorch( + config=ExecutorchBackendConfig( + extract_delegate_segments=True, + do_quant_fusion_and_const_prop=True, + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), + ), + ) + + # Save .pte + os.makedirs(args.output_dir, exist_ok=True) + pte_path = os.path.join(args.output_dir, "model.pte") + print(f"Saving to {pte_path}...") + with open(pte_path, "wb") as f: + et_program.write_to_file(f) + size_mb = os.path.getsize(pte_path) / (1024 * 1024) + print(f"Saved {size_mb:.1f} MB") + + if et_program._tensor_data: + et_program.write_tensor_data_to_file(args.output_dir) + print(f"Saved tensor data to {args.output_dir}/") + + print("Done!") + + def _export_cuda(model, config, args): """Export model to .pte via torch.export + CUDA backend. @@ -708,10 +837,8 @@ def _export_cuda(model, config, args): # --------------------------------------------------------------------------- -def main(): - parser = argparse.ArgumentParser( - description="Export Qwen3.5 MoE to ExecuTorch (CUDA or MLX)" - ) +def main(): # noqa: C901 + parser = argparse.ArgumentParser(description="Export Qwen3.5 MoE to ExecuTorch") parser.add_argument( "--model-dir", default=None, @@ -729,13 +856,13 @@ def main(): parser.add_argument( "--backend", default="cuda", - choices=["cuda", "mlx"], - help="Backend for export: cuda (default) or mlx.", + choices=["cuda", "mlx", "metal"], + help="Backend for export: cuda (default), mlx, or metal.", ) parser.add_argument( "--qlinear", default=None, - choices=["4w", "8w", "8da4w", "8da8w"], + choices=["4w", "8w", "8da4w", "8da8w", "fpa4w"], help="Quantize linear layers.", ) parser.add_argument( @@ -805,6 +932,13 @@ def main(): if args.turboquant: parser.error("--turboquant is not supported with --backend mlx") + if args.backend == "metal": + if args.turboquant: + parser.error("--turboquant is not supported with --backend metal") + + if args.qlinear == "fpa4w" and args.backend != "metal": + parser.error("--qlinear=fpa4w can only be used with --backend=metal") + model, config = load_and_quantize(args) if args.backend == "cuda": diff --git a/extension/llm/export/quantize.py b/extension/llm/export/quantize.py index fb2678ff60f..90590f21cb9 100644 --- a/extension/llm/export/quantize.py +++ b/extension/llm/export/quantize.py @@ -111,6 +111,9 @@ def _check_shape_compatible(m, fqn, config_name, group_size, skip_incompatible_s shape = m.weight.shape if config_name == "nvfp4": compatible = shape[-2] % group_size == 0 and shape[-1] % group_size == 0 + elif config_name == "fpa4w": + # MPS UIntx kernel requires N % 4 == 0 when M > 1 (e.g. prefill) + compatible = shape[-1] % group_size == 0 and shape[-2] % 4 == 0 elif group_size != 0: compatible = shape[-1] % group_size == 0 else: