Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
"""
Rebuild all 6 Sortformer pipeline variants from current conversion code.
Fixes issue #726 (stale HF models carry an input==output BNNS alias built on torch 2.9.x).
Local only — no upload. Each variant: write config.py, run convert_to_coreml.py in a fresh
subprocess, compile the pipeline to its app-expected name, verify no alias + ANE load.
"""
import os
import shutil
import subprocess
import sys

HERE = os.path.dirname(os.path.abspath(__file__))
PY = sys.executable
# Env-overridable so the same driver builds fp16 and palettized sets:
# PALETTIZE_NBITS=6 OUT_DIR=build_palettized_models python build_all_variants.py
PALETTIZE_NBITS = int(os.environ.get("PALETTIZE_NBITS", "0"))
OUT = os.path.join(HERE, os.environ.get("OUT_DIR", "build_fixed_models"))
os.makedirs(OUT, exist_ok=True)

CONFIG_TEMPLATE = '''class Config:
chunk_len = {chunk_len}
chunk_right_context = {chunk_right_context}
chunk_left_context = {chunk_left_context}
fifo_len = {fifo_len}
spkcache_len = {spkcache_len}
spkcache_update_period = {spkcache_update_period}

# do not touch these
subsampling_factor = 8
sample_rate = 16000
mel_window = 400
mel_stride = 160
frame_duration = 0.08

chunk_frames = (chunk_len + chunk_right_context + chunk_left_context) * subsampling_factor
coreml_audio_samples = (chunk_frames - 1) * mel_stride + mel_window
preproc_feature_frames = chunk_len * subsampling_factor
preproc_audio_hop = preproc_feature_frames * mel_stride
'''

CONFIGS = {
"Default": dict(chunk_len=6, chunk_right_context=7, chunk_left_context=1,
fifo_len=40, spkcache_len=188, spkcache_update_period=31),
"NvidiaLow": dict(chunk_len=6, chunk_right_context=7, chunk_left_context=1,
fifo_len=188, spkcache_len=188, spkcache_update_period=144),
"NvidiaHigh": dict(chunk_len=340, chunk_right_context=40, chunk_left_context=1,
fifo_len=40, spkcache_len=188, spkcache_update_period=300),
# Higher-throughput streaming: Default context, larger 25-frame chunk (~2s output
# latency, ~4x RTFx of Default). Maps to Swift SortformerConfig.efficientV2_1.
"Efficient": dict(chunk_len=25, chunk_right_context=7, chunk_left_context=1,
fifo_len=40, spkcache_len=188, spkcache_update_period=31),
}

# (config_key, model_version) -> final mlmodelc name expected by the app
VARIANTS = [
("Default", "v2.1", "Sortformer_v2.1"),
("Default", "v2", "Sortformer_v2"),
("NvidiaLow", "v2.1", "SortformerNvidiaLow_v2.1"),
("NvidiaLow", "v2", "SortformerNvidiaLow_v2"),
("NvidiaHigh", "v2.1", "SortformerNvidiaHigh_v2.1"),
("NvidiaHigh", "v2", "SortformerNvidiaHigh_v2"),
("Efficient", "v2.1", "SortformerEfficient_v2.1"),
]

MODEL_NAME = {
"v2.1": "nvidia/diar_streaming_sortformer_4spk-v2.1",
"v2": "nvidia/diar_streaming_sortformer_4spk-v2",
}


def write_config(cfg_key):
with open(os.path.join(HERE, "config.py"), "w") as f:
f.write(CONFIG_TEMPLATE.format(**CONFIGS[cfg_key]))


def verify(mlc):
import coremltools as ct
mil = open(os.path.join(mlc, "model1", "model.mil")).read().splitlines()[3]
alias = "chunk_pre_encoder_embs_out" in mil
m = ct.models.CompiledMLModel(mlc, ct.ComputeUnit.ALL) # raises if it won't load on ANE
return (not alias)


def main():
results = []
for cfg_key, ver, final_name in VARIANTS:
print(f"\n{'='*70}\nBUILD {final_name} (config={cfg_key}, model={ver})\n{'='*70}", flush=True)
write_config(cfg_key)
build_dir = os.path.join(HERE, f"build_v_{final_name}")
if os.path.exists(build_dir):
shutil.rmtree(build_dir)
cmd = [PY, "convert_to_coreml.py", "--model_name", MODEL_NAME[ver], "--output_dir", build_dir]
if PALETTIZE_NBITS > 0:
cmd += ["--palettize_head_nbits", str(PALETTIZE_NBITS)]
r = subprocess.run(cmd, cwd=HERE, capture_output=True, text=True)
pkg = os.path.join(build_dir, "SortformerPipeline.mlpackage")
if not os.path.exists(pkg):
print(f" FAILED: no pipeline produced.\n stderr tail:\n{r.stderr[-1500:]}", flush=True)
results.append((final_name, "BUILD_FAILED"))
continue
import coremltools as ct
mlc = ct.utils.compile_model(pkg)
dst = os.path.join(OUT, final_name + ".mlmodelc")
if os.path.exists(dst):
shutil.rmtree(dst)
shutil.copytree(mlc, dst)
try:
ok = verify(dst)
status = "OK (no alias, ANE load)" if ok else "ALIAS STILL PRESENT"
except Exception as e:
status = f"VERIFY_FAILED: {type(e).__name__}: {str(e)[:120]}"
sz = os.popen(f"du -sh '{dst}'").read().split()[0]
print(f" -> {dst} [{sz}] {status}", flush=True)
results.append((final_name, status))

