Skip to content

WIP Support per expert amax in TEGroupedMLP#1550

Open
jenchen13 wants to merge 4 commits into
mainfrom
jennifchen/te_per_expert
Open

WIP Support per expert amax in TEGroupedMLP#1550
jenchen13 wants to merge 4 commits into
mainfrom
jennifchen/te_per_expert

Conversation

@jenchen13
Copy link
Copy Markdown
Contributor

@jenchen13 jenchen13 commented May 27, 2026

What does this PR do?

Type of change: ?

note: if different TP/EP is used during checkpoint restore, the per expert weight quantizer may not restore properly. TODO fix this

Usage

# Add a code snippet demonstrating how to use this

Testing

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅ / ❌ / N/A
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅ / ❌ / N/A
  • Did you write any new necessary tests?: ✅ / ❌ / N/A
  • Did you update Changelog?: ✅ / ❌ / N/A
  • Did you get Claude approval on this PR?: ✅ / ❌ / N/A

Additional Information

Summary by CodeRabbit

  • New Features

    • Opt-in per-expert weight quantization for Transformer Engine grouped linear layers, with per-expert calibration on model restore when enabled.
  • Tests

    • Added Mixture-of-Experts (MoE) quantization validation tests comparing grouped vs sequential expert configurations and quantization error metrics.

Review Change Stack

jenchen13 added 3 commits May 27, 2026 15:41
…nfra fixes

modelopt/torch/quantization/plugins/transformer_engine.py:
  MODELOPT_TEGROUPED_PER_EXPERT_QUANTIZER=1 opts into per-gemm
  weight_quantizer_0..N-1 inside _QuantTEGroupedLinear (deepcopied from
  the shared weight_quantizer). Lets TEGroupedMLP recover per-expert
  amax granularity, matching SequentialMLP's default behavior.

modelopt/torch/distill/plugins/megatron.py:
  LogitsKLLoss.forward prints student/teacher logit stats (mean/std/
  min/max/shape) on rank 0 each call. Diagnostic for the QAD loss-spike
  investigation — confirms which spec produces which logits without
  changing the KL math.

tests/gpu_megatron/torch/quantization/plugins/test_megatron.py:
  New test_te_grouped_vs_sequential_default_amax + ..._default_loss
  cover the structural amax asymmetry between TEGroupedMLP and
  SequentialMLP (TEGrouped per-linear amax = max-over-Sequential-experts
  amax) and a finiteness sanity check on the resulting quant error.

tools/launcher/common/service_utils.sh:
  - Fall back to SLURM_PROCID / SLURM_LOCALID when PMIX_*/OMPI_* are
    unset, so `[[ "$mpi_local_rank" -eq 0 ]]` doesn't silently pass on
    every rank under plain srun.
  - util_install_extra_dep: per-node marker so concurrent ranks wait
    for rank 0 to finish installing (concurrent pip on a shared FS
    leaves a broken state); also installs nvidia-resiliency-ext.

Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
- transformer_engine.py: dedup `import copy`/`import os` left over from the
  rebase, sort the four imports alphabetically.
- transformer_engine.py: comment near the per-expert weight_quantizer setup
  explaining that base modelopt_post_restore won't re-calibrate the
  weight_quantizer_{i} modules, so save/restore is only safe when TP/EP is
  unchanged. Per-channel _amax shape depends on the TP-sliced output dim.
- service_utils.sh: drop the duplicated mpi_rank / mpi_local_rank
  re-assignments — main already carries the SLURM fallback, the extra two
  lines were leftover rebase noise.

Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
@jenchen13 jenchen13 requested a review from a team as a code owner May 27, 2026 23:00
@jenchen13 jenchen13 requested a review from sychen52 May 27, 2026 23:00
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 27, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 1ce4fbe3-0096-4792-aea4-66dc4dee5df3

📥 Commits

Reviewing files that changed from the base of the PR and between 4a58f02 and b1e32d9.

📒 Files selected for processing (1)
  • modelopt/torch/quantization/plugins/transformer_engine.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/torch/quantization/plugins/transformer_engine.py

📝 Walkthrough

Walkthrough

This PR adds an environment-gated per-GEMM weight-quantizer mode to Transformer Engine's TEGroupedLinear, implements deep-copied per-GEMM quantizers, updates restore/calibration and forward paths to use per-GEMM quantizers, and adds tests comparing TEGrouped vs Sequential MoE quantization behavior.

