diff --git a/models/speaker-diarization/sortformer-streaming/conversion/build_all_variants.py b/models/speaker-diarization/sortformer-streaming/conversion/build_all_variants.py new file mode 100644 index 0000000..0c04ecb --- /dev/null +++ b/models/speaker-diarization/sortformer-streaming/conversion/build_all_variants.py @@ -0,0 +1,122 @@ +""" +Rebuild all 6 Sortformer pipeline variants from current conversion code. +Fixes issue #726 (stale HF models carry an input==output BNNS alias built on torch 2.9.x). +Local only — no upload. Each variant: write config.py, run convert_to_coreml.py in a fresh +subprocess, compile the pipeline to its app-expected name, verify no alias + ANE load. +""" +import os +import shutil +import subprocess +import sys + +HERE = os.path.dirname(os.path.abspath(__file__)) +PY = sys.executable +# Env-overridable so the same driver builds fp16 and palettized sets: +# PALETTIZE_NBITS=6 OUT_DIR=build_palettized_models python build_all_variants.py +PALETTIZE_NBITS = int(os.environ.get("PALETTIZE_NBITS", "0")) +OUT = os.path.join(HERE, os.environ.get("OUT_DIR", "build_fixed_models")) +os.makedirs(OUT, exist_ok=True) + +CONFIG_TEMPLATE = '''class Config: + chunk_len = {chunk_len} + chunk_right_context = {chunk_right_context} + chunk_left_context = {chunk_left_context} + fifo_len = {fifo_len} + spkcache_len = {spkcache_len} + spkcache_update_period = {spkcache_update_period} + + # do not touch these + subsampling_factor = 8 + sample_rate = 16000 + mel_window = 400 + mel_stride = 160 + frame_duration = 0.08 + + chunk_frames = (chunk_len + chunk_right_context + chunk_left_context) * subsampling_factor + coreml_audio_samples = (chunk_frames - 1) * mel_stride + mel_window + preproc_feature_frames = chunk_len * subsampling_factor + preproc_audio_hop = preproc_feature_frames * mel_stride +''' + +CONFIGS = { + "Default": dict(chunk_len=6, chunk_right_context=7, chunk_left_context=1, + fifo_len=40, spkcache_len=188, spkcache_update_period=31), + "NvidiaLow": dict(chunk_len=6, chunk_right_context=7, chunk_left_context=1, + fifo_len=188, spkcache_len=188, spkcache_update_period=144), + "NvidiaHigh": dict(chunk_len=340, chunk_right_context=40, chunk_left_context=1, + fifo_len=40, spkcache_len=188, spkcache_update_period=300), + # Higher-throughput streaming: Default context, larger 25-frame chunk (~2s output + # latency, ~4x RTFx of Default). Maps to Swift SortformerConfig.efficientV2_1. + "Efficient": dict(chunk_len=25, chunk_right_context=7, chunk_left_context=1, + fifo_len=40, spkcache_len=188, spkcache_update_period=31), +} + +# (config_key, model_version) -> final mlmodelc name expected by the app +VARIANTS = [ + ("Default", "v2.1", "Sortformer_v2.1"), + ("Default", "v2", "Sortformer_v2"), + ("NvidiaLow", "v2.1", "SortformerNvidiaLow_v2.1"), + ("NvidiaLow", "v2", "SortformerNvidiaLow_v2"), + ("NvidiaHigh", "v2.1", "SortformerNvidiaHigh_v2.1"), + ("NvidiaHigh", "v2", "SortformerNvidiaHigh_v2"), + ("Efficient", "v2.1", "SortformerEfficient_v2.1"), +] + +MODEL_NAME = { + "v2.1": "nvidia/diar_streaming_sortformer_4spk-v2.1", + "v2": "nvidia/diar_streaming_sortformer_4spk-v2", +} + + +def write_config(cfg_key): + with open(os.path.join(HERE, "config.py"), "w") as f: + f.write(CONFIG_TEMPLATE.format(**CONFIGS[cfg_key])) + + +def verify(mlc): + import coremltools as ct + mil = open(os.path.join(mlc, "model1", "model.mil")).read().splitlines()[3] + alias = "chunk_pre_encoder_embs_out" in mil + m = ct.models.CompiledMLModel(mlc, ct.ComputeUnit.ALL) # raises if it won't load on ANE + return (not alias) + + +def main(): + results = [] + for cfg_key, ver, final_name in VARIANTS: + print(f"\n{'='*70}\nBUILD {final_name} (config={cfg_key}, model={ver})\n{'='*70}", flush=True) + write_config(cfg_key) + build_dir = os.path.join(HERE, f"build_v_{final_name}") + if os.path.exists(build_dir): + shutil.rmtree(build_dir) + cmd = [PY, "convert_to_coreml.py", "--model_name", MODEL_NAME[ver], "--output_dir", build_dir] + if PALETTIZE_NBITS > 0: + cmd += ["--palettize_head_nbits", str(PALETTIZE_NBITS)] + r = subprocess.run(cmd, cwd=HERE, capture_output=True, text=True) + pkg = os.path.join(build_dir, "SortformerPipeline.mlpackage") + if not os.path.exists(pkg): + print(f" FAILED: no pipeline produced.\n stderr tail:\n{r.stderr[-1500:]}", flush=True) + results.append((final_name, "BUILD_FAILED")) + continue + import coremltools as ct + mlc = ct.utils.compile_model(pkg) + dst = os.path.join(OUT, final_name + ".mlmodelc") + if os.path.exists(dst): + shutil.rmtree(dst) + shutil.copytree(mlc, dst) + try: + ok = verify(dst) + status = "OK (no alias, ANE load)" if ok else "ALIAS STILL PRESENT" + except Exception as e: + status = f"VERIFY_FAILED: {type(e).__name__}: {str(e)[:120]}" + sz = os.popen(f"du -sh '{dst}'").read().split()[0] + print(f" -> {dst} [{sz}] {status}", flush=True) + results.append((final_name, status)) + + print(f"\n{'='*70}\nSUMMARY\n{'='*70}", flush=True) + for name, status in results: + print(f" {name:32s} {status}", flush=True) + + +if __name__ == "__main__": + main() diff --git a/models/speaker-diarization/sortformer-streaming/conversion/convert_to_coreml.py b/models/speaker-diarization/sortformer-streaming/conversion/convert_to_coreml.py index a5914db..928b2c9 100644 --- a/models/speaker-diarization/sortformer-streaming/conversion/convert_to_coreml.py +++ b/models/speaker-diarization/sortformer-streaming/conversion/convert_to_coreml.py @@ -4,6 +4,7 @@ import os import numpy as np +from coremltools.optimize.coreml import OpPalettizerConfig, OptimizationConfig, palettize_weights from nemo.collections.asr.models import SortformerEncLabelModel from coreml_wrappers import PreprocessorWrapper, PreEncoderWrapper, SortformerHeadWrapper from config import Config @@ -95,7 +96,8 @@ def export_pipeline( pre_encoder_precision: str = "fp32", head_precision: str = "fp16", skip_modules: bool = False, - verify: bool = False + verify: bool = False, + palettize_head_nbits: int = 0 ): """ Export the Sortformer model as a pipeline of separate CoreML models. @@ -300,9 +302,22 @@ def get_precision(s): # Both models now use compute_units=ALL. # The pre_encoder uses ANE-safe gather operations in fixed_concat_and_pad # to avoid zero-length slices that would crash on ANE. - + + # Optional weight palettization of the head (conformer+transformer = ~98% of size). + # 6-bit kmeans LUT compression cuts the model ~2.5x (matches Argmax's speakerkit) with + # no measurable speed change and preserves speaker-argmax decisions. Palettization (LUT) + # is GPU-safe, unlike int8 linear quantization which crashes MPSGraph (#726 RAM on old devices). + if palettize_head_nbits > 0: + print(f" Palettizing head weights to {palettize_head_nbits}-bit (kmeans LUT)...") + palette_cfg = OptimizationConfig( + global_config=OpPalettizerConfig( + nbits=palettize_head_nbits, mode="kmeans", weight_threshold=512 + ) + ) + head_mlmodel = palettize_weights(head_mlmodel, palette_cfg) + pipeline_model = ct.utils.make_pipeline( - pre_encoder_mlmodel, + pre_encoder_mlmodel, head_mlmodel, compute_units=ct.ComputeUnit.ALL ) @@ -372,6 +387,8 @@ def get_precision(s): help="Conformer encoder precision") parser.add_argument("--skip_modules", action="store_true", help="Skip modules in pipeline export") parser.add_argument("--verify", action="store_true", help="Skip pipeline in pipeline export") + parser.add_argument("--palettize_head_nbits", type=int, default=0, + help="If >0, palettize head weights to N-bit kmeans LUT (e.g. 6). ~2.5x smaller, GPU-safe.") args = parser.parse_args() @@ -385,4 +402,5 @@ def get_precision(s): head_precision=args.head_precision, skip_modules=args.skip_modules, verify=args.verify, + palettize_head_nbits=args.palettize_head_nbits, ) diff --git a/models/speaker-diarization/sortformer-streaming/conversion/offline_argmax_bench.py b/models/speaker-diarization/sortformer-streaming/conversion/offline_argmax_bench.py new file mode 100644 index 0000000..4945f21 --- /dev/null +++ b/models/speaker-diarization/sortformer-streaming/conversion/offline_argmax_bench.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python3 +"""Offline head-to-head latency benchmark: FluidAudio fused Sortformer vs Argmax. + +Both process the same 30.72 s window (mel [1,128,3072] = 3072 x 10 ms hops). Argmax's +Sortformer is an OFFLINE batch model (3-model chain, no streaming state); FluidAudio's +offline export is a single fused graph (mel -> speaker_preds). This measures the offline +encoder/throughput gap only — it is NOT a streaming comparison. + +Models (override paths via the constants below): + - FluidAudio fused offline: exported via the NeMo offline path (streaming_mode=False); + inputs mel[1,128,3072] + mel_length[1] -> speaker_preds. + - Argmax: argmaxinc/speakerkit-pro sortformer/v2-1/384_94MB/{MelSpectrogram, + AudioConformerPreEncoder,SortformerFullEncoder}.mlmodelc (proprietary; download via + the Argmax v2 Playground app or the speakerkit-pro HF repo). + +Random inputs of the correct shape/dtype — valid for latency of these static-shape graphs. +Interleaved per-stage timing (median of N after warmup), ComputeUnit.ALL. + +Usage: python offline_argmax_bench.py +Result (M5 Pro, 2026-06-23): FluidAudio 1.3-1.4x faster offline; see Documentation/ +Diarization/Sortformer.md#benchmarks. +""" +import os +import time + +import numpy as np +import coremltools as ct + +# --- model paths (override with env vars) --- +ARGMAX_DIR = os.environ.get("ARGMAX_DIR", "/tmp/argmax_sf/sortformer/v2-1/384_94MB") +OURS_FP16 = os.environ.get("OURS_FP16", "/tmp/sf_offline_fp16.mlmodelc") +OURS_PALETTE6 = os.environ.get("OURS_PALETTE6", "/tmp/sf_offline_palette6.mlmodelc") + +CU = ct.ComputeUnit.ALL +WINDOW_S = 491520 / 16000.0 # 30.72 s +WARMUP, RUNS = 12, 120 + + +def load(path): + return ct.models.CompiledMLModel(path, compute_units=CU) + + +def bench(model, feed, warm=WARMUP, n=RUNS): + for _ in range(warm): + model.predict(feed) + ts = [] + for _ in range(n): + s = time.perf_counter() + model.predict(feed) + ts.append((time.perf_counter() - s) * 1e3) + ts = np.array(ts) + return float(np.median(ts)), float(np.percentile(ts, 95)) + + +def rtfx(ms): + return WINDOW_S / (ms / 1e3) + + +def main(): + print("loading models...", flush=True) + mel = load(f"{ARGMAX_DIR}/MelSpectrogram.mlmodelc") + pre = load(f"{ARGMAX_DIR}/AudioConformerPreEncoder.mlmodelc") + full = load(f"{ARGMAX_DIR}/SortformerFullEncoder.mlmodelc") + ours_fp16 = load(OURS_FP16) + ours_p6 = load(OURS_PALETTE6) + + f16 = np.float16 + audio = {"audio": np.random.randn(491520).astype(f16)} + melfeat = {"melspectrogram_features": np.random.randn(1, 1, 3073, 128).astype(f16)} + fullin = { + "downsampled_melspectrogram_features": np.random.randn(1, 512, 1, 384).astype(f16), + "conformer_encoder_padding_mask": np.zeros((1, 384), f16), + "conformer_encoder_qk_mask": np.zeros((1, 1, 384, 384), f16), + "transformer_encoder_mask": np.zeros((1, 384), f16), + "input_1": np.ones((1, 1, 1, 1), f16), + } + ours_in = {"mel": np.random.randn(1, 128, 3072).astype(np.float32), "mel_length": np.array([3072], np.int32)} + + mel_ms, _ = bench(mel, audio) + pre_ms, _ = bench(pre, melfeat) + full_ms, _ = bench(full, fullin) + ours_ms, ours_p95 = bench(ours_fp16, ours_in) + p6_ms, p6_p95 = bench(ours_p6, ours_in) + + argmax_enc = pre_ms + full_ms + argmax_e2e = mel_ms + pre_ms + full_ms + ours_e2e = ours_ms + mel_ms + p6_e2e = p6_ms + mel_ms + + print(f"\nWindow = {WINDOW_S:.2f}s, ComputeUnit.ALL, median of {RUNS} ({WARMUP} warmup)\n") + print(f"{'stage':38} {'ms':>8} {'RTFx':>9}") + print(f"{'Argmax MelSpectrogram':38} {mel_ms:8.2f} {rtfx(mel_ms):9.0f}") + print(f"{'Argmax AudioConformerPreEncoder':38} {pre_ms:8.2f} {rtfx(pre_ms):9.0f}") + print(f"{'Argmax SortformerFullEncoder':38} {full_ms:8.2f} {rtfx(full_ms):9.0f}") + print(f"{'OURS fused offline (fp16)':38} {ours_ms:8.2f} {rtfx(ours_ms):9.0f} p95 {ours_p95:.2f}") + print(f"{'OURS fused offline (palette6)':38} {p6_ms:8.2f} {rtfx(p6_ms):9.0f} p95 {p6_p95:.2f}") + print("\n=== ENCODER only (mel->preds, fair apples-to-apples) ===") + print(f" Argmax (PreEnc+FullEnc) {argmax_enc:8.2f} ms {rtfx(argmax_enc):6.0f}x (2 calls, ANE->GPU handoff)") + print(f" OURS fp16 fused {ours_ms:8.2f} ms {rtfx(ours_ms):6.0f}x (1 call) -> {argmax_enc/ours_ms:.2f}x faster") + print(f" OURS palette6 fused {p6_ms:8.2f} ms {rtfx(p6_ms):6.0f}x -> {argmax_enc/p6_ms:.2f}x faster") + print("\n=== END-TO-END incl mel (same MelSpectrogram tax both sides) ===") + print(f" Argmax full (3 calls) {argmax_e2e:8.2f} ms {rtfx(argmax_e2e):6.0f}x") + print(f" OURS fp16 + mel {ours_e2e:8.2f} ms {rtfx(ours_e2e):6.0f}x -> {argmax_e2e/ours_e2e:.2f}x faster") + print(f" OURS palette6 + mel {p6_e2e:8.2f} ms {rtfx(p6_e2e):6.0f}x -> {argmax_e2e/p6_e2e:.2f}x faster") + + +if __name__ == "__main__": + main()