print(f"\n{'='*70}\nSUMMARY\n{'='*70}", flush=True)
for name, status in results:
print(f" {name:32s} {status}", flush=True)


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import numpy as np

from coremltools.optimize.coreml import OpPalettizerConfig, OptimizationConfig, palettize_weights
from nemo.collections.asr.models import SortformerEncLabelModel
from coreml_wrappers import PreprocessorWrapper, PreEncoderWrapper, SortformerHeadWrapper
from config import Config
Expand Down Expand Up @@ -95,7 +96,8 @@ def export_pipeline(
pre_encoder_precision: str = "fp32",
head_precision: str = "fp16",
skip_modules: bool = False,
verify: bool = False
verify: bool = False,
palettize_head_nbits: int = 0
):
"""
Export the Sortformer model as a pipeline of separate CoreML models.
Expand Down Expand Up @@ -300,9 +302,22 @@ def get_precision(s):
# Both models now use compute_units=ALL.
# The pre_encoder uses ANE-safe gather operations in fixed_concat_and_pad
# to avoid zero-length slices that would crash on ANE.


# Optional weight palettization of the head (conformer+transformer = ~98% of size).
# 6-bit kmeans LUT compression cuts the model ~2.5x (matches Argmax's speakerkit) with
# no measurable speed change and preserves speaker-argmax decisions. Palettization (LUT)
# is GPU-safe, unlike int8 linear quantization which crashes MPSGraph (#726 RAM on old devices).
if palettize_head_nbits > 0:
print(f" Palettizing head weights to {palettize_head_nbits}-bit (kmeans LUT)...")
palette_cfg = OptimizationConfig(
global_config=OpPalettizerConfig(
nbits=palettize_head_nbits, mode="kmeans", weight_threshold=512
)
)
head_mlmodel = palettize_weights(head_mlmodel, palette_cfg)

pipeline_model = ct.utils.make_pipeline(
pre_encoder_mlmodel,
pre_encoder_mlmodel,
head_mlmodel,
compute_units=ct.ComputeUnit.ALL
)
Expand Down Expand Up @@ -372,6 +387,8 @@ def get_precision(s):
help="Conformer encoder precision")
parser.add_argument("--skip_modules", action="store_true", help="Skip modules in pipeline export")
parser.add_argument("--verify", action="store_true", help="Skip pipeline in pipeline export")
parser.add_argument("--palettize_head_nbits", type=int, default=0,
help="If >0, palettize head weights to N-bit kmeans LUT (e.g. 6). ~2.5x smaller, GPU-safe.")

args = parser.parse_args()

Expand All @@ -385,4 +402,5 @@ def get_precision(s):
head_precision=args.head_precision,
skip_modules=args.skip_modules,
verify=args.verify,
palettize_head_nbits=args.palettize_head_nbits,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
#!/usr/bin/env python3
"""Offline head-to-head latency benchmark: FluidAudio fused Sortformer vs Argmax.

Both process the same 30.72 s window (mel [1,128,3072] = 3072 x 10 ms hops). Argmax's
Sortformer is an OFFLINE batch model (3-model chain, no streaming state); FluidAudio's
offline export is a single fused graph (mel -> speaker_preds). This measures the offline
encoder/throughput gap only — it is NOT a streaming comparison.

Models (override paths via the constants below):
- FluidAudio fused offline: exported via the NeMo offline path (streaming_mode=False);
inputs mel[1,128,3072] + mel_length[1] -> speaker_preds.
- Argmax: argmaxinc/speakerkit-pro sortformer/v2-1/384_94MB/{MelSpectrogram,
AudioConformerPreEncoder,SortformerFullEncoder}.mlmodelc (proprietary; download via
the Argmax v2 Playground app or the speakerkit-pro HF repo).

Random inputs of the correct shape/dtype — valid for latency of these static-shape graphs.
Interleaved per-stage timing (median of N after warmup), ComputeUnit.ALL.

Usage: python offline_argmax_bench.py
Result (M5 Pro, 2026-06-23): FluidAudio 1.3-1.4x faster offline; see Documentation/
Diarization/Sortformer.md#benchmarks.
"""
import os
import time

import numpy as np
import coremltools as ct

# --- model paths (override with env vars) ---
ARGMAX_DIR = os.environ.get("ARGMAX_DIR", "/tmp/argmax_sf/sortformer/v2-1/384_94MB")
OURS_FP16 = os.environ.get("OURS_FP16", "/tmp/sf_offline_fp16.mlmodelc")
OURS_PALETTE6 = os.environ.get("OURS_PALETTE6", "/tmp/sf_offline_palette6.mlmodelc")

CU = ct.ComputeUnit.ALL
WINDOW_S = 491520 / 16000.0 # 30.72 s
WARMUP, RUNS = 12, 120


def load(path):
return ct.models.CompiledMLModel(path, compute_units=CU)


def bench(model, feed, warm=WARMUP, n=RUNS):
for _ in range(warm):
model.predict(feed)
ts = []
for _ in range(n):
s = time.perf_counter()
model.predict(feed)
ts.append((time.perf_counter() - s) * 1e3)
ts = np.array(ts)
return float(np.median(ts)), float(np.percentile(ts, 95))


def rtfx(ms):
return WINDOW_S / (ms / 1e3)


def main():
print("loading models...", flush=True)
mel = load(f"{ARGMAX_DIR}/MelSpectrogram.mlmodelc")
pre = load(f"{ARGMAX_DIR}/AudioConformerPreEncoder.mlmodelc")
full = load(f"{ARGMAX_DIR}/SortformerFullEncoder.mlmodelc")
ours_fp16 = load(OURS_FP16)
ours_p6 = load(OURS_PALETTE6)

f16 = np.float16
audio = {"audio": np.random.randn(491520).astype(f16)}
melfeat = {"melspectrogram_features": np.random.randn(1, 1, 3073, 128).astype(f16)}
fullin = {
"downsampled_melspectrogram_features": np.random.randn(1, 512, 1, 384).astype(f16),
"conformer_encoder_padding_mask": np.zeros((1, 384), f16),
"conformer_encoder_qk_mask": np.zeros((1, 1, 384, 384), f16),
"transformer_encoder_mask": np.zeros((1, 384), f16),
"input_1": np.ones((1, 1, 1, 1), f16),
}
ours_in = {"mel": np.random.randn(1, 128, 3072).astype(np.float32), "mel_length": np.array([3072], np.int32)}

mel_ms, _ = bench(mel, audio)
pre_ms, _ = bench(pre, melfeat)
full_ms, _ = bench(full, fullin)
ours_ms, ours_p95 = bench(ours_fp16, ours_in)
p6_ms, p6_p95 = bench(ours_p6, ours_in)

argmax_enc = pre_ms + full_ms
argmax_e2e = mel_ms + pre_ms + full_ms
ours_e2e = ours_ms + mel_ms
p6_e2e = p6_ms + mel_ms

print(f"\nWindow = {WINDOW_S:.2f}s, ComputeUnit.ALL, median of {RUNS} ({WARMUP} warmup)\n")
print(f"{'stage':38} {'ms':>8} {'RTFx':>9}")
print(f"{'Argmax MelSpectrogram':38} {mel_ms:8.2f} {rtfx(mel_ms):9.0f}")
print(f"{'Argmax AudioConformerPreEncoder':38} {pre_ms:8.2f} {rtfx(pre_ms):9.0f}")
print(f"{'Argmax SortformerFullEncoder':38} {full_ms:8.2f} {rtfx(full_ms):9.0f}")
print(f"{'OURS fused offline (fp16)':38} {ours_ms:8.2f} {rtfx(ours_ms):9.0f} p95 {ours_p95:.2f}")
print(f"{'OURS fused offline (palette6)':38} {p6_ms:8.2f} {rtfx(p6_ms):9.0f} p95 {p6_p95:.2f}")
print("\n=== ENCODER only (mel->preds, fair apples-to-apples) ===")
print(f" Argmax (PreEnc+FullEnc) {argmax_enc:8.2f} ms {rtfx(argmax_enc):6.0f}x (2 calls, ANE->GPU handoff)")
print(f" OURS fp16 fused {ours_ms:8.2f} ms {rtfx(ours_ms):6.0f}x (1 call) -> {argmax_enc/ours_ms:.2f}x faster")
print(f" OURS palette6 fused {p6_ms:8.2f} ms {rtfx(p6_ms):6.0f}x -> {argmax_enc/p6_ms:.2f}x faster")
print("\n=== END-TO-END incl mel (same MelSpectrogram tax both sides) ===")
print(f" Argmax full (3 calls) {argmax_e2e:8.2f} ms {rtfx(argmax_e2e):6.0f}x")
print(f" OURS fp16 + mel {ours_e2e:8.2f} ms {rtfx(ours_e2e):6.0f}x -> {argmax_e2e/ours_e2e:.2f}x faster")
print(f" OURS palette6 + mel {p6_e2e:8.2f} ms {rtfx(p6_e2e):6.0f}x -> {argmax_e2e/p6_e2e:.2f}x faster")


if __name__ == "__main__":
main()