Changes

Per-expert weight quantization for TEGroupedLinear

Layer / File(s) Summary
Environment control and imports
modelopt/torch/quantization/plugins/transformer_engine.py
Module imports updated with copy and os. New _per_expert_weight_quantizer_enabled() reads MODELOPT_TEGROUPED_PER_EXPERT_QUANTIZER environment variable.
Per-GEMM weight quantizer setup and selection
modelopt/torch/quantization/plugins/transformer_engine.py
_QuantTEGroupedLinear._setup() conditionally initializes per-GEMM weight quantizers by deep-copying base weight_quantizer into weight_quantizer_{i} modules. Added _get_weight_quantizer(gemm_idx) and updated modelopt_post_restore/iter_weights_for_calibration() to handle per-GEMM quantizers and re-calibration when _amax is present.
Quantized forward execution with per-GEMM quantizers
modelopt/torch/quantization/plugins/transformer_engine.py
Quantized TEGrouped linear forward path now applies _get_weight_quantizer(gemm_idx) per GEMM to each weight argument instead of using a single shared quantizer.
MoE quantization validation tests
tests/gpu_megatron/torch/quantization/plugins/test_megatron.py
Added import math and two test helpers/tests that (1) verify TEGrouped per-linear amax equals the max across Sequential experts while Sequential shows per-expert amax divergence, and (2) compare post-quantization mean absolute output error vs BF16 references for TEGrouped and Sequential models.

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 5 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 42.86% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'WIP Support per expert amax in TEGroupedMLP' directly describes the main feature being added: per-expert amax support for TEGroupedMLP weight quantization, which aligns with the changeset's core functionality.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed No security anti-patterns found. No unsafe deserialization, hardcoded trust_remote_code, eval/exec on untrusted input, or nosec comments. Only stdlib imports with safe usage.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch jennifchen/te_per_expert

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 27, 2026

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1550/

Built to branch gh-pages at 2026-05-27 23:10 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning

CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.

Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.

👉 Steps to fix this

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/torch/quantization/plugins/transformer_engine.py (1)

147-166: ⚠️ Potential issue | 🟠 Major | 🏗️ Heavy lift

Per-expert restore is still unsafe when TP/EP changes.

This branch creates weight_quantizer_{i} modules, but modelopt_post_restore() still runs the base restore path that only knows about self.weight_quantizer. If a checkpoint is restored under different TP/EP, the per-expert _amax tensors can keep the old shape and no code here reinitializes them. Please either block this mode on topology changes or add a per-expert reset/recalibration path during restore.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/quantization/plugins/transformer_engine.py` around lines 147 -
166, The per-expert weight-quantizer path (_per_expert_weight_quantizer and the
created modules weight_quantizer_{i}) is unsafe on TP/EP topology change because
modelopt_post_restore currently only runs the base path and does not
reinitialize per-expert _amax buffers; update modelopt_post_restore to detect
topology/shape mismatch and either (a) fail-fast by raising a clear error when
_per_expert_weight_quantizer is enabled and the restored per-expert _amax shapes
don't match the current weight shapes, or (b) perform a per-expert
recalibration: iterate over each created module name weight_quantizer_{i} (use
self.num_gemms to enumerate), reset/recreate their _amax buffers to the correct
shape for the current self.weight{i} and call the quantizer's max_calibrate (or
its equivalent reinit routine) so each per-expert quantizer is re-calibrated
after restore; implement one of these branches in modelopt_post_restore to
ensure shapes/amax are consistent after load.
🧹 Nitpick comments (1)
tests/gpu_megatron/torch/quantization/plugins/test_megatron.py (1)

738-750: ⚡ Quick win

Keep these GPU assertions on-device.

tensor.item() and math.isfinite(...) force host syncs on every worker here. Prefer tensor comparisons/reductions plus torch.isfinite, and only materialize Python scalars for failure messages or rank-0 logging.

As per coding guidelines, "Performance note: avoid GPU syncs in hot paths/tests (don’t use tensor.item()/float(tensor)/min(tensor))—use tensor ops instead."

Also applies to: 800-813

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/gpu_megatron/torch/quantization/plugins/test_megatron.py` around lines
738 - 750, The test is forcing GPU->CPU syncs by calling .item() inside the
assert message and in the divergence check; change the assert to keep
comparisons on-device and only materialize scalars when needed: replace the
assert torch.allclose(te_amax.view(()), seq_max, ...) with an if-not pattern (if
not torch.allclose(...): raise
AssertionError(f"...{te_amax.item()}...{seq_max.item()}...")) so the .item()
calls only run on failure, and compute the per-expert divergence using tensor
ops (delta = seq_amaxes.max() - seq_amaxes.min(); if (delta > 1e-5):
saw_per_expert_divergence = True) while ensuring you only call .item() when you
must convert to a Python bool for logging or failure paths; update references to
seq_amaxes, seq_mlp.local_experts, linear_name, te_amax, and
saw_per_expert_divergence accordingly.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@tests/gpu_megatron/torch/quantization/plugins/test_megatron.py`:
- Around line 698-761: The default-path helper
_test_te_grouped_vs_sequential_default_amax_helper must explicitly disable the
feature flag: set os.environ["MODELOPT_TEGROUPED_PER_EXPERT_QUANTIZER"]="0"
(saving the original and restoring it after the test) before calling
initialize_for_megatron so the shared-quantizer behavior is forced; then add a
new test variant that sets
os.environ["MODELOPT_TEGROUPED_PER_EXPERT_QUANTIZER"]="1" and runs the same
checks to cover the env-on path. Apply the same explicit env-off/restore and
corresponding env-on test change to the other similar helper/test pair
exercising TEGrouped vs Sequential (the companion helper referenced in the
review) so both default and env-enabled code paths have dedicated coverage.

---

Outside diff comments:
In `@modelopt/torch/quantization/plugins/transformer_engine.py`:
- Around line 147-166: The per-expert weight-quantizer path
(_per_expert_weight_quantizer and the created modules weight_quantizer_{i}) is
unsafe on TP/EP topology change because modelopt_post_restore currently only
runs the base path and does not reinitialize per-expert _amax buffers; update
modelopt_post_restore to detect topology/shape mismatch and either (a) fail-fast
by raising a clear error when _per_expert_weight_quantizer is enabled and the
restored per-expert _amax shapes don't match the current weight shapes, or (b)
perform a per-expert recalibration: iterate over each created module name
weight_quantizer_{i} (use self.num_gemms to enumerate), reset/recreate their
_amax buffers to the correct shape for the current self.weight{i} and call the
quantizer's max_calibrate (or its equivalent reinit routine) so each per-expert
quantizer is re-calibrated after restore; implement one of these branches in
modelopt_post_restore to ensure shapes/amax are consistent after load.

---

Nitpick comments:
In `@tests/gpu_megatron/torch/quantization/plugins/test_megatron.py`:
- Around line 738-750: The test is forcing GPU->CPU syncs by calling .item()
inside the assert message and in the divergence check; change the assert to keep
comparisons on-device and only materialize scalars when needed: replace the
assert torch.allclose(te_amax.view(()), seq_max, ...) with an if-not pattern (if
not torch.allclose(...): raise
AssertionError(f"...{te_amax.item()}...{seq_max.item()}...")) so the .item()
calls only run on failure, and compute the per-expert divergence using tensor
ops (delta = seq_amaxes.max() - seq_amaxes.min(); if (delta > 1e-5):
saw_per_expert_divergence = True) while ensuring you only call .item() when you
must convert to a Python bool for logging or failure paths; update references to
seq_amaxes, seq_mlp.local_experts, linear_name, te_amax, and
saw_per_expert_divergence accordingly.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 5fbda55f-b0c7-4054-bc00-63b4bc3a150e

📥 Commits

Reviewing files that changed from the base of the PR and between b49f9b9 and 4a58f02.

📒 Files selected for processing (2)
  • modelopt/torch/quantization/plugins/transformer_engine.py
  • tests/gpu_megatron/torch/quantization/plugins/test_megatron.py

Comment on lines +698 to +761
def _test_te_grouped_vs_sequential_default_amax_helper(tp_size, ep_size, quant_cfg, rank, size):
"""TEGrouped per-linear amax should equal max-over-Sequential-experts under default sync=False."""
initialize_for_megatron(
tensor_model_parallel_size=tp_size,
expert_model_parallel_size=ep_size,
seed=SEED,
)

te_grouped = _gpt_model_provider(
tp_size=tp_size, ep_size=ep_size, hidden_size=32, moe_grouped_gemm=True,
transformer_impl="transformer_engine", num_moe_experts=4,
)
forward = get_forward(te_grouped, batch_size=8)

sequential = _gpt_model_provider(
tp_size=tp_size, ep_size=ep_size, hidden_size=32, moe_grouped_gemm=False,
num_moe_experts=4, transformer_impl="modelopt",
)
copy_weights_from_grouped_to_non_grouped(te_grouped, sequential)

for module in te_grouped.modules():
if isinstance(module, TopKRouter):
module.topk = module.num_experts
for module in sequential.modules():
if isinstance(module, TopKRouter):
module.topk = module.num_experts

mtq.quantize(te_grouped, quant_cfg, forward)
mtq.quantize(sequential, quant_cfg, forward)

te_modules = [m for m in te_grouped.modules() if isinstance(m, TEGroupedMLP)]
seq_modules = [m for m in sequential.modules() if isinstance(m, SequentialMLP)]
assert len(te_modules) == len(seq_modules)

saw_per_expert_divergence = False
for te_mlp, seq_mlp in zip(te_modules, seq_modules):
for linear_name in ("linear_fc1", "linear_fc2"):
te_amax = getattr(te_mlp, linear_name).weight_quantizer.amax
assert te_amax is not None and te_amax.numel() == 1

seq_amaxes = torch.stack([
getattr(expert, linear_name).weight_quantizer.amax.view(())
for expert in seq_mlp.local_experts
])
seq_max = seq_amaxes.max()

assert torch.allclose(te_amax.view(()), seq_max, atol=1e-5, rtol=1e-5), (
f"TEGrouped per-linear amax ({te_amax.item()}) != "
f"max-over-Sequential-experts ({seq_max.item()}) for {linear_name}"
)

if (seq_amaxes.max() - seq_amaxes.min()).item() > 1e-5:
saw_per_expert_divergence = True

assert saw_per_expert_divergence, (
"Expected per-expert weight amax to diverge across SequentialMLP experts."
)


@pytest.mark.parametrize("quant_cfg", [mtq.FP8_DEFAULT_CFG, mtq.NVFP4_DEFAULT_CFG])
def test_te_grouped_vs_sequential_default_amax(dist_workers_size_4, quant_cfg):
dist_workers_size_4.run(
partial(_test_te_grouped_vs_sequential_default_amax_helper, 1, 2, quant_cfg)
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Make these “default” tests control the feature flag explicitly.

These helpers never clear or set MODELOPT_TEGROUPED_PER_EXPERT_QUANTIZER, so a worker process with that flag exported will exercise the new per-expert path instead of the default shared-quantizer path. That also means the env-enabled branch added in this PR still has no dedicated coverage. Please force the flag off in these default-behavior tests and add a separate env-on case for the new path.

Also applies to: 764-820

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/gpu_megatron/torch/quantization/plugins/test_megatron.py` around lines
698 - 761, The default-path helper
_test_te_grouped_vs_sequential_default_amax_helper must explicitly disable the
feature flag: set os.environ["MODELOPT_TEGROUPED_PER_EXPERT_QUANTIZER"]="0"
(saving the original and restoring it after the test) before calling
initialize_for_megatron so the shared-quantizer behavior is forced; then add a
new test variant that sets
os.environ["MODELOPT_TEGROUPED_PER_EXPERT_QUANTIZER"]="1" and runs the same
checks to cover the env-on path. Apply the same explicit env-off/restore and
corresponding env-on test change to the other similar helper/test pair
exercising TEGrouped vs Sequential (the companion helper referenced in the
review) so both default and env-enabled code paths have dedicated coverage.

@codecov
Copy link
Copy Markdown

codecov Bot commented May 27, 2026

Codecov Report

❌ Patch coverage is 17.24138% with 24 lines in your changes missing coverage. Please review.
✅ Project coverage is 69.45%. Comparing base (b49f9b9) to head (b1e32d9).

Files with missing lines Patch % Lines
...t/torch/quantization/plugins/transformer_engine.py 17.24% 24 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1550      +/-   ##
==========================================
- Coverage   76.70%   69.45%   -7.26%     
==========================================
  Files         477      477              
  Lines       51977    52002      +25     
==========================================
- Hits        39868    36116    -3752     
- Misses      12109    15886    +3777     
Flag Coverage Δ
examples 33.63% <17.24%> (-5.47%) ⬇️
gpu 50.97% <17.24%> (-9.16%) ⬇️
regression 15.23% <17.24%> (+0.06%) ⬆️
unit 52.74% <6.89%> (-0.02%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant