Skip to content

FSDP2 calibration with hf_ptq.py#1563

Draft
sugunav14 wants to merge 5 commits into
mainfrom
svelury/fsdp2-refactor
Draft

FSDP2 calibration with hf_ptq.py#1563
sugunav14 wants to merge 5 commits into
mainfrom
svelury/fsdp2-refactor

Conversation

@sugunav14
Copy link
Copy Markdown
Contributor

@sugunav14 sugunav14 commented May 28, 2026

What does this PR do?

Type of change: ?

Usage

# Add a code snippet demonstrating how to use this

Testing

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅ / ❌ / N/A
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅ / ❌ / N/A
  • Did you write any new necessary tests?: ✅ / ❌ / N/A
  • Did you update Changelog?: ✅ / ❌ / N/A
  • Did you get Claude approval on this PR?: ✅ / ❌ / N/A

Additional Information

Summary by CodeRabbit

  • New Features

    • FSDP2-based post-training quantization flow for single- and multi-node runs
    • NVFP4 "max" quantization format and optional per-layer calibration
  • Documentation

    • Rewritten PTQ guide with torchrun-based usage and FSDP2 instructions
  • Removed

    • Legacy Accelerate-based multi-node PTQ workflow and related config/scripts

Review Change Stack

sugunav14 added 4 commits May 21, 2026 21:16
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
@sugunav14 sugunav14 requested review from a team as code owners May 28, 2026 23:22
@sugunav14 sugunav14 requested a review from realAsma May 28, 2026 23:22
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 28, 2026

📝 Walkthrough

Walkthrough

Consolidates multi-node LLM PTQ into hf_ptq.py by adding FSDP2 utilities (wrapping, meta init, sharding, state gathering), wiring dataloader sharding and FSDP2-aware calibration/export flows, adding NVFP4 max configs, and replacing the Accelerate-based multinode workflow with torchrun instructions.

Changes

FSDP2 PTQ Consolidation

Layer / File(s) Summary
FSDP2 distributed helpers foundation
modelopt/torch/utils/distributed.py
New exports and implementations: fsdp2_wrap, init_params_on_meta, fsdp2_shard, shard_dataloader, fsdp_aware_forward_loop, and Fsdp2StateDictAdapter.get_state_dict to support rank-aware materialization, sharding, dataloader sharding, and full-state gathering.
Quantization FSDP2 mesh robustness
modelopt/torch/quantization/utils/core_utils.py
Improved _get_fsdp2_mesh to prefer post_forward_mesh_info, fall back to mesh_info, and defensively handle None cases returning info.mesh.
Example utilities FSDP2 integration
examples/llm_ptq/example_utils.py
Adds setup_distributed_args, cleanup_distributed, validate_fsdp2_supported, and load_and_prepare_fsdp2_model implementing rank-0 CPU load + meta-device init + fsdp2_shard for causal LM checkpoints.
Main PTQ script FSDP2 orchestration
examples/llm_ptq/hf_ptq.py
Imports and wires FSDP2 helpers: adds _nvfp4_max_cfg and nvfp4_max(_layerwise) choices; shards calibration dataloader per rank; uses load_and_prepare_fsdp2_model when --use_fsdp2; switches to fsdp_aware_forward_loop; adds _export_fsdp2_hf_checkpoint for gathered-state export; gates disk writes to args.is_main; adds --use_fsdp2 and --cpu_offload flags with validation; calls setup_distributed_args/cleanup_distributed and wraps main in patch_fsdp_mp_dtypes().
Documentation and legacy removal
examples/llm_ptq/README.md
Replaces Accelerate-based multinode instructions with torchrun command templates for single-node and multi-node FSDP2 usage and documents --qformat nvfp4_max_layerwise layerwise calibration. Removed multinode_ptq.py/fsdp2.yaml legacy references.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Suggested reviewers

  • realAsma
  • jingyu-ml
🚥 Pre-merge checks | ✅ 5 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 65.38% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The PR title 'FSDP2 calibration with hf_ptq.py' accurately describes the main change: integrating FSDP2 distributed training support into the hf_ptq.py entry point for post-training quantization.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed No security anti-patterns found: trust_remote_code parameterized with False defaults, no unsafe torch.load/numpy.load, no eval/exec, no nosec, no new unsafe dependencies.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch svelury/fsdp2-refactor
⚔️ Resolve merge conflicts
  • Resolve merge conflict in branch svelury/fsdp2-refactor

Comment @coderabbitai help to get the list of available commands and usage tips.

@sugunav14 sugunav14 marked this pull request as draft May 28, 2026 23:23
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 28, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@examples/llm_ptq/example_utils.py`:
- Around line 68-78: The else branch for when getattr(args, "use_fsdp2", False)
is false is overwriting args.device with None which breaks non-FSDP device
handling (e.g., --device cpu); change the else branch in that block so it sets
args.rank = 0, args.world_size = 1 and args.is_main = True but does NOT modify
args.device (leave existing args.device intact or only set it if undefined), so
downstream calls like get_model() receive the requested device; keep reference
to getattr(args, "use_fsdp2", False), args.device, args.rank, args.world_size,
args.is_main and get_model() to locate the code to change.
- Around line 183-201: Modify create_fsdp2_calibration_loop to accept an is_main
(default False) parameter and thread it into the inner calibrate closure; inside
calibrate only wrap the dataloader with tqdm(desc="Calibrating") when is_main is
True and otherwise iterate the dataloader without tqdm to avoid per-rank
progress bars, and use print_rank_0 to emit any high-level start/finish messages
on rank 0 if needed; update references to the calibrate closure accordingly so
callers can pass is_main=True on rank 0.
- Around line 115-180: The FSDP2 load path in load_and_prepare_fsdp2_model
ignores the --attn_implementation setting, so pass the attn implementation
through to HuggingFace calls and the caller: include args.attn_implementation
(when args is not None) in AutoConfig.from_pretrained,
AutoModelForCausalLM.from_pretrained and AutoModelForCausalLM.from_config by
forwarding it as the attention backend argument expected by HF (e.g.,
attn_implementation or equivalent kwarg), and ensure examples/llm_ptq/hf_ptq.py
passes args.attn_implementation into load_and_prepare_fsdp2_model when invoking
it; update references in this function (hf_config creation, from_pretrained
call, from_config call) and the call site to propagate the flag unchanged.

In `@examples/llm_ptq/hf_ptq.py`:
- Around line 1430-1440: Add a fail-fast validation after argument parsing to
reject the incompatible combination of --use_fsdp2 and auto-quantize: detect
when args.use_fsdp2 is true and the auto-quantize flag/setting (the CLI flag
that triggers mtq.auto_quantize(), e.g., args.auto_quantize_bits or equivalent)
is present/non-zero, log/raise a clear error and exit non-zero before calling
fsdp2_shard() or mtq.auto_quantize(); reference the parser option "--use_fsdp2",
the fsdp2_shard() call and mtq.auto_quantize() so the check runs immediately
after parsing and prevents entering the frozen-parameter path.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 250df2b3-92b6-427a-ad39-e2fb1869f878

📥 Commits

Reviewing files that changed from the base of the PR and between a2c496a and 46cb80e.

📒 Files selected for processing (7)
  • examples/llm_ptq/README.md
  • examples/llm_ptq/example_utils.py
  • examples/llm_ptq/fsdp2.yaml
  • examples/llm_ptq/hf_ptq.py
  • examples/llm_ptq/multinode_ptq.py
  • modelopt/torch/quantization/utils/core_utils.py
  • modelopt/torch/utils/distributed.py
💤 Files with no reviewable changes (2)
  • examples/llm_ptq/fsdp2.yaml
  • examples/llm_ptq/multinode_ptq.py

Comment on lines +68 to +78
if getattr(args, "use_fsdp2", False):
dist_utils.setup()
args.rank = dist_utils.rank()
args.world_size = dist_utils.size()
args.device = torch.device(f"cuda:{dist_utils.local_rank()}")
args.is_main = args.rank == 0
else:
args.rank = 0
args.world_size = 1
args.device = None
args.is_main = True
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Preserve the requested device when FSDP2 is off.

The else branch overwrites args.device with None, but the normal path still passes args.device into get_model() and related helpers. That silently breaks --device handling for non-FSDP runs, especially --device cpu.

Suggested fix
     else:
         args.rank = 0
         args.world_size = 1
-        args.device = None
         args.is_main = True
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if getattr(args, "use_fsdp2", False):
dist_utils.setup()
args.rank = dist_utils.rank()
args.world_size = dist_utils.size()
args.device = torch.device(f"cuda:{dist_utils.local_rank()}")
args.is_main = args.rank == 0
else:
args.rank = 0
args.world_size = 1
args.device = None
args.is_main = True
if getattr(args, "use_fsdp2", False):
dist_utils.setup()
args.rank = dist_utils.rank()
args.world_size = dist_utils.size()
args.device = torch.device(f"cuda:{dist_utils.local_rank()}")
args.is_main = args.rank == 0
else:
args.rank = 0
args.world_size = 1
args.is_main = True
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/llm_ptq/example_utils.py` around lines 68 - 78, The else branch for
when getattr(args, "use_fsdp2", False) is false is overwriting args.device with
None which breaks non-FSDP device handling (e.g., --device cpu); change the else
branch in that block so it sets args.rank = 0, args.world_size = 1 and
args.is_main = True but does NOT modify args.device (leave existing args.device
intact or only set it if undefined), so downstream calls like get_model()
receive the requested device; keep reference to getattr(args, "use_fsdp2",
False), args.device, args.rank, args.world_size, args.is_main and get_model() to
locate the code to change.

Comment on lines +115 to +180
def load_and_prepare_fsdp2_model(
ckpt_path: str,
device: torch.device,
rank: int,
args=None,
trust_remote_code: bool = False,
mp_policy=None,
):
"""Load and FSDP2-shard a causal LM (accelerate-style rank-0-only CPU load).

Replicates ``accelerate.init_empty_weights(include_buffers=False)`` +
``load_checkpoint_in_model`` manually:

- Rank 0: ``from_pretrained`` on CPU; capture ``src_state_dict``.
- Other ranks: ``from_config`` under ``init_params_on_meta`` → params on
meta (~0 CPU), buffers computed on CPU from config (RoPE inv_freq etc.).
- ``fsdp2_shard`` wraps decoder layers (root stays unsharded), materializes
meta→GPU, broadcasts state_dict from rank 0, re-ties weights, freezes.

Memory: rank 0 holds the full BF16 model in CPU during the broadcast
(~model_size bytes); other ranks pay ~0 CPU. Each rank ends with
``model_size / world_size`` GPU shard storage plus replicated
``embed_tokens`` + ``lm_head`` (~few-GiB total).

v1 supports standard transformers families only (causal LMs that load
cleanly via ``AutoModelForCausalLM``). VILA / pack-quantized /
speculative / VL go through ``get_model`` and don't get FSDP2.
"""
from modelopt.torch.utils.distributed import fsdp2_shard, init_params_on_meta

hf_config = AutoConfig.from_pretrained(ckpt_path, trust_remote_code=trust_remote_code)
if args is not None:
validate_fsdp2_supported(args, hf_config)

dtype = getattr(hf_config, "torch_dtype", None) or torch.bfloat16

if rank == 0:
src_model = AutoModelForCausalLM.from_pretrained(
ckpt_path,
torch_dtype="auto",
trust_remote_code=trust_remote_code,
low_cpu_mem_usage=True,
)
src_model.eval()
src_state_dict = src_model.state_dict()
else:
src_model = None
src_state_dict = {}

with init_params_on_meta():
model = AutoModelForCausalLM.from_config(
hf_config, torch_dtype=dtype, trust_remote_code=trust_remote_code
)
model.eval()
if hasattr(model, "config") and hasattr(model.config, "use_cache"):
model.config.use_cache = False

sharded = fsdp2_shard(
model,
device,
rank,
src_state_dict=src_state_dict,
mp_policy=mp_policy,
)
del src_model, src_state_dict
return sharded
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Don't drop --attn_implementation on the FSDP2 load path.

get_model() threads this CLI option into AutoConfig / from_pretrained, but load_and_prepare_fsdp2_model() ignores it completely. Under --use_fsdp2, the flag becomes a no-op, which can change loading behavior and memory use for models that require a specific attention backend.

Suggested fix
 def load_and_prepare_fsdp2_model(
     ckpt_path: str,
     device: torch.device,
     rank: int,
     args=None,
     trust_remote_code: bool = False,
+    attn_implementation: str | None = None,
     mp_policy=None,
 ):
@@
-    hf_config = AutoConfig.from_pretrained(ckpt_path, trust_remote_code=trust_remote_code)
+    config_kwargs = {"trust_remote_code": trust_remote_code}
+    if attn_implementation is not None:
+        config_kwargs["attn_implementation"] = attn_implementation
+    hf_config = AutoConfig.from_pretrained(ckpt_path, **config_kwargs)
@@
         src_model = AutoModelForCausalLM.from_pretrained(
             ckpt_path,
             torch_dtype="auto",
             trust_remote_code=trust_remote_code,
+            attn_implementation=attn_implementation,
             low_cpu_mem_usage=True,
         )
@@
         model = AutoModelForCausalLM.from_config(
-            hf_config, torch_dtype=dtype, trust_remote_code=trust_remote_code
+            hf_config,
+            torch_dtype=dtype,
+            trust_remote_code=trust_remote_code,
+            attn_implementation=attn_implementation,
         )

Please also pass args.attn_implementation from examples/llm_ptq/hf_ptq.py.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/llm_ptq/example_utils.py` around lines 115 - 180, The FSDP2 load
path in load_and_prepare_fsdp2_model ignores the --attn_implementation setting,
so pass the attn implementation through to HuggingFace calls and the caller:
include args.attn_implementation (when args is not None) in
AutoConfig.from_pretrained, AutoModelForCausalLM.from_pretrained and
AutoModelForCausalLM.from_config by forwarding it as the attention backend
argument expected by HF (e.g., attn_implementation or equivalent kwarg), and
ensure examples/llm_ptq/hf_ptq.py passes args.attn_implementation into
load_and_prepare_fsdp2_model when invoking it; update references in this
function (hf_config creation, from_pretrained call, from_config call) and the
call site to propagate the flag unchanged.

Comment thread examples/llm_ptq/example_utils.py Outdated
Comment on lines +1430 to +1440
parser.add_argument(
"--use_fsdp2",
action="store_true",
help=(
"Run calibration under PyTorch FSDP2 (requires launching with torchrun). "
"Takes precedence over --use_seq_device_map. "
"v1 limitations: standard causal-LM only (no VILA / pack-quantized / speculative / "
"auto-quantize / sparsity / VLM). Rank 0 holds the full model in CPU briefly "
"during the broadcast step; other ranks pay ~0 CPU."
),
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Fail fast on --use_fsdp2 --auto_quantize_bits.

The help text already marks auto-quantize as unsupported for FSDP2, but the parser still accepts it. That combination later reaches mtq.auto_quantize() after fsdp2_shard() has frozen every parameter, so users only discover the problem deep into the run.

Suggested fix
     if args.use_fsdp2 and args.use_seq_device_map:
         warnings.warn("--use_seq_device_map is ignored when --use_fsdp2 is set.")
         args.use_seq_device_map = False
     if args.use_fsdp2 and os.environ.get("RANK") is None:
         parser.error("--use_fsdp2 requires launching with torchrun")
+    if args.use_fsdp2 and args.auto_quantize_bits is not None:
+        parser.error("--use_fsdp2 is not supported with --auto_quantize_bits")

Also applies to: 1545-1549

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/llm_ptq/hf_ptq.py` around lines 1430 - 1440, Add a fail-fast
validation after argument parsing to reject the incompatible combination of
--use_fsdp2 and auto-quantize: detect when args.use_fsdp2 is true and the
auto-quantize flag/setting (the CLI flag that triggers mtq.auto_quantize(),
e.g., args.auto_quantize_bits or equivalent) is present/non-zero, log/raise a
clear error and exit non-zero before calling fsdp2_shard() or
mtq.auto_quantize(); reference the parser option "--use_fsdp2", the
fsdp2_shard() call and mtq.auto_quantize() so the check runs immediately after
parsing and prevents entering the frozen-parameter path.

@codecov
Copy link
Copy Markdown

codecov Bot commented May 28, 2026

Codecov Report

❌ Patch coverage is 15.15152% with 84 lines in your changes missing coverage. Please review.
✅ Project coverage is 76.50%. Comparing base (3f88f0d) to head (705316d).
⚠️ Report is 34 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/utils/distributed.py 10.86% 82 Missing ⚠️
modelopt/torch/quantization/utils/core_utils.py 71.42% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1563      +/-   ##
==========================================
- Coverage   76.92%   76.50%   -0.42%     
==========================================
  Files         474      474              
  Lines       51503    51599      +96     
==========================================
- Hits        39618    39476     -142     
- Misses      11885    12123     +238     
Flag Coverage Δ
examples 33.65% <10.10%> (-7.18%) ⬇️
gpu 59.66% <15.15%> (-0.68%) ⬇️
regression 15.20% <10.10%> (+0.06%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@sugunav14 sugunav14 marked this pull request as ready for review May 29, 2026 16:47
@sugunav14 sugunav14 requested a review from meenchen May 29, 2026 17:15
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
examples/llm_ptq/example_utils.py (1)

97-112: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Reject encoder-decoder configs before entering the FSDP2 loader.

load_and_prepare_fsdp2_model() always uses AutoModelForCausalLM, so families like T5/BART/Whisper currently slip through validation and fail later during model construction instead of with the intended early NotImplementedError.

Suggested fix
     if getattr(config, "quantization_config", None) is not None:
         issues.append("pack-quantized / compressed-tensors checkpoints")
+    if getattr(config, "is_encoder_decoder", False):
+        issues.append("encoder-decoder models (FSDP2 v1 only supports causal LMs)")
     if getattr(args, "specdec_offline_dataset", None) is not None:
         issues.append("speculative decoding (--specdec_offline_dataset)")
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/llm_ptq/example_utils.py` around lines 97 - 112, The validation
block that collects issues must also reject encoder-decoder configs so
load_and_prepare_fsdp2_model() doesn't call AutoModelForCausalLM on incompatible
families; detect encoder-decoder models by checking config.is_encoder_decoder
(or config.model_type in known encoder-decoder types) and append a descriptive
message to issues (e.g., "encoder-decoder models (T5/BART/Whisper)"), keeping
the existing checks (vila, is_nemotron_vl, _is_multimodal_config,
quantization_config, specdec_offline_dataset, low_memory_mode) and raising the
same NotImplementedError if issues is non-empty before any use of
AutoModelForCausalLM or model construction.
♻️ Duplicate comments (4)
examples/llm_ptq/hf_ptq.py (1)

1552-1558: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Reject --use_fsdp2 together with --auto_quantize_bits.

load_and_prepare_fsdp2_model() goes through fsdp2_shard(), which freezes every parameter before auto_quantize() runs. The help text already says this combo is unsupported, so error out here instead of failing deep in the quantization flow.

Suggested fix
     if args.use_fsdp2 and os.environ.get("RANK") is None:
         parser.error("--use_fsdp2 requires launching with torchrun")
+    if args.use_fsdp2 and args.auto_quantize_bits is not None:
+        parser.error("--use_fsdp2 is not supported with --auto_quantize_bits")
     if args.cpu_offload and not args.use_fsdp2:
         parser.error("--cpu_offload requires --use_fsdp2")
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/llm_ptq/hf_ptq.py` around lines 1552 - 1558, The CLI currently
allows combining --use_fsdp2 and --auto_quantize_bits which later breaks because
load_and_prepare_fsdp2_model -> fsdp2_shard freezes parameters before
auto_quantize runs; add an explicit check after parsing args that if
args.use_fsdp2 and args.auto_quantize_bits are both truthy, call parser.error
with a clear message rejecting this combination (mentioning --use_fsdp2 and
--auto_quantize_bits) so the program exits early instead of failing inside
load_and_prepare_fsdp2_model / fsdp2_shard / auto_quantize.
modelopt/torch/utils/distributed.py (1)

439-458: ⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Hide the calibration progress bar on non-main ranks.

tqdm will render once per rank here, so multi-node calibration still spams overlapping progress bars. Thread an is_main flag into this helper and disable the bar elsewhere.

Suggested fix
-def fsdp_aware_forward_loop(wrapped_model, dataloader, device=None):
+def fsdp_aware_forward_loop(wrapped_model, dataloader, device=None, is_main: bool = True):
@@
-        for batch in tqdm(dataloader, desc="Calibrating"):
+        for batch in tqdm(dataloader, desc="Calibrating", disable=not is_main):

As per coding guidelines **/*.py: Use print_rank_0 or warn_rank_0 when possible to avoid noisy logs and guard shared side effects against race conditions between ranks in distributed processing.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/utils/distributed.py` around lines 439 - 458, The calibration
progress bar in fsdp_aware_forward_loop's inner function calibrate renders on
every rank; add an is_main (or rank0) boolean parameter to
fsdp_aware_forward_loop and thread it into calibrate, then pass disable=not
is_main (or only instantiate tqdm when is_main) so tqdm only shows on the main
rank; additionally wrap any shared-side-effect logs with
print_rank_0/warn_rank_0 and ensure any state changes that could race between
ranks are guarded by the same is_main check.
examples/llm_ptq/example_utils.py (2)

74-78: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Don't overwrite the requested device in the non-FSDP path.

Setting args.device = None makes --device cpu a no-op and can later feed None into get_model().

Suggested fix
     else:
         args.rank = 0
         args.world_size = 1
-        args.device = None
         args.is_main = True
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/llm_ptq/example_utils.py` around lines 74 - 78, The code sets
args.device = None in the non-FSDP branch which overwrites any user-requested
device and can pass None into get_model(); change this to preserve the provided
device (do not assign None) — either remove the args.device = None assignment or
set args.device = args.device if already defined, ensuring args.device retains
"--device" input before calling get_model() and other functions that expect a
valid device.

115-123: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Preserve --attn_implementation on the FSDP2 load path.

Under --use_fsdp2, this helper ignores the CLI flag entirely, so config/model loading diverges from get_model() and can pick a different attention backend or memory profile.

Suggested fix
 def load_and_prepare_fsdp2_model(
     ckpt_path: str,
     device: torch.device,
     rank: int,
     args=None,
     trust_remote_code: bool = False,
+    attn_implementation: str | None = None,
     mp_policy=None,
     cpu_offload: bool = False,
 ):
@@
-    hf_config = AutoConfig.from_pretrained(ckpt_path, trust_remote_code=trust_remote_code)
+    config_kwargs = {"trust_remote_code": trust_remote_code}
+    if attn_implementation is not None:
+        config_kwargs["attn_implementation"] = attn_implementation
+    hf_config = AutoConfig.from_pretrained(ckpt_path, **config_kwargs)
@@
         src_model = AutoModelForCausalLM.from_pretrained(
             ckpt_path,
             torch_dtype="auto",
             trust_remote_code=trust_remote_code,
+            attn_implementation=attn_implementation,
             low_cpu_mem_usage=True,
         )
@@
         model = AutoModelForCausalLM.from_config(
-            hf_config, torch_dtype=dtype, trust_remote_code=trust_remote_code
+            hf_config,
+            torch_dtype=dtype,
+            trust_remote_code=trust_remote_code,
+            attn_implementation=attn_implementation,
         )

Please also pass args.attn_implementation from examples/llm_ptq/hf_ptq.py.

Also applies to: 146-168

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/llm_ptq/example_utils.py` around lines 115 - 123, The FSDP2 load
path in load_and_prepare_fsdp2_model currently ignores the CLI flag for
attention backend; update the function to accept and propagate
args.attn_implementation (or a dedicated attn_implementation param) when
building the model/config so the same attention implementation used by
get_model() is preserved, and ensure callers (examples/llm_ptq/hf_ptq.py) pass
args.attn_implementation into load_and_prepare_fsdp2_model; also mirror this
change in the related FSDP2 code paths around the block at 146-168 so the
attention backend and memory profile remain consistent.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Outside diff comments:
In `@examples/llm_ptq/example_utils.py`:
- Around line 97-112: The validation block that collects issues must also reject
encoder-decoder configs so load_and_prepare_fsdp2_model() doesn't call
AutoModelForCausalLM on incompatible families; detect encoder-decoder models by
checking config.is_encoder_decoder (or config.model_type in known
encoder-decoder types) and append a descriptive message to issues (e.g.,
"encoder-decoder models (T5/BART/Whisper)"), keeping the existing checks (vila,
is_nemotron_vl, _is_multimodal_config, quantization_config,
specdec_offline_dataset, low_memory_mode) and raising the same
NotImplementedError if issues is non-empty before any use of
AutoModelForCausalLM or model construction.

---

Duplicate comments:
In `@examples/llm_ptq/example_utils.py`:
- Around line 74-78: The code sets args.device = None in the non-FSDP branch
which overwrites any user-requested device and can pass None into get_model();
change this to preserve the provided device (do not assign None) — either remove
the args.device = None assignment or set args.device = args.device if already
defined, ensuring args.device retains "--device" input before calling
get_model() and other functions that expect a valid device.
- Around line 115-123: The FSDP2 load path in load_and_prepare_fsdp2_model
currently ignores the CLI flag for attention backend; update the function to
accept and propagate args.attn_implementation (or a dedicated
attn_implementation param) when building the model/config so the same attention
implementation used by get_model() is preserved, and ensure callers
(examples/llm_ptq/hf_ptq.py) pass args.attn_implementation into
load_and_prepare_fsdp2_model; also mirror this change in the related FSDP2 code
paths around the block at 146-168 so the attention backend and memory profile
remain consistent.

In `@examples/llm_ptq/hf_ptq.py`:
- Around line 1552-1558: The CLI currently allows combining --use_fsdp2 and
--auto_quantize_bits which later breaks because load_and_prepare_fsdp2_model ->
fsdp2_shard freezes parameters before auto_quantize runs; add an explicit check
after parsing args that if args.use_fsdp2 and args.auto_quantize_bits are both
truthy, call parser.error with a clear message rejecting this combination
(mentioning --use_fsdp2 and --auto_quantize_bits) so the program exits early
instead of failing inside load_and_prepare_fsdp2_model / fsdp2_shard /
auto_quantize.

In `@modelopt/torch/utils/distributed.py`:
- Around line 439-458: The calibration progress bar in fsdp_aware_forward_loop's
inner function calibrate renders on every rank; add an is_main (or rank0)
boolean parameter to fsdp_aware_forward_loop and thread it into calibrate, then
pass disable=not is_main (or only instantiate tqdm when is_main) so tqdm only
shows on the main rank; additionally wrap any shared-side-effect logs with
print_rank_0/warn_rank_0 and ensure any state changes that could race between
ranks are guarded by the same is_main check.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 0c53a76c-6ab3-44c4-92db-4acf83822755

📥 Commits

Reviewing files that changed from the base of the PR and between 46cb80e and 705316d.

📒 Files selected for processing (3)
  • examples/llm_ptq/example_utils.py
  • examples/llm_ptq/hf_ptq.py
  • modelopt/torch/utils/distributed.py

@sugunav14 sugunav14 requested a review from shengliangxu May 29, 2026 20:07
@sugunav14 sugunav14 marked this pull request as draft May 29, 2026 20:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant