From d0efb620deec4634d5f3d963f32e7d0704e33914 Mon Sep 17 00:00:00 2001 From: Zeeshan Modi <92383127+Zeesejo@users.noreply.github.com> Date: Sat, 11 Apr 2026 20:11:37 +0200 Subject: [PATCH 1/6] Fix: swap input_amplitude and target_amplitude in JukeboxLoss.forward Fixes #8820 - input_amplitude was incorrectly computed from `target` and target_amplitude from `input`. Corrected to match semantic meaning and standard forward(input, target) convention. Signed-off-by: Zeeshan Modi <92383127+Zeesejo@users.noreply.github.com> --- monai/losses/spectral_loss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/losses/spectral_loss.py b/monai/losses/spectral_loss.py index 06714f3993..fcba03f132 100644 --- a/monai/losses/spectral_loss.py +++ b/monai/losses/spectral_loss.py @@ -55,8 +55,8 @@ def __init__( self.fft_norm = fft_norm def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - input_amplitude = self._get_fft_amplitude(target) - target_amplitude = self._get_fft_amplitude(input) + input_amplitude = self._get_fft_amplitude(input) + target_amplitude = self._get_fft_amplitude(target) # Compute distance between amplitude of frequency components # See Section 3.3 from https://arxiv.org/abs/2005.00341 From 4417d79b54de36f2677fc32e70cfa246a104f097 Mon Sep 17 00:00:00 2001 From: Zeeshan Modi <92383127+Zeesejo@users.noreply.github.com> Date: Sat, 11 Apr 2026 20:39:24 +0200 Subject: [PATCH 2/6] Fix: remove redundant `1-` from SSIMLoss docstring examples Fixes #8822 - The forward() docstring examples used `print(1-SSIMLoss()(x,y))`, but SSIMLoss already computes 1-ssim internally. The `1-` prefix made examples return ssim (not loss), misleading users into training with inverted loss. Signed-off-by: Zeeshan Modi <92383127+Zeesejo@users.noreply.github.com> --- monai/losses/ssim_loss.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/losses/ssim_loss.py b/monai/losses/ssim_loss.py index 8ee1da7267..3fa578da29 100644 --- a/monai/losses/ssim_loss.py +++ b/monai/losses/ssim_loss.py @@ -111,17 +111,17 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # 2D data x = torch.ones([1,1,10,10])/2 y = torch.ones([1,1,10,10])/2 - print(1-SSIMLoss(spatial_dims=2)(x,y)) + print(SSIMLoss(spatial_dims=2)(x,y)) # pseudo-3D data x = torch.ones([1,5,10,10])/2 # 5 could represent number of slices y = torch.ones([1,5,10,10])/2 - print(1-SSIMLoss(spatial_dims=2)(x,y)) + print(SSIMLoss(spatial_dims=2)(x,y)) # 3D data x = torch.ones([1,1,10,10,10])/2 y = torch.ones([1,1,10,10,10])/2 - print(1-SSIMLoss(spatial_dims=3)(x,y)) + print(SSIMLoss(spatial_dims=3)(x,y)) """ ssim_value = self.ssim_metric._compute_tensor(input, target).view(-1, 1) loss: torch.Tensor = 1 - ssim_value From c4dd5c3f96f9b3320f44c5f09725cd67d1bcd58b Mon Sep 17 00:00:00 2001 From: Zeeshan Modi <92383127+Zeesejo@users.noreply.github.com> Date: Fri, 17 Apr 2026 21:42:38 +0200 Subject: [PATCH 3/6] Add inspect_ckpt CLI command to monai.bundle Adds inspect_ckpt function to analImplements the `python -m monai.bundle inspect_ckpt` CLI command requested in issue #5537. The new `inspect_ckpt` function in `monai/bundle/scripts.py`: - Loads a checkpoint file and displays tensor names, shapes, and dtypes - Optionally computes the file hash (md5 or sha256), useful for creating large_files.yml entries in model-zoo bundles - Follows the same patterns as other bundle CLI commands Signed-off-by: Zeeshan Modi <92383127+Zeesejo@users.noreply.github.com>yze checkpoint files. Signed-off-by: Zeeshan Modi <92383127+Zeesejo@users.noreply.github.com> --- monai/bundle/scripts.py | 73 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index fa9ba27096..e74bfad312 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -2013,3 +2013,76 @@ def download_large_files(bundle_path: str | None = None, large_file_name: str | lf_data["filepath"] = os.path.join(bundle_path, lf_data["path"]) lf_data.pop("path") download_url(**lf_data) + +def inspect_ckpt( + path: str, + print_all_vars: bool = True, + compute_hash: bool = False, + hash_type: str = "md5", +) -> dict: + """ + Inspect the variables and shapes saved in a checkpoint file. + Prints a human-readable summary of the tensor names, shapes, and dtypes + stored in the checkpoint, similar to TensorFlow's inspect_checkpoint tool. + Optionally also computes the hash value of the file (useful when creating + a ``large_files.yml`` for model-zoo bundles). + + Typical usage examples: + + .. code-block:: bash + + # Display all tensor names, shapes, and dtypes: + python -m monai.bundle inspect_ckpt --path model.pt + + # Suppress individual variable printing (only show file-level info): + python -m monai.bundle inspect_ckpt --path model.pt --print_all_vars false + + # Also compute md5 hash of the checkpoint file: + python -m monai.bundle inspect_ckpt --path model.pt --compute_hash true + + # Use sha256 hash instead of md5: + python -m monai.bundle inspect_ckpt --path model.pt --compute_hash true --hash_type sha256 + + Args: + path: path to the checkpoint file to inspect. + print_all_vars: whether to print individual variable names, shapes, + and dtypes. Default to ``True``. + compute_hash: whether to compute and print the hash value of the + checkpoint file. Default to ``False``. + hash_type: the hash type to use when ``compute_hash`` is ``True``. + Should be ``"md5"`` or ``"sha256"``. Default to ``"md5"``. + + Returns: + A dictionary mapping variable names to a dict containing + ``"shape"`` (tuple) and ``"dtype"`` (str) for each tensor. + """ + import hashlib + + _log_input_summary(tag="inspect_ckpt", args={"path": path, "print_all_vars": print_all_vars, "compute_hash": compute_hash}) + + ckpt = torch.load(path, map_location="cpu", weights_only=True) + if not isinstance(ckpt, Mapping): + ckpt = get_state_dict(ckpt) + + var_info: dict = {} + for name, val in ckpt.items(): + if isinstance(val, torch.Tensor): + var_info[name] = {"shape": tuple(val.shape), "dtype": str(val.dtype)} + else: + var_info[name] = {"shape": None, "dtype": type(val).__name__} + + logger.info(f"checkpoint file: {path}") + logger.info(f"total variables: {len(var_info)}") + if print_all_vars: + for name, info in var_info.items(): + logger.info(f" {name}: shape={info['shape']}, dtype={info['dtype']}") + + if compute_hash: + h = hashlib.new(hash_type) + with open(path, "rb") as f: + for chunk in iter(lambda: f.read(1 << 20), b""): + h.update(chunk) + digest = h.hexdigest() + logger.info(f"{hash_type} hash: {digest}") + + return var_info From 1a8d8f1c5a15ee3599b1f0a0213440300716aecb Mon Sep 17 00:00:00 2001 From: Zeeshan Modi <92383127+Zeesejo@users.noreply.github.com> Date: Fri, 17 Apr 2026 21:43:45 +0200 Subject: [PATCH 4/6] Export inspect_ckpt from monai.bundle Signed-off-by: Zeeshan Modi <92383127+Zeesejo@users.noreply.github.com> --- monai/bundle/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/bundle/__init__.py b/monai/bundle/__init__.py index 3f3c8d545e..065d55089c 100644 --- a/monai/bundle/__init__.py +++ b/monai/bundle/__init__.py @@ -24,6 +24,7 @@ get_bundle_info, get_bundle_versions, init_bundle, + inspect_ckpt, load, onnx_export, push_to_hf_hub, From 36eed9d752e9cccd1b74749523b9e8047888fde8 Mon Sep 17 00:00:00 2001 From: Zeeshan Modi <92383127+Zeesejo@users.noreply.github.com> Date: Fri, 17 Apr 2026 21:44:53 +0200 Subject: [PATCH 5/6] Add inspect_ckpt to bundle CLI entry point Signed-off-by: Zeeshan Modi <92383127+Zeesejo@users.noreply.github.com> --- monai/bundle/__main__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/bundle/__main__.py b/monai/bundle/__main__.py index 778c9ef2f0..edce1567df 100644 --- a/monai/bundle/__main__.py +++ b/monai/bundle/__main__.py @@ -16,6 +16,7 @@ download, download_large_files, init_bundle, + inspect_ckpt, onnx_export, run, run_workflow, From c94c5596639cff3f07ad302d189a658243304e56 Mon Sep 17 00:00:00 2001 From: Zeeshan Modi <92383127+Zeesejo@users.noreply.github.com> Date: Fri, 17 Apr 2026 21:45:55 +0200 Subject: [PATCH 6/6] Add unit tests for inspect_ckpt Signed-off-by: Zeeshan Modi <92383127+Zeesejo@users.noreply.github.com> --- tests/bundle/test_bundle_inspect_ckpt.py | 70 ++++++++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 tests/bundle/test_bundle_inspect_ckpt.py diff --git a/tests/bundle/test_bundle_inspect_ckpt.py b/tests/bundle/test_bundle_inspect_ckpt.py new file mode 100644 index 0000000000..ab569b234d --- /dev/null +++ b/tests/bundle/test_bundle_inspect_ckpt.py @@ -0,0 +1,70 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import os +import tempfile +import unittest + +import torch + +from monai.bundle import inspect_ckpt + + +class TestInspectCkpt(unittest.TestCase): + def setUp(self): + # Create a temporary checkpoint file with a simple state dict + self.tmp_dir = tempfile.mkdtemp() + self.ckpt_path = os.path.join(self.tmp_dir, "model.pt") + state_dict = { + "layer1.weight": torch.randn(4, 3), + "layer1.bias": torch.zeros(4), + "layer2.weight": torch.randn(2, 4), + } + torch.save(state_dict, self.ckpt_path) + + def test_returns_dict_with_correct_keys(self): + result = inspect_ckpt(path=self.ckpt_path, print_all_vars=False) + self.assertIsInstance(result, dict) + self.assertIn("layer1.weight", result) + self.assertIn("layer1.bias", result) + self.assertIn("layer2.weight", result) + + def test_shapes_are_correct(self): + result = inspect_ckpt(path=self.ckpt_path, print_all_vars=False) + self.assertEqual(result["layer1.weight"]["shape"], (4, 3)) + self.assertEqual(result["layer1.bias"]["shape"], (4,)) + self.assertEqual(result["layer2.weight"]["shape"], (2, 4)) + + def test_dtype_is_reported(self): + result = inspect_ckpt(path=self.ckpt_path, print_all_vars=False) + self.assertIn("dtype", result["layer1.weight"]) + self.assertTrue(result["layer1.weight"]["dtype"].startswith("torch.")) + + def test_compute_hash_md5(self): + # Should not raise; hash value is logged but not returned in dict + result = inspect_ckpt(path=self.ckpt_path, print_all_vars=False, compute_hash=True, hash_type="md5") + self.assertIsInstance(result, dict) + + def test_compute_hash_sha256(self): + result = inspect_ckpt(path=self.ckpt_path, print_all_vars=False, compute_hash=True, hash_type="sha256") + self.assertIsInstance(result, dict) + + def test_print_all_vars_true_does_not_raise(self): + # Should log each variable without raising + try: + inspect_ckpt(path=self.ckpt_path, print_all_vars=True) + except Exception as e: + self.fail(f"inspect_ckpt raised an exception with print_all_vars=True: {e}") + + +if __name__ == "__main__": + unittest.main()