From de00881ab20289edb952b16f1b5ed73377ba1d05 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9cile?= Date: Thu, 23 Apr 2026 13:56:22 +0200 Subject: [PATCH 01/24] Implement segmentation refinement as in gMRI2FEM and csf segmentation, to be used in intracranial mask computation --- src/mritk/cli.py | 9 ++- src/mritk/masks.py | 63 ++++++++++++++--- src/mritk/segmentation.py | 137 ++++++++++++++++++++++++++++++++++++- tests/create_test_data.py | 3 + tests/test_masks.py | 34 +++++++-- tests/test_segmentation.py | 43 ++++++++++++ 6 files changed, 271 insertions(+), 18 deletions(-) diff --git a/src/mritk/cli.py b/src/mritk/cli.py index 10fe43a..5e4615c 100644 --- a/src/mritk/cli.py +++ b/src/mritk/cli.py @@ -9,7 +9,7 @@ from rich.logging import RichHandler from rich_argparse import RichHelpFormatter -from . import concentration, datasets, hybrid, info, looklocker, masks, mixed, napari, r1, show, statistics +from . import concentration, datasets, hybrid, info, segmentation, looklocker, masks, mixed, napari, r1, show, statistics def version_info(): @@ -75,6 +75,11 @@ def setup_parser(): napari_parser = subparsers.add_parser("napari", help="Show MRI data using napari", formatter_class=parser.formatter_class) napari.add_arguments(napari_parser) + segmentation_parser = subparsers.add_parser( + "seg", help="Perform segmentation tasks", formatter_class=parser.formatter_class + ) + segmentation.add_arguments(segmentation_parser, extra_args_cb=add_extra_arguments) + looklocker_parser = subparsers.add_parser( "looklocker", help="Process Look-Locker data", formatter_class=parser.formatter_class ) @@ -142,6 +147,8 @@ def dispatch(parser: argparse.ArgumentParser, argv: Optional[Sequence[str]] = No show.dispatch(args) elif command == "napari": napari.dispatch(args) + elif command == "seg": + segmentation.dispatch(args) elif command == "looklocker": looklocker.dispatch(args) elif command == "mask": diff --git a/src/mritk/masks.py b/src/mritk/masks.py index a50b4c6..05c8573 100644 --- a/src/mritk/masks.py +++ b/src/mritk/masks.py @@ -10,6 +10,7 @@ import numpy as np import skimage +import scipy.interpolate from .data import MRIData from .testing import assert_same_space @@ -107,6 +108,45 @@ def csf_mask(input: Path, connectivity: int | None = 2, use_li: bool = False) -> return mri_data +def csf_segmentation(input_segmentation: Path | MRIData, csf_mask: Path | MRIData) -> MRIData: + """ + Generates a CSF segmentation by applying a CSF mask to an anatomical segmentation. + + This function takes an anatomical segmentation (e.g., from FreeSurfer) and a CSF mask, + and produces a new segmentation where voxels identified as CSF in the mask are labeled + with their original segmentation values, while non-CSF voxels are set to zero. + + Args: + input_segmentation (Path | MRIData): Path to the anatomical segmentation NIfTI file + or an MRIData object containing the resampled segmentation. + csf_mask (Path | MRIData): Either a path to a CSF mask NIfTI file or an MRIData object containing the mask. + + Returns: + MRIData: An MRIData object containing the CSF segmentation. + """ + if isinstance(input_segmentation, Path): + seg_mri = MRIData.from_file(input_segmentation, dtype=np.int16) + else: + seg_mri = input_segmentation + + if isinstance(csf_mask, Path): + csf_mask_mri = MRIData.from_file(csf_mask, dtype=bool) + else: + csf_mask_mri = csf_mask + + assert_same_space(seg_mri, csf_mask_mri) + + # Get interpolation operator + I, J, K = np.where(seg_mri.data != 0) + interp = scipy.interpolate.NearestNDInterpolator(np.array([I, J, K]).T, seg_mri.data[I, J, K]) + # Interpolate segmentation values at CSF mask locations + i, j, k = np.where(csf_mask_mri.data != 0) + csf_seg = np.zeros_like(seg_mri.data, dtype=np.int16) + csf_seg[i, j, k] = interp(i, j, k) + + return MRIData(data=csf_seg.astype(np.int16), affine=csf_mask_mri.affine) + + def compute_intracranial_mask_array(csf_mask_array: np.ndarray, segmentation_array: np.ndarray) -> np.ndarray: """ Combines a CSF mask array and a brain segmentation mask array into a solid intracranial mask. @@ -134,7 +174,7 @@ def compute_intracranial_mask_array(csf_mask_array: np.ndarray, segmentation_arr return ~opened_background -def intracranial_mask(csf_segmentation_path: Path, segmentation_path: Path) -> MRIData: +def intracranial_mask(segmentation_path: Path, csf_mask_path: Path) -> MRIData: """ I/O wrapper for generating and saving an intracranial mask from NIfTI files. @@ -142,20 +182,21 @@ def intracranial_mask(csf_segmentation_path: Path, segmentation_path: Path) -> M delegates the array computation. Args: - csf_segmentation_path (Path): Path to the CSF segmentation NIfTI file. - segmentation_path (Path): Path to the brain segmentation NIfTI file. - output (Optional[Path], optional): Path to save the resulting mask. Defaults to None. + segmentation_path (Path): Path to the brain (refined) segmentation NIfTI file, \ + generated by the segmentation refinement module. + csf_mask_path (Path): Path to the CSF mask, generated by the csf mask module. Returns: MRIData: An MRIData object containing the intracranial mask. """ - input_csf_mask = MRIData.from_file(csf_segmentation_path, dtype=bool) + # Get segmentation data and csf segmentation segmentation_data = MRIData.from_file(segmentation_path, dtype=bool) + csf_seg = csf_segmentation(input_segmentation=segmentation_data, csf_mask=csf_mask_path) # Validate spatial alignment before array operations - assert_same_space(input_csf_mask, segmentation_data) + assert_same_space(csf_seg, segmentation_data) - mask_data = compute_intracranial_mask_array(input_csf_mask.data, segmentation_data.data) + mask_data = compute_intracranial_mask_array(csf_seg.data, segmentation_data.data) mri_data = MRIData(data=mask_data, affine=segmentation_data.affine) return mri_data @@ -183,8 +224,10 @@ def add_arguments( intracranial_mask_parser = subparser.add_parser( "intracranial", help="Compute intracranial mask", formatter_class=parser.formatter_class ) - intracranial_mask_parser.add_argument("--csf-segmentation-path", type=Path, help="Path to the CSF segmentation NIfTI file") - intracranial_mask_parser.add_argument("--segmentation-path", type=Path, help="Path to the brain segmentation NIfTI file") + intracranial_mask_parser.add_argument("--segmentation-path", type=Path, help="Path to refined segmentation file, generated by \ + the segmentation refinement module, i.e. mritk seg refine") + intracranial_mask_parser.add_argument("--csf-mask-path", type=Path, help="Path to the CSF mask NIfTI file, generated by \ + the csf mask module, i.e. mritk mask csf") intracranial_mask_parser.add_argument("-o", "--output", type=Path, help="Desired output path for the resulting mask") if extra_args_cb is not None: @@ -199,7 +242,7 @@ def dispatch(args): csf_mask_data.save(args.pop("output"), dtype=np.uint8) elif command == "intracranial": intracranial_mask_data = intracranial_mask( - csf_segmentation_path=args.pop("csf_segmentation_path"), segmentation_path=args.pop("segmentation_path") + segmentation_path=args.pop("segmentation_path"), csf_mask_path=args.pop("csf_mask_path") ) intracranial_mask_data.save(args.pop("output"), dtype=np.uint8) else: diff --git a/src/mritk/segmentation.py b/src/mritk/segmentation.py index 0f97fe5..252e4a7 100644 --- a/src/mritk/segmentation.py +++ b/src/mritk/segmentation.py @@ -9,12 +9,16 @@ import re from pathlib import Path from urllib.request import urlretrieve +import itertools +import scipy +import argparse +from collections.abc import Callable import numpy as np import numpy.typing as npt import pandas as pd -from .data import MRIData, load_mri_data +from .data import MRIData, load_mri_data, apply_affine logger = logging.getLogger(__name__) @@ -88,6 +92,10 @@ class Segmentation(MRIData): labels to a descriptive Lookup Table (LUT). """ + mri: MRIData + rois: np.ndarray + lut: pd.DataFrame + def __init__(self, data: np.ndarray, affine: np.ndarray, lut: pd.DataFrame | None = None): """ Initializes the Segmentation object. @@ -147,6 +155,73 @@ def get_roi_labels(self, rois: npt.NDArray[np.int32] | None = None) -> pd.DataFr return self.lut.loc[self.lut.index.isin(rois), [self._label_name]].rename_axis("ROI").reset_index() + def resample_to_reference(self, reference_mri: MRIData): + """ + Resamples the segmentation to match the spatial dimensions and resolution of a reference MRI. + + Args: + reference_mri (MRIData): The MRI to which the segmentation will be resampled, + for example a T1-weighted anatomical scan. + Returns: + Segmentation: A new Segmentation object containing the resampled data. + """ + + shape_in = self.shape + shape_out = reference_mri.shape + + # Generate a grid of voxel indices for the output space + upsampled_indices = np.fromiter( + itertools.product(*(np.arange(ni) for ni in shape_out)), + dtype=np.dtype((int, 3)), + ) + # Get voxel indices in the input segmentation space corresponding to the output grid + seg_indices = apply_affine( + np.linalg.inv(self.affine), + apply_affine(reference_mri.affine, upsampled_indices), + ) + seg_indices = np.rint(seg_indices).astype(int) + + # The two images does not necessarily share field of view. + # Remove voxels which are not located within the segmentation fov. + valid_index_mask = (seg_indices > 0).all(axis=1) * (seg_indices < shape_in).all(axis=1) + upsampled_indices = upsampled_indices[valid_index_mask] + seg_indices = seg_indices[valid_index_mask] + + seg_upsampled = np.zeros(shape_out, dtype=self.data.dtype) + I_in, J_in, K_in = seg_indices.T + I_out, J_out, K_out = upsampled_indices.T + seg_upsampled[I_out, J_out, K_out] = self.data[I_in, J_in, K_in] + + return Segmentation(data=seg_upsampled, affine=reference_mri.affine, lut=self.lut) + + def smooth(self, sigma: float, cutoff_score: float = 0.5, **kwargs) -> MRIData: + """ + Applies Gaussian smoothing to the segmentation labels to create a soft probabilistic map. + + Args: + sigma (float): The standard deviation for the Gaussian kernel. + cutoff_score (float, optional): A threshold to remove low-confidence voxels. Defaults to 0.5. + **kwargs: Additional keyword arguments passed to scipy.ndimage.gaussian_filter. + + Returns: + dict[str, np.ndarray]: A dictionary containing 'labels' (the smoothed segmentation) + and 'scores' (the confidence scores for each voxel). + """ + smoothed_rois = np.zeros_like(self.data) + high_scores = np.zeros(self.data.shape) + + for roi in self.rois: + scores = scipy.ndimage.gaussian_filter( + (self.data == roi).astype(float), sigma=sigma, **kwargs + ) + is_new_high_score = scores > high_scores + smoothed_rois[is_new_high_score] = roi + high_scores[is_new_high_score] = scores[is_new_high_score] + + delete_scores = (high_scores < cutoff_score) * (self.data == 0) + smoothed_rois[delete_scores] = 0 + + return MRIData(data=smoothed_rois, affine=self.affine) class FreeSurferSegmentation(Segmentation): """ @@ -407,3 +482,63 @@ def write_lut(filename: Path, table: pd.DataFrame): # Save as tab-separated values without headers or indices newtable.to_csv(filename, sep="\t", index=False, header=False) + + + +def add_arguments( + parser: argparse.ArgumentParser, + extra_args_cb: Callable[[argparse.ArgumentParser], None] | None = None, +) -> None: + subparser = parser.add_subparsers(dest="seg-command", help="Commands for segmentation processing") + + resample_parser = subparser.add_parser( + "resample", help="Resample a segmentation to match the space of a reference MRI", formatter_class=parser.formatter_class + ) + resample_parser.add_argument("-i", "--input", type=Path, help="Path to the input segmentation NIfTI file") + resample_parser.add_argument("-r", "--reference", type=Path, help="Path to the reference MRI \ + - usually a registered T1 weighted anatomical scan") + resample_parser.add_argument("-o", "--output", type=Path, help="Desired output path for the resampled segmentation") + + smooth_parser = subparser.add_parser( + "smooth", help="Apply Gaussian smoothing to a segmentation to create a soft probabilistic map", formatter_class=parser.formatter_class + ) + smooth_parser.add_argument("-i", "--input", type=Path, help="Path to the input (refined) segmentation NIfTI file") + smooth_parser.add_argument("-s", "--sigma", type=float, help="Standard deviation for the Gaussian kernel used in smoothing") + smooth_parser.add_argument("-c", "--cutoff", type=float, default=0.5, help="Cutoff score to remove low-confidence voxels (default: 0.5)") + smooth_parser.add_argument("-o", "--output", type=Path, help="Desired output path for the smoothed segmentation") + + + refine_parser = subparser.add_parser( + "refine", help="Refine a segmentation by applying Gaussian smoothing to the labels", formatter_class=parser.formatter_class + ) + refine_parser.add_argument("-i", "--input", type=Path, help="Path to the input segmentation NIfTI file") + refine_parser.add_argument("-r", "--reference", type=Path, help="Path to the reference MRI \ + - usually a registered T1 weighted anatomical scan") + refine_parser.add_argument("-s", "--smooth", type=float, help="Standard deviation for the Gaussian kernel used in smoothing") + refine_parser.add_argument("-o", "--output", type=Path, help="Desired output path for the refined segmentation") + + if extra_args_cb is not None: + extra_args_cb(resample_parser) + extra_args_cb(smooth_parser) + extra_args_cb(refine_parser) + + +def dispatch(args): + command = args.pop("seg-command") + if command == "resample": + print("Resampling segmentation...") + input_seg = Segmentation.from_file(args.pop("input")) + reference_mri = Segmentation.from_file(args.pop("reference")) + resampled_seg = input_seg.resample_to_reference(reference_mri) + resampled_seg.save(args.pop("output"), dtype=np.int32) + elif command == "smooth": + smoothed = Segmentation.from_file(args.pop("input")).smooth(sigma=args.pop("sigma"), cutoff_score=args.pop("cutoff")) + smoothed.save(args.pop("output"), dtype=np.int32) + elif command == "refine": + seg = Segmentation.from_file(args.pop("input")) + refined = seg.resample_to_reference(MRIData.from_file(args.pop("reference"))) + smoothed = refined.smooth(sigma=args.pop("smooth")) + refined.data = np.where(smoothed.data > 0, smoothed.data, refined.data) + refined.save(args.pop("output"), dtype=np.int32) + else: + raise ValueError(f"Unknown segmentation command: {command}") diff --git a/tests/create_test_data.py b/tests/create_test_data.py index 811a335..b55dae0 100644 --- a/tests/create_test_data.py +++ b/tests/create_test_data.py @@ -26,6 +26,9 @@ def main(): "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-csf-aseg.nii.gz", "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-wmparc_refined.nii.gz", "mri-processed/mri_processed_data/sub-01/registered/sub-01_ses-01_T2w_registered.nii.gz", + "freesurfer/mri_processed_data/freesurfer/sub-01/mri/aparc+aseg.mgz", + "freesurfer/mri_processed_data/freesurfer/sub-01/mri/aseg.mgz", + "freesurfer/mri_processed_data/freesurfer/sub-01/mri/wmparc.mgz", ] for file in files: diff --git a/tests/test_masks.py b/tests/test_masks.py index 2494e43..f9a92bb 100644 --- a/tests/test_masks.py +++ b/tests/test_masks.py @@ -7,12 +7,20 @@ from pathlib import Path from unittest.mock import patch +import pytest import nibabel as nib import numpy as np import mritk.cli -from mritk.masks import compute_csf_mask_array, compute_intracranial_mask_array, csf_mask, intracranial_mask, largest_island +from mritk.masks import ( + compute_csf_mask_array, + compute_intracranial_mask_array, + csf_mask, + intracranial_mask, + csf_segmentation, + largest_island +) from mritk.testing import compare_nifti_images @@ -146,7 +154,7 @@ def test_intracranial_mask_io(tmp_path): seg_data[4:6, 4:6, 4:6] = 1.0 nib.save(nib.Nifti1Image(seg_data, affine), seg_path) - result = intracranial_mask(csf_segmentation_path=csf_path, segmentation_path=seg_path) + result = intracranial_mask(segmentation_path=seg_path, csf_mask_path=csf_path) result.save(out_path, dtype=np.uint8) # Verify the file was physically saved to the filesystem @@ -170,17 +178,17 @@ def test_dispatch_intracranial_mask(mock_intracranial_mask): [ "mask", "intracranial", - "--csf-segmentation-path", - "csf_segmentation.nii.gz", "--segmentation-path", "segmentation.nii.gz", + "--csf-mask-path", + "csf_mask.nii.gz", "-o", "ic_mask.nii.gz", ] ) mock_intracranial_mask.assert_called_once_with( - csf_segmentation_path=Path("csf_segmentation.nii.gz"), segmentation_path=Path("segmentation.nii.gz") + segmentation_path=Path("segmentation.nii.gz"), csf_mask_path=Path("csf_mask.nii.gz") ) @@ -198,11 +206,25 @@ def test_csf_mask(tmp_path, mri_data_dir: Path): def test_intracranial_mask(tmp_path, mri_data_dir: Path): csf_segmentation_path = mri_data_dir / "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-csf-aseg.nii.gz" + csf_mask_path = mri_data_dir / "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-csf_binary.nii.gz" segmentation_path = mri_data_dir / "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-wmparc_refined.nii.gz" ref_output = mri_data_dir / "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-intracranial_binary.nii.gz" test_output = tmp_path / "output_seg-intracranial_binary.nii.gz" - result = intracranial_mask(csf_segmentation_path=csf_segmentation_path, segmentation_path=segmentation_path) + result = intracranial_mask(segmentation_path=segmentation_path, csf_mask_path=csf_mask_path) result.save(test_output, dtype=np.uint8) compare_nifti_images(test_output, ref_output, data_tolerance=1e-12) + +@pytest.mark.parametrize("seg_type", ["aparc+aseg", "aseg", "wmparc"]) +def test_csf_segmentation(tmp_path, mri_data_dir: Path, seg_type): + """Test the CSF segmentation logic by comparing against a known reference.""" + input_T2w_path = mri_data_dir / "mri-processed/mri_processed_data/sub-01/registered/sub-01_ses-01_T2w_registered.nii.gz" + input_csf_mask = mri_data_dir / "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-csf_binary.nii.gz" + + ref_output = mri_data_dir / f"mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-csf-{seg_type}.nii.gz" + test_output = tmp_path / f"output_seg-csf-{seg_type}.nii.gz" + + result = csf_segmentation(input_segmentation=input_T2w_path, csf_mask=input_csf_mask) + result.save(test_output, dtype=np.uint8) + compare_nifti_images(test_output, ref_output, data_tolerance=1e-12) \ No newline at end of file diff --git a/tests/test_segmentation.py b/tests/test_segmentation.py index a368fde..67618b9 100644 --- a/tests/test_segmentation.py +++ b/tests/test_segmentation.py @@ -17,6 +17,10 @@ validate_lut_file, write_lut, ) +from mritk.data import MRIData +from mritk.testing import compare_nifti_images +import mritk.cli + def test_segmentation_initialization(example_segmentation: Segmentation): @@ -180,3 +184,42 @@ def test_write_lut_file_io(tmp_path): # Verify the denormalization restored the original 0-255 integers assert content[0] == "4\tLeft-Lateral-Ventricle\t120\t18\t134\t0" assert content[1] == "5\tLeft-Inf-Lat-Vent\t198\t51\t122\t0" + +# Note : Refinement is actually testing both resampling and smoothing +@pytest.mark.skip(reason="Takes too long to run") +@pytest.mark.parametrize("seg_type", ["aparc+aseg", "aseg", "wmparc"]) +def test_segmentation_refinement(tmp_path, mri_data_dir: Path, seg_type: str): + + ref_mri = mri_data_dir / "mri-processed/mri_processed_data/sub-01/registered/sub-01_ses-01_T1w_registered.nii.gz" + smoothing = 1 + + FS_segmentation = mri_data_dir / f"freesurfer/mri_processed_data/freesurfer/sub-01/mri/{seg_type}.mgz" + ref_output = mri_data_dir / f"mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-{seg_type}_refined.nii.gz" + test_output = tmp_path / "output_refined.nii.gz" + + fs_input = Segmentation.from_file(FS_segmentation) + result = fs_input.resample_to_reference(MRIData.from_file(ref_mri)) + smoothed = result.smooth(sigma=smoothing) + result.data = smoothed.data + result.save(test_output, dtype=np.int32) + compare_nifti_images(test_output, ref_output, data_tolerance=1e-12) + +@patch("mritk.segmentation.Segmentation") +def test_dispatch_resample(mock_seg): + """Test that dispatch correctly routes to segmentation resample.""" + + mritk.cli.main(["seg", "resample", "-i", "mock_in.nii.gz", "-r", "mock_ref.nii.gz", "-o", "mock_out.nii.gz"]) + + assert mock_seg.from_file.call_count == 2 # Called once for input segmentation and once for reference MRI + inst = mock_seg.from_file.return_value # Segmentation type instance returned by from_file + inst.resample_to_reference.assert_called_once_with(mock_seg.from_file.return_value) + +@patch("mritk.segmentation.Segmentation") +def test_dispatch_smoothing(mock_seg): + """Test that dispatch correctly routes to segmentation smoothing.""" + + mritk.cli.main(["seg", "smooth", "-i", "mock_in.nii.gz", "-o", "mock_out.nii.gz", "-s", "1"]) + + mock_seg.from_file.assert_called_once_with(Path("mock_in.nii.gz")) + inst = mock_seg.from_file.return_value # Segmentation type instance returned by from_file + inst.smooth.assert_called_once_with(sigma=1, cutoff_score=0.5) From ab5ee6093d128f83157a54f381a8ccd44205c9e2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 23 Apr 2026 12:07:04 +0000 Subject: [PATCH 02/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/mritk/cli.py | 6 ++--- src/mritk/masks.py | 20 +++++++++++----- src/mritk/segmentation.py | 47 ++++++++++++++++++++++++-------------- tests/create_test_data.py | 2 +- tests/test_masks.py | 9 ++++---- tests/test_segmentation.py | 16 +++++++------ 6 files changed, 61 insertions(+), 39 deletions(-) diff --git a/src/mritk/cli.py b/src/mritk/cli.py index 5e4615c..f657c50 100644 --- a/src/mritk/cli.py +++ b/src/mritk/cli.py @@ -9,7 +9,7 @@ from rich.logging import RichHandler from rich_argparse import RichHelpFormatter -from . import concentration, datasets, hybrid, info, segmentation, looklocker, masks, mixed, napari, r1, show, statistics +from . import concentration, datasets, hybrid, info, looklocker, masks, mixed, napari, r1, segmentation, show, statistics def version_info(): @@ -75,9 +75,7 @@ def setup_parser(): napari_parser = subparsers.add_parser("napari", help="Show MRI data using napari", formatter_class=parser.formatter_class) napari.add_arguments(napari_parser) - segmentation_parser = subparsers.add_parser( - "seg", help="Perform segmentation tasks", formatter_class=parser.formatter_class - ) + segmentation_parser = subparsers.add_parser("seg", help="Perform segmentation tasks", formatter_class=parser.formatter_class) segmentation.add_arguments(segmentation_parser, extra_args_cb=add_extra_arguments) looklocker_parser = subparsers.add_parser( diff --git a/src/mritk/masks.py b/src/mritk/masks.py index 05c8573..6343782 100644 --- a/src/mritk/masks.py +++ b/src/mritk/masks.py @@ -9,8 +9,8 @@ from pathlib import Path import numpy as np -import skimage import scipy.interpolate +import skimage from .data import MRIData from .testing import assert_same_space @@ -117,7 +117,7 @@ def csf_segmentation(input_segmentation: Path | MRIData, csf_mask: Path | MRIDat with their original segmentation values, while non-CSF voxels are set to zero. Args: - input_segmentation (Path | MRIData): Path to the anatomical segmentation NIfTI file + input_segmentation (Path | MRIData): Path to the anatomical segmentation NIfTI file or an MRIData object containing the resampled segmentation. csf_mask (Path | MRIData): Either a path to a CSF mask NIfTI file or an MRIData object containing the mask. @@ -224,10 +224,18 @@ def add_arguments( intracranial_mask_parser = subparser.add_parser( "intracranial", help="Compute intracranial mask", formatter_class=parser.formatter_class ) - intracranial_mask_parser.add_argument("--segmentation-path", type=Path, help="Path to refined segmentation file, generated by \ - the segmentation refinement module, i.e. mritk seg refine") - intracranial_mask_parser.add_argument("--csf-mask-path", type=Path, help="Path to the CSF mask NIfTI file, generated by \ - the csf mask module, i.e. mritk mask csf") + intracranial_mask_parser.add_argument( + "--segmentation-path", + type=Path, + help="Path to refined segmentation file, generated by \ + the segmentation refinement module, i.e. mritk seg refine", + ) + intracranial_mask_parser.add_argument( + "--csf-mask-path", + type=Path, + help="Path to the CSF mask NIfTI file, generated by \ + the csf mask module, i.e. mritk mask csf", + ) intracranial_mask_parser.add_argument("-o", "--output", type=Path, help="Desired output path for the resulting mask") if extra_args_cb is not None: diff --git a/src/mritk/segmentation.py b/src/mritk/segmentation.py index 252e4a7..7d5c457 100644 --- a/src/mritk/segmentation.py +++ b/src/mritk/segmentation.py @@ -4,21 +4,21 @@ # Copyright (C) 2026 Cécile Daversin-Catty (cecile@simula.no) # Copyright (C) 2026 Simula Research Laboratory +import argparse +import itertools import logging import os import re +from collections.abc import Callable from pathlib import Path from urllib.request import urlretrieve -import itertools -import scipy -import argparse -from collections.abc import Callable import numpy as np import numpy.typing as npt import pandas as pd +import scipy -from .data import MRIData, load_mri_data, apply_affine +from .data import MRIData, apply_affine, load_mri_data logger = logging.getLogger(__name__) @@ -211,9 +211,7 @@ def smooth(self, sigma: float, cutoff_score: float = 0.5, **kwargs) -> MRIData: high_scores = np.zeros(self.data.shape) for roi in self.rois: - scores = scipy.ndimage.gaussian_filter( - (self.data == roi).astype(float), sigma=sigma, **kwargs - ) + scores = scipy.ndimage.gaussian_filter((self.data == roi).astype(float), sigma=sigma, **kwargs) is_new_high_score = scores > high_scores smoothed_rois[is_new_high_score] = roi high_scores[is_new_high_score] = scores[is_new_high_score] @@ -223,6 +221,7 @@ def smooth(self, sigma: float, cutoff_score: float = 0.5, **kwargs) -> MRIData: return MRIData(data=smoothed_rois, affine=self.affine) + class FreeSurferSegmentation(Segmentation): """ Segmentation class specifically tailored for FreeSurfer outputs. @@ -484,7 +483,6 @@ def write_lut(filename: Path, table: pd.DataFrame): newtable.to_csv(filename, sep="\t", index=False, header=False) - def add_arguments( parser: argparse.ArgumentParser, extra_args_cb: Callable[[argparse.ArgumentParser], None] | None = None, @@ -495,25 +493,40 @@ def add_arguments( "resample", help="Resample a segmentation to match the space of a reference MRI", formatter_class=parser.formatter_class ) resample_parser.add_argument("-i", "--input", type=Path, help="Path to the input segmentation NIfTI file") - resample_parser.add_argument("-r", "--reference", type=Path, help="Path to the reference MRI \ - - usually a registered T1 weighted anatomical scan") + resample_parser.add_argument( + "-r", + "--reference", + type=Path, + help="Path to the reference MRI \ + - usually a registered T1 weighted anatomical scan", + ) resample_parser.add_argument("-o", "--output", type=Path, help="Desired output path for the resampled segmentation") smooth_parser = subparser.add_parser( - "smooth", help="Apply Gaussian smoothing to a segmentation to create a soft probabilistic map", formatter_class=parser.formatter_class + "smooth", + help="Apply Gaussian smoothing to a segmentation to create a soft probabilistic map", + formatter_class=parser.formatter_class, ) smooth_parser.add_argument("-i", "--input", type=Path, help="Path to the input (refined) segmentation NIfTI file") smooth_parser.add_argument("-s", "--sigma", type=float, help="Standard deviation for the Gaussian kernel used in smoothing") - smooth_parser.add_argument("-c", "--cutoff", type=float, default=0.5, help="Cutoff score to remove low-confidence voxels (default: 0.5)") + smooth_parser.add_argument( + "-c", "--cutoff", type=float, default=0.5, help="Cutoff score to remove low-confidence voxels (default: 0.5)" + ) smooth_parser.add_argument("-o", "--output", type=Path, help="Desired output path for the smoothed segmentation") - refine_parser = subparser.add_parser( - "refine", help="Refine a segmentation by applying Gaussian smoothing to the labels", formatter_class=parser.formatter_class + "refine", + help="Refine a segmentation by applying Gaussian smoothing to the labels", + formatter_class=parser.formatter_class, ) refine_parser.add_argument("-i", "--input", type=Path, help="Path to the input segmentation NIfTI file") - refine_parser.add_argument("-r", "--reference", type=Path, help="Path to the reference MRI \ - - usually a registered T1 weighted anatomical scan") + refine_parser.add_argument( + "-r", + "--reference", + type=Path, + help="Path to the reference MRI \ + - usually a registered T1 weighted anatomical scan", + ) refine_parser.add_argument("-s", "--smooth", type=float, help="Standard deviation for the Gaussian kernel used in smoothing") refine_parser.add_argument("-o", "--output", type=Path, help="Desired output path for the refined segmentation") diff --git a/tests/create_test_data.py b/tests/create_test_data.py index b55dae0..222fba6 100644 --- a/tests/create_test_data.py +++ b/tests/create_test_data.py @@ -28,7 +28,7 @@ def main(): "mri-processed/mri_processed_data/sub-01/registered/sub-01_ses-01_T2w_registered.nii.gz", "freesurfer/mri_processed_data/freesurfer/sub-01/mri/aparc+aseg.mgz", "freesurfer/mri_processed_data/freesurfer/sub-01/mri/aseg.mgz", - "freesurfer/mri_processed_data/freesurfer/sub-01/mri/wmparc.mgz", + "freesurfer/mri_processed_data/freesurfer/sub-01/mri/wmparc.mgz", ] for file in files: diff --git a/tests/test_masks.py b/tests/test_masks.py index f9a92bb..a6a26b1 100644 --- a/tests/test_masks.py +++ b/tests/test_masks.py @@ -7,19 +7,19 @@ from pathlib import Path from unittest.mock import patch -import pytest import nibabel as nib import numpy as np +import pytest import mritk.cli from mritk.masks import ( compute_csf_mask_array, compute_intracranial_mask_array, csf_mask, - intracranial_mask, csf_segmentation, - largest_island + intracranial_mask, + largest_island, ) from mritk.testing import compare_nifti_images @@ -216,6 +216,7 @@ def test_intracranial_mask(tmp_path, mri_data_dir: Path): result.save(test_output, dtype=np.uint8) compare_nifti_images(test_output, ref_output, data_tolerance=1e-12) + @pytest.mark.parametrize("seg_type", ["aparc+aseg", "aseg", "wmparc"]) def test_csf_segmentation(tmp_path, mri_data_dir: Path, seg_type): """Test the CSF segmentation logic by comparing against a known reference.""" @@ -227,4 +228,4 @@ def test_csf_segmentation(tmp_path, mri_data_dir: Path, seg_type): result = csf_segmentation(input_segmentation=input_T2w_path, csf_mask=input_csf_mask) result.save(test_output, dtype=np.uint8) - compare_nifti_images(test_output, ref_output, data_tolerance=1e-12) \ No newline at end of file + compare_nifti_images(test_output, ref_output, data_tolerance=1e-12) diff --git a/tests/test_segmentation.py b/tests/test_segmentation.py index 67618b9..30592b5 100644 --- a/tests/test_segmentation.py +++ b/tests/test_segmentation.py @@ -5,6 +5,8 @@ import pandas as pd import pytest +import mritk.cli +from mritk.data import MRIData from mritk.segmentation import ( LUT_REGEX, VENTRICLES, @@ -17,10 +19,7 @@ validate_lut_file, write_lut, ) -from mritk.data import MRIData from mritk.testing import compare_nifti_images -import mritk.cli - def test_segmentation_initialization(example_segmentation: Segmentation): @@ -185,6 +184,7 @@ def test_write_lut_file_io(tmp_path): assert content[0] == "4\tLeft-Lateral-Ventricle\t120\t18\t134\t0" assert content[1] == "5\tLeft-Inf-Lat-Vent\t198\t51\t122\t0" + # Note : Refinement is actually testing both resampling and smoothing @pytest.mark.skip(reason="Takes too long to run") @pytest.mark.parametrize("seg_type", ["aparc+aseg", "aseg", "wmparc"]) @@ -193,7 +193,7 @@ def test_segmentation_refinement(tmp_path, mri_data_dir: Path, seg_type: str): ref_mri = mri_data_dir / "mri-processed/mri_processed_data/sub-01/registered/sub-01_ses-01_T1w_registered.nii.gz" smoothing = 1 - FS_segmentation = mri_data_dir / f"freesurfer/mri_processed_data/freesurfer/sub-01/mri/{seg_type}.mgz" + FS_segmentation = mri_data_dir / f"freesurfer/mri_processed_data/freesurfer/sub-01/mri/{seg_type}.mgz" ref_output = mri_data_dir / f"mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-{seg_type}_refined.nii.gz" test_output = tmp_path / "output_refined.nii.gz" @@ -204,16 +204,18 @@ def test_segmentation_refinement(tmp_path, mri_data_dir: Path, seg_type: str): result.save(test_output, dtype=np.int32) compare_nifti_images(test_output, ref_output, data_tolerance=1e-12) + @patch("mritk.segmentation.Segmentation") def test_dispatch_resample(mock_seg): """Test that dispatch correctly routes to segmentation resample.""" mritk.cli.main(["seg", "resample", "-i", "mock_in.nii.gz", "-r", "mock_ref.nii.gz", "-o", "mock_out.nii.gz"]) - assert mock_seg.from_file.call_count == 2 # Called once for input segmentation and once for reference MRI - inst = mock_seg.from_file.return_value # Segmentation type instance returned by from_file + assert mock_seg.from_file.call_count == 2 # Called once for input segmentation and once for reference MRI + inst = mock_seg.from_file.return_value # Segmentation type instance returned by from_file inst.resample_to_reference.assert_called_once_with(mock_seg.from_file.return_value) + @patch("mritk.segmentation.Segmentation") def test_dispatch_smoothing(mock_seg): """Test that dispatch correctly routes to segmentation smoothing.""" @@ -221,5 +223,5 @@ def test_dispatch_smoothing(mock_seg): mritk.cli.main(["seg", "smooth", "-i", "mock_in.nii.gz", "-o", "mock_out.nii.gz", "-s", "1"]) mock_seg.from_file.assert_called_once_with(Path("mock_in.nii.gz")) - inst = mock_seg.from_file.return_value # Segmentation type instance returned by from_file + inst = mock_seg.from_file.return_value # Segmentation type instance returned by from_file inst.smooth.assert_called_once_with(sigma=1, cutoff_score=0.5) From 1479097bad0d810cf8747d368c90eb445885d90b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9cile?= Date: Thu, 23 Apr 2026 14:42:23 +0200 Subject: [PATCH 03/24] Add missing files in test_data --- tests/create_test_data.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/create_test_data.py b/tests/create_test_data.py index b55dae0..5087c80 100644 --- a/tests/create_test_data.py +++ b/tests/create_test_data.py @@ -21,9 +21,12 @@ def main(): "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-intracranial_binary.nii.gz", "mri-processed/mri_dataset/derivatives/sub-01/ses-01/sub-01_ses-01_acq-mixed_T1map.nii.gz", "mri-processed/mri_dataset/derivatives/sub-01/ses-01/sub-01_ses-01_acq-looklocker_T1map.nii.gz", - "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-aparc+aseg_refined.nii.gz", "mri-processed/mri_processed_data/sub-01/registered/sub-01_ses-01_acq-looklocker_T1map_registered.nii.gz", "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-csf-aseg.nii.gz", + "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-csf-aparc+aseg.nii.gz", + "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-csf-wmparc.nii.gz", + "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-aseg_refined.nii.gz", + "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-aparc+aseg_refined.nii.gz", "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-wmparc_refined.nii.gz", "mri-processed/mri_processed_data/sub-01/registered/sub-01_ses-01_T2w_registered.nii.gz", "freesurfer/mri_processed_data/freesurfer/sub-01/mri/aparc+aseg.mgz", @@ -31,6 +34,7 @@ def main(): "freesurfer/mri_processed_data/freesurfer/sub-01/mri/wmparc.mgz", ] + for file in files: src = inputdir / file dst = outdir / file From e0b062b846088bc44fb308000026f1083db556ba Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 23 Apr 2026 12:43:07 +0000 Subject: [PATCH 04/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/create_test_data.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/create_test_data.py b/tests/create_test_data.py index 21cd7e8..e487ed3 100644 --- a/tests/create_test_data.py +++ b/tests/create_test_data.py @@ -34,7 +34,6 @@ def main(): "freesurfer/mri_processed_data/freesurfer/sub-01/mri/wmparc.mgz", ] - for file in files: src = inputdir / file dst = outdir / file From 10ef679829874e6000ab691a72061bd0dc10d45b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9cile?= Date: Thu, 23 Apr 2026 14:46:31 +0200 Subject: [PATCH 05/24] Minor - removed unused file in test --- tests/test_masks.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_masks.py b/tests/test_masks.py index a6a26b1..c08a8dd 100644 --- a/tests/test_masks.py +++ b/tests/test_masks.py @@ -205,7 +205,6 @@ def test_csf_mask(tmp_path, mri_data_dir: Path): def test_intracranial_mask(tmp_path, mri_data_dir: Path): - csf_segmentation_path = mri_data_dir / "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-csf-aseg.nii.gz" csf_mask_path = mri_data_dir / "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-csf_binary.nii.gz" segmentation_path = mri_data_dir / "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-wmparc_refined.nii.gz" From 53ef36e4a3b9742712eabbff2b5e6f12b50429e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9cile?= Date: Thu, 23 Apr 2026 14:49:23 +0200 Subject: [PATCH 06/24] Fix mypy - convert MRIData to Segmentation type --- tests/test_segmentation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_segmentation.py b/tests/test_segmentation.py index 30592b5..792fa48 100644 --- a/tests/test_segmentation.py +++ b/tests/test_segmentation.py @@ -197,7 +197,8 @@ def test_segmentation_refinement(tmp_path, mri_data_dir: Path, seg_type: str): ref_output = mri_data_dir / f"mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-{seg_type}_refined.nii.gz" test_output = tmp_path / "output_refined.nii.gz" - fs_input = Segmentation.from_file(FS_segmentation) + fs_input = Segmentation.from_file(FS_segmentation) #MRIData type + fs_input = Segmentation(data=fs_input.data, affine=fs_input.affine) # Convert to Segmentation type to access refinement methods result = fs_input.resample_to_reference(MRIData.from_file(ref_mri)) smoothed = result.smooth(sigma=smoothing) result.data = smoothed.data From d655ea038aad93329fb1261de87decc73dd0eac4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 23 Apr 2026 12:49:45 +0000 Subject: [PATCH 07/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_segmentation.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_segmentation.py b/tests/test_segmentation.py index 792fa48..a9e2ae0 100644 --- a/tests/test_segmentation.py +++ b/tests/test_segmentation.py @@ -197,8 +197,10 @@ def test_segmentation_refinement(tmp_path, mri_data_dir: Path, seg_type: str): ref_output = mri_data_dir / f"mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-{seg_type}_refined.nii.gz" test_output = tmp_path / "output_refined.nii.gz" - fs_input = Segmentation.from_file(FS_segmentation) #MRIData type - fs_input = Segmentation(data=fs_input.data, affine=fs_input.affine) # Convert to Segmentation type to access refinement methods + fs_input = Segmentation.from_file(FS_segmentation) # MRIData type + fs_input = Segmentation( + data=fs_input.data, affine=fs_input.affine + ) # Convert to Segmentation type to access refinement methods result = fs_input.resample_to_reference(MRIData.from_file(ref_mri)) smoothed = result.smooth(sigma=smoothing) result.data = smoothed.data From 0811a95da7842526d22fb9c92f747a434aed0f71 Mon Sep 17 00:00:00 2001 From: Henrik Finsberg Date: Thu, 23 Apr 2026 16:42:00 +0200 Subject: [PATCH 08/24] Reset cache for data --- .github/workflows/setup-data.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/setup-data.yml b/.github/workflows/setup-data.yml index 050ad3d..34548a8 100644 --- a/.github/workflows/setup-data.yml +++ b/.github/workflows/setup-data.yml @@ -22,7 +22,7 @@ jobs: path: data/ # The folder you want to cache # The key determines if we have a match. # Change 'v1' to 'v2' manually to force a re-download in the future. - key: test-data-v7 + key: test-data-v8 # 2. DOWNLOAD ONLY IF CACHE MISS From 3e093747c2d2fd790b58ac27224d2303ea4edd30 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9cile?= Date: Fri, 24 Apr 2026 15:32:34 +0200 Subject: [PATCH 09/24] Update masks.py function input type and move csf_segmentation to segmentation.py (as suggested by Henrik) --- src/mritk/masks.py | 77 +++++++++++++-------------------------------- tests/test_masks.py | 43 ++++++++++++------------- 2 files changed, 41 insertions(+), 79 deletions(-) diff --git a/src/mritk/masks.py b/src/mritk/masks.py index 6343782..27c5187 100644 --- a/src/mritk/masks.py +++ b/src/mritk/masks.py @@ -13,6 +13,7 @@ import skimage from .data import MRIData +from .segmentation import CSFSegmentation from .testing import assert_same_space @@ -82,12 +83,12 @@ def compute_csf_mask_array( return binary -def csf_mask(input: Path, connectivity: int | None = 2, use_li: bool = False) -> MRIData: +def csf_mask(input: MRIData, connectivity: int | None = 2, use_li: bool = False) -> MRIData: """ I/O wrapper for generating and saving a CSF mask from a NIfTI file. Args: - input (Path): Path to the input NIfTI image. + input (MRIData): An MRIData object containing the input volume (typically T2-weighted or Spin-Echo). connectivity (Optional[int], optional): Connectivity distance. Defaults to 2. use_li (bool, optional): If True, uses Li thresholding. Defaults to False. output (Optional[Path], optional): Path to save the resulting mask. Defaults to None. @@ -98,55 +99,14 @@ def csf_mask(input: Path, connectivity: int | None = 2, use_li: bool = False) -> Raises: AssertionError: If the resulting mask contains no voxels. """ - input_vol = MRIData.from_file(input, dtype=np.single) - mask = compute_csf_mask_array(input_vol.data, connectivity, use_li) - + mask = compute_csf_mask_array(input.data, connectivity, use_li) assert np.max(mask) > 0, "Masking failed, no voxels in mask" - mri_data = MRIData(data=mask, affine=input_vol.affine) + mri_data = MRIData(data=mask, affine=input.affine) return mri_data -def csf_segmentation(input_segmentation: Path | MRIData, csf_mask: Path | MRIData) -> MRIData: - """ - Generates a CSF segmentation by applying a CSF mask to an anatomical segmentation. - - This function takes an anatomical segmentation (e.g., from FreeSurfer) and a CSF mask, - and produces a new segmentation where voxels identified as CSF in the mask are labeled - with their original segmentation values, while non-CSF voxels are set to zero. - - Args: - input_segmentation (Path | MRIData): Path to the anatomical segmentation NIfTI file - or an MRIData object containing the resampled segmentation. - csf_mask (Path | MRIData): Either a path to a CSF mask NIfTI file or an MRIData object containing the mask. - - Returns: - MRIData: An MRIData object containing the CSF segmentation. - """ - if isinstance(input_segmentation, Path): - seg_mri = MRIData.from_file(input_segmentation, dtype=np.int16) - else: - seg_mri = input_segmentation - - if isinstance(csf_mask, Path): - csf_mask_mri = MRIData.from_file(csf_mask, dtype=bool) - else: - csf_mask_mri = csf_mask - - assert_same_space(seg_mri, csf_mask_mri) - - # Get interpolation operator - I, J, K = np.where(seg_mri.data != 0) - interp = scipy.interpolate.NearestNDInterpolator(np.array([I, J, K]).T, seg_mri.data[I, J, K]) - # Interpolate segmentation values at CSF mask locations - i, j, k = np.where(csf_mask_mri.data != 0) - csf_seg = np.zeros_like(seg_mri.data, dtype=np.int16) - csf_seg[i, j, k] = interp(i, j, k) - - return MRIData(data=csf_seg.astype(np.int16), affine=csf_mask_mri.affine) - - def compute_intracranial_mask_array(csf_mask_array: np.ndarray, segmentation_array: np.ndarray) -> np.ndarray: """ Combines a CSF mask array and a brain segmentation mask array into a solid intracranial mask. @@ -174,7 +134,7 @@ def compute_intracranial_mask_array(csf_mask_array: np.ndarray, segmentation_arr return ~opened_background -def intracranial_mask(segmentation_path: Path, csf_mask_path: Path) -> MRIData: +def intracranial_mask(segmentation: MRIData, csf_mask: MRIData) -> MRIData: """ I/O wrapper for generating and saving an intracranial mask from NIfTI files. @@ -182,22 +142,20 @@ def intracranial_mask(segmentation_path: Path, csf_mask_path: Path) -> MRIData: delegates the array computation. Args: - segmentation_path (Path): Path to the brain (refined) segmentation NIfTI file, \ + segmentation (MRIData): The refined segmentation (MRIData), \ generated by the segmentation refinement module. - csf_mask_path (Path): Path to the CSF mask, generated by the csf mask module. + csf_mask (MRIData): The CSF mask (MRIData), generated by the csf mask module. Returns: MRIData: An MRIData object containing the intracranial mask. """ - # Get segmentation data and csf segmentation - segmentation_data = MRIData.from_file(segmentation_path, dtype=bool) - csf_seg = csf_segmentation(input_segmentation=segmentation_data, csf_mask=csf_mask_path) + csf_seg = CSFSegmentation(segmentation, csf_mask).to_csf_segmentation() # Validate spatial alignment before array operations - assert_same_space(csf_seg, segmentation_data) + assert_same_space(csf_seg, segmentation) - mask_data = compute_intracranial_mask_array(csf_seg.data, segmentation_data.data) - mri_data = MRIData(data=mask_data, affine=segmentation_data.affine) + mask_data = compute_intracranial_mask_array(csf_seg.data, segmentation.data) + mri_data = MRIData(data=mask_data, affine=segmentation.affine) return mri_data @@ -246,11 +204,18 @@ def add_arguments( def dispatch(args): command = args.pop("mask-command") if command == "csf": - csf_mask_data = csf_mask(input=args.pop("input"), connectivity=args.pop("connectivity"), use_li=args.pop("use_li")) + + csf_mask_data = csf_mask( + input=MRIData.from_file(args.pop("input")), + connectivity=args.pop("connectivity"), + use_li=args.pop("use_li") + ) csf_mask_data.save(args.pop("output"), dtype=np.uint8) elif command == "intracranial": + intracranial_mask_data = intracranial_mask( - segmentation_path=args.pop("segmentation_path"), csf_mask_path=args.pop("csf_mask_path") + segmentation=MRIData.from_file(args.pop("segmentation_path"), dtype=np.single), + csf_mask=MRIData.from_file(args.pop("csf_mask_path"), dtype=np.single) ) intracranial_mask_data.save(args.pop("output"), dtype=np.uint8) else: diff --git a/tests/test_masks.py b/tests/test_masks.py index c08a8dd..35c1d75 100644 --- a/tests/test_masks.py +++ b/tests/test_masks.py @@ -13,11 +13,11 @@ import pytest import mritk.cli +from mritk.data import MRIData from mritk.masks import ( compute_csf_mask_array, compute_intracranial_mask_array, csf_mask, - csf_segmentation, intracranial_mask, largest_island, ) @@ -126,7 +126,8 @@ def test_csf_mask_io(tmp_path): nii = nib.Nifti1Image(data, np.eye(4)) nib.save(nii, in_path) - result = csf_mask(input=in_path, use_li=True) + input_data = mritk.data.MRIData.from_file(in_path, dtype=np.single) + result = csf_mask(input=input_data, use_li=True) result.save(out_path, dtype=np.uint8) # Verify the file was physically saved to the filesystem @@ -154,7 +155,7 @@ def test_intracranial_mask_io(tmp_path): seg_data[4:6, 4:6, 4:6] = 1.0 nib.save(nib.Nifti1Image(seg_data, affine), seg_path) - result = intracranial_mask(segmentation_path=seg_path, csf_mask_path=csf_path) + result = intracranial_mask(segmentation=MRIData(seg_data, affine), csf_mask=MRIData(csf_data, affine)) result.save(out_path, dtype=np.uint8) # Verify the file was physically saved to the filesystem @@ -164,15 +165,18 @@ def test_intracranial_mask_io(tmp_path): @patch("mritk.masks.csf_mask") -def test_dispatch_csf_mask(mock_csf_mask): +@patch("mritk.data.MRIData.from_file") +def test_dispatch_csf_mask(mock_from_file, mock_csf_mask): """Test the CLI dispatch for the CSF mask command.""" - mritk.cli.main(["mask", "csf", "-i", "input.nii.gz", "--output", "mock_out.nii.gz", "--use-li", "--connectivity", "2"]) + mritk.cli.main(["mask", "csf", "-i", "input.nii.gz", "-o", "mock_out.nii.gz", "--use-li", "--connectivity", "2"]) - mock_csf_mask.assert_called_once_with(input=Path("input.nii.gz"), connectivity=2, use_li=True) + input_data = mock_from_file(Path("input.nii.gz"), dtype=np.single) + mock_csf_mask.assert_called_once_with(input=input_data, connectivity=2, use_li=True) @patch("mritk.masks.intracranial_mask") -def test_dispatch_intracranial_mask(mock_intracranial_mask): +@patch("mritk.data.MRIData.from_file") +def test_dispatch_intracranial_mask(mock_from_file, mock_intracranial_mask): """Test the CLI dispatch for the intracranial mask command.""" mritk.cli.main( [ @@ -187,8 +191,10 @@ def test_dispatch_intracranial_mask(mock_intracranial_mask): ] ) + seg_data = mock_from_file(Path("segmentation.nii.gz"), dtype=np.single) + csf_data = mock_from_file(Path("csf_mask.nii.gz"), dtype=np.single) mock_intracranial_mask.assert_called_once_with( - segmentation_path=Path("segmentation.nii.gz"), csf_mask_path=Path("csf_mask.nii.gz") + segmentation=seg_data, csf_mask=csf_data ) @@ -199,7 +205,8 @@ def test_csf_mask(tmp_path, mri_data_dir: Path): ref_output = mri_data_dir / "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-csf_binary.nii.gz" test_output = tmp_path / "output_seg-csf_binary.nii.gz" - result = csf_mask(input=input_T2w_path, use_li=use_li) + input_T2w = mritk.data.MRIData.from_file(input_T2w_path, dtype=np.single) + result = csf_mask(input=input_T2w, use_li=use_li) result.save(test_output, dtype=np.uint8) compare_nifti_images(test_output, ref_output, data_tolerance=1e-12) @@ -211,20 +218,10 @@ def test_intracranial_mask(tmp_path, mri_data_dir: Path): ref_output = mri_data_dir / "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-intracranial_binary.nii.gz" test_output = tmp_path / "output_seg-intracranial_binary.nii.gz" - result = intracranial_mask(segmentation_path=segmentation_path, csf_mask_path=csf_mask_path) - result.save(test_output, dtype=np.uint8) - compare_nifti_images(test_output, ref_output, data_tolerance=1e-12) + input_segmentation = mritk.data.MRIData.from_file(segmentation_path, dtype=np.single) + input_csf_mask = mritk.data.MRIData.from_file(csf_mask_path, dtype=np.single) - -@pytest.mark.parametrize("seg_type", ["aparc+aseg", "aseg", "wmparc"]) -def test_csf_segmentation(tmp_path, mri_data_dir: Path, seg_type): - """Test the CSF segmentation logic by comparing against a known reference.""" - input_T2w_path = mri_data_dir / "mri-processed/mri_processed_data/sub-01/registered/sub-01_ses-01_T2w_registered.nii.gz" - input_csf_mask = mri_data_dir / "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-csf_binary.nii.gz" - - ref_output = mri_data_dir / f"mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-csf-{seg_type}.nii.gz" - test_output = tmp_path / f"output_seg-csf-{seg_type}.nii.gz" - - result = csf_segmentation(input_segmentation=input_T2w_path, csf_mask=input_csf_mask) + result = intracranial_mask(segmentation=input_segmentation, csf_mask=input_csf_mask) result.save(test_output, dtype=np.uint8) compare_nifti_images(test_output, ref_output, data_tolerance=1e-12) + From 9b81380e95bc8505a5b53f1350836ec53cff47f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9cile?= Date: Fri, 24 Apr 2026 15:36:17 +0200 Subject: [PATCH 10/24] Updates in segmentation: Segmentation class does not inherit from MRIData anymore and have its own from_file function, for the sake of consistency. csf_segmentation is moved to a dedicated CSFSegmentation class --- src/mritk/segmentation.py | 106 +++++++++++++++++++++++++++++-------- tests/conftest.py | 2 +- tests/test_segmentation.py | 80 +++++++++++++++++++++------- 3 files changed, 144 insertions(+), 44 deletions(-) diff --git a/src/mritk/segmentation.py b/src/mritk/segmentation.py index 7d5c457..1d13f66 100644 --- a/src/mritk/segmentation.py +++ b/src/mritk/segmentation.py @@ -12,6 +12,7 @@ from collections.abc import Callable from pathlib import Path from urllib.request import urlretrieve +from dataclasses import dataclass import numpy as np import numpy.typing as npt @@ -19,6 +20,7 @@ import scipy from .data import MRIData, apply_affine, load_mri_data +from .testing import assert_same_space logger = logging.getLogger(__name__) @@ -82,8 +84,7 @@ "subcortical-gm": SUBCORTICAL_GM_RANGES, } - -class Segmentation(MRIData): +class Segmentation: """ Base class for MRI segmentations, linking spatial data with anatomical lookup tables. @@ -95,8 +96,9 @@ class Segmentation(MRIData): mri: MRIData rois: np.ndarray lut: pd.DataFrame + label_name: str - def __init__(self, data: np.ndarray, affine: np.ndarray, lut: pd.DataFrame | None = None): + def __init__(self, mri: MRIData, lut: pd.DataFrame | None = None): """ Initializes the Segmentation object. @@ -106,11 +108,10 @@ def __init__(self, data: np.ndarray, affine: np.ndarray, lut: pd.DataFrame | Non lut (Optional[pd.DataFrame], optional): A pandas DataFrame mapping numerical labels to their descriptions. If None, a default numerical mapping is generated. Defaults to None. """ - super().__init__(data, affine) - self.data = self.data.astype(int) + self.mri = mri # Extract all unique active regions (ignoring 0/background) - self.rois = np.unique(self.data[self.data > 0]) + self.rois = np.unique(self.mri.data[self.mri.data > 0]) if lut is not None: self.lut = lut @@ -118,7 +119,39 @@ def __init__(self, data: np.ndarray, affine: np.ndarray, lut: pd.DataFrame | Non self.lut = pd.DataFrame({"Label": self.rois}, index=self.rois) # Identify the primary label column dynamically - self._label_name = "Label" if "Label" in self.lut.columns else self.lut.columns[0] + self.label_name = "Label" if "Label" in self.lut.columns else self.lut.columns[0] + + @classmethod + def from_file(cls, seg_path: Path) -> "Segmentation": + """Loads a Segmentation from a NIfTI file. + + Args: + seg_path (Path): The file path to the segmentation NIfTI file. + Returns: + Segmentation: An instance of the Segmentation class containing the loaded + segmentation data and affine transformation. + """ + logger.info(f"Loading segmentation from {seg_path}.") + mri = MRIData.from_file(seg_path, dtype=np.single) + + rois = np.unique(mri.data[mri.data > 0]) + lut = pd.DataFrame({"Label": rois}, index=rois) + + return cls(mri=mri, lut=lut) + + def set_lut(self, lut: pd.DataFrame, label_column: str = "Label"): + """Sets the Lookup Table (LUT) for the segmentation, ensuring it matches the present ROIs. + + Args: + lut (pd.DataFrame): A pandas DataFrame mapping numerical labels + to their descriptions. If None, a default numerical mapping is generated. Defaults to None. + label_column (str, optional): The name of the column in the LUT that contains the label descriptions. Defaults to "Label". + """ + + self.lut = lut + self.label_name = label_column + if self.label_name not in self.lut.columns: + raise ValueError(f"Specified label column '{self.label_name}' not found in LUT.") @property def num_rois(self) -> int: @@ -153,7 +186,7 @@ def get_roi_labels(self, rois: npt.NDArray[np.int32] | None = None) -> pd.DataFr if not np.isin(rois, self.rois).all(): raise ValueError("Some of the provided ROIs are not present in the segmentation.") - return self.lut.loc[self.lut.index.isin(rois), [self._label_name]].rename_axis("ROI").reset_index() + return self.lut.loc[self.lut.index.isin(rois), [self.label_name]].rename_axis("ROI").reset_index() def resample_to_reference(self, reference_mri: MRIData): """ @@ -166,7 +199,7 @@ def resample_to_reference(self, reference_mri: MRIData): Segmentation: A new Segmentation object containing the resampled data. """ - shape_in = self.shape + shape_in = self.mri.shape shape_out = reference_mri.shape # Generate a grid of voxel indices for the output space @@ -176,7 +209,7 @@ def resample_to_reference(self, reference_mri: MRIData): ) # Get voxel indices in the input segmentation space corresponding to the output grid seg_indices = apply_affine( - np.linalg.inv(self.affine), + np.linalg.inv(self.mri.affine), apply_affine(reference_mri.affine, upsampled_indices), ) seg_indices = np.rint(seg_indices).astype(int) @@ -187,12 +220,13 @@ def resample_to_reference(self, reference_mri: MRIData): upsampled_indices = upsampled_indices[valid_index_mask] seg_indices = seg_indices[valid_index_mask] - seg_upsampled = np.zeros(shape_out, dtype=self.data.dtype) + seg_upsampled = np.zeros(shape_out, dtype=self.mri.data.dtype) I_in, J_in, K_in = seg_indices.T I_out, J_out, K_out = upsampled_indices.T - seg_upsampled[I_out, J_out, K_out] = self.data[I_in, J_in, K_in] + seg_upsampled[I_out, J_out, K_out] = self.mri.data[I_in, J_in, K_in] - return Segmentation(data=seg_upsampled, affine=reference_mri.affine, lut=self.lut) + #return Segmentation(data=seg_upsampled, affine=reference_mri.affine, lut=self.lut) + return MRIData(data=seg_upsampled, affine=reference_mri.affine) def smooth(self, sigma: float, cutoff_score: float = 0.5, **kwargs) -> MRIData: """ @@ -207,19 +241,19 @@ def smooth(self, sigma: float, cutoff_score: float = 0.5, **kwargs) -> MRIData: dict[str, np.ndarray]: A dictionary containing 'labels' (the smoothed segmentation) and 'scores' (the confidence scores for each voxel). """ - smoothed_rois = np.zeros_like(self.data) - high_scores = np.zeros(self.data.shape) + smoothed_rois = np.zeros_like(self.mri.data) + high_scores = np.zeros(self.mri.data.shape) for roi in self.rois: - scores = scipy.ndimage.gaussian_filter((self.data == roi).astype(float), sigma=sigma, **kwargs) + scores = scipy.ndimage.gaussian_filter((self.mri.data == roi).astype(float), sigma=sigma, **kwargs) is_new_high_score = scores > high_scores smoothed_rois[is_new_high_score] = roi high_scores[is_new_high_score] = scores[is_new_high_score] - delete_scores = (high_scores < cutoff_score) * (self.data == 0) + delete_scores = (high_scores < cutoff_score) * (self.mri.data == 0) smoothed_rois[delete_scores] = 0 - return MRIData(data=smoothed_rois, affine=self.affine) + return MRIData(data=smoothed_rois, affine=self.mri.affine) class FreeSurferSegmentation(Segmentation): @@ -253,9 +287,8 @@ def from_file( # FreeSurfer LUTs index by the "label" column lut = lut.set_index("label") if "label" in lut.columns else lut - data, affine = load_mri_data(filepath, dtype=dtype, orient=orient) - return cls(data=data, affine=affine, lut=lut) - + mri = MRIData.from_file(filepath, dtype=dtype, orient=orient) + return cls(mri=mri, lut=lut) class ExtendedFreeSurferSegmentation(FreeSurferSegmentation): """ @@ -292,7 +325,7 @@ def get_roi_labels(self, rois: npt.NDArray[np.int32] | None = None) -> pd.DataFr left_on="FreeSurfer_ROI", right_on="FreeSurfer_ROI", how="outer", - ).drop(columns=["FreeSurfer_ROI"])[["ROI", self._label_name, "tissue_type"]] + ).drop(columns=["FreeSurfer_ROI"])[["ROI", self.label_name, "tissue_type"]] def get_tissue_type(self, rois: npt.NDArray[np.int32] | None = None) -> pd.DataFrame: """ @@ -322,6 +355,33 @@ def get_tissue_type(self, rois: npt.NDArray[np.int32] | None = None) -> pd.DataF ret["FreeSurfer_ROI"] = ret["ROI"] % 10000 return ret +# @dataclass +class CSFSegmentation: + segmentation: MRIData + csf_mask: MRIData + + def __init__(self, segmentation: MRIData, csf_mask: MRIData): + assert_same_space(segmentation, csf_mask) + self.segmentation = segmentation + self.csf_mask = csf_mask + + @classmethod + def from_file(cls, segmentation_path: Path, csf_mask_path: Path) -> "CSFSegmentation": + segmentation = MRIData.from_file(segmentation_path, dtype=np.int16) + csf_mask = MRIData.from_file(csf_mask_path, dtype=bool) + assert_same_space(segmentation, csf_mask) + return cls(segmentation=segmentation, csf_mask=csf_mask) + + def to_csf_segmentation(self) -> MRIData: + # Get interpolation operator + I, J, K = np.where(self.segmentation.data != 0) + interp = scipy.interpolate.NearestNDInterpolator(np.array([I, J, K]).T, self.segmentation.data[I, J, K]) + # Interpolate segmentation values at CSF mask locations + i, j, k = np.where(self.csf_mask.data != 0) + csf_seg = np.zeros_like(self.segmentation.data, dtype=np.int16) + csf_seg[i, j, k] = interp(i, j, k) + + return MRIData(data=csf_seg.astype(np.int16), affine=self.csf_mask.affine) def default_segmentation_groups() -> dict[str, list[int]]: """ @@ -541,7 +601,7 @@ def dispatch(args): if command == "resample": print("Resampling segmentation...") input_seg = Segmentation.from_file(args.pop("input")) - reference_mri = Segmentation.from_file(args.pop("reference")) + reference_mri = MRIData.from_file(args.pop("reference")) resampled_seg = input_seg.resample_to_reference(reference_mri) resampled_seg.save(args.pop("output"), dtype=np.int32) elif command == "smooth": diff --git a/tests/conftest.py b/tests/conftest.py index be56825..599b4c2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,7 +21,7 @@ def example_segmentation() -> Segmentation: base = np.array([0, 1, 2, 3], dtype=float) seg = np.tile(base, (100, 1)) - return Segmentation(seg, affine=np.eye(4)) + return Segmentation(MRIData(data=seg, affine=np.eye(4))) @pytest.fixture diff --git a/tests/test_segmentation.py b/tests/test_segmentation.py index 792fa48..fc43cf9 100644 --- a/tests/test_segmentation.py +++ b/tests/test_segmentation.py @@ -12,6 +12,7 @@ VENTRICLES, ExtendedFreeSurferSegmentation, Segmentation, + CSFSegmentation, default_segmentation_groups, lut_record, read_freesurfer_lut, @@ -23,8 +24,8 @@ def test_segmentation_initialization(example_segmentation: Segmentation): - assert example_segmentation.data.shape == (100, 4) - assert example_segmentation.affine.shape == (4, 4) + assert example_segmentation.mri.data.shape == (100, 4) + assert example_segmentation.mri.affine.shape == (4, 4) assert example_segmentation.num_rois == 3 assert set(example_segmentation.roi_labels) == {1, 2, 3} assert example_segmentation.lut.shape == (3, 1) @@ -47,11 +48,11 @@ def test_freesurfer_segmentation_labels(mri_data_dir: Path): def test_extended_freesurfer_segmentation_labels(example_segmentation: Segmentation, mri_data_dir: Path): - data = example_segmentation.data + data = example_segmentation.mri.data data[0:2, 0:2] = 10001 # csf data[3:5, 3:5] = 20001 # dura - ext_fs_seg = ExtendedFreeSurferSegmentation(data, affine=np.eye(4)) + ext_fs_seg = ExtendedFreeSurferSegmentation(MRIData(data=data, affine=np.eye(4))) labels = ext_fs_seg.get_roi_labels() assert set(labels["ROI"]) == set(ext_fs_seg.roi_labels) @@ -186,35 +187,74 @@ def test_write_lut_file_io(tmp_path): # Note : Refinement is actually testing both resampling and smoothing -@pytest.mark.skip(reason="Takes too long to run") +# @pytest.mark.skip(reason="Takes too long to run") +@pytest.mark.xfail( + reason=( + "Call to resample_to_reference fails " + "due to shape issue when using gonzo_roi. " + "Needs to be investigated further." + ) +) @pytest.mark.parametrize("seg_type", ["aparc+aseg", "aseg", "wmparc"]) -def test_segmentation_refinement(tmp_path, mri_data_dir: Path, seg_type: str): - - ref_mri = mri_data_dir / "mri-processed/mri_processed_data/sub-01/registered/sub-01_ses-01_T1w_registered.nii.gz" - smoothing = 1 - - FS_segmentation = mri_data_dir / f"freesurfer/mri_processed_data/freesurfer/sub-01/mri/{seg_type}.mgz" - ref_output = mri_data_dir / f"mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-{seg_type}_refined.nii.gz" +def test_segmentation_refinement(tmp_path, mri_data_dir: Path, gonzo_roi, seg_type: str): + + # Get gonzo_roi from FS_segmentation + FS_seg_path = mri_data_dir / f"freesurfer/mri_processed_data/freesurfer/sub-01/mri/{seg_type}.mgz" + fs_seg = Segmentation.from_file(FS_seg_path) #MRIData type + vi = gonzo_roi.voxel_indices(affine=fs_seg.mri.affine) + v = fs_seg.mri.data[tuple(vi.T)].reshape((*gonzo_roi.shape, -1)) + piece_fs_seg_data = mritk.data.MRIData(data=v, affine=gonzo_roi.affine) + + # Get gonzo_roi from reference MRI to use as reference for resampling + ref_mri_path = mri_data_dir / "mri-processed/mri_processed_data/sub-01/registered/sub-01_ses-01_T1w_registered.nii.gz" + ref_mri = MRIData.from_file(ref_mri_path, dtype=np.single) + vi = gonzo_roi.voxel_indices(affine=ref_mri.affine) + v = ref_mri.data[tuple(vi.T)].reshape((*gonzo_roi.shape, -1)) + piece_ref_mri_data = mritk.data.MRIData(data=v, affine=gonzo_roi.affine) + + # Output: Refine segmentation from gonzoi_roi segmentation and ref MRI test_output = tmp_path / "output_refined.nii.gz" - - fs_input = Segmentation.from_file(FS_segmentation) #MRIData type - fs_input = Segmentation(data=fs_input.data, affine=fs_input.affine) # Convert to Segmentation type to access refinement methods - result = fs_input.resample_to_reference(MRIData.from_file(ref_mri)) + smoothing = 1 + piece_fs_seg = Segmentation(mri=piece_fs_seg_data) + result = piece_fs_seg.resample_to_reference(piece_ref_mri_data) smoothed = result.smooth(sigma=smoothing) result.data = smoothed.data result.save(test_output, dtype=np.int32) - compare_nifti_images(test_output, ref_output, data_tolerance=1e-12) + ref_output_path = mri_data_dir / f"mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-{seg_type}_refined.nii.gz" + ref_output = mritk.data.MRIData.from_file(ref_output_path, dtype=np.single) + vi = gonzo_roi.voxel_indices(affine=ref_output.affine) + v_ref = ref_output.data[tuple(vi.T)].reshape((*gonzo_roi.shape, -1)) + + mritk.testing.compare_nifti_arrays(result.data, v_ref, data_tolerance=1e-12) + +@pytest.mark.parametrize("seg_type", ["aparc+aseg", "aseg", "wmparc"]) +def test_csf_segmentation(tmp_path, mri_data_dir: Path, seg_type): + """Test the CSF segmentation logic by comparing against a known reference.""" + input_T2w_path = mri_data_dir / "mri-processed/mri_processed_data/sub-01/registered/sub-01_ses-01_T2w_registered.nii.gz" + input_csf_mask_path = mri_data_dir / "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-csf_binary.nii.gz" + + ref_output = mri_data_dir / f"mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-csf-{seg_type}.nii.gz" + test_output = tmp_path / f"output_seg-csf-{seg_type}.nii.gz" + + input_T2w = MRIData.from_file(input_T2w_path, dtype=np.single) + input_csf_mask = MRIData.from_file(input_csf_mask_path, dtype=np.single) + result = CSFSegmentation(segmentation=input_T2w, csf_mask=input_csf_mask).to_csf_segmentation() + result.save(test_output, dtype=np.uint8) + compare_nifti_images(test_output, ref_output, data_tolerance=1e-12) @patch("mritk.segmentation.Segmentation") -def test_dispatch_resample(mock_seg): +@patch("mritk.data.MRIData") +def test_dispatch_resample(mock_seg, mock_mri_data): """Test that dispatch correctly routes to segmentation resample.""" mritk.cli.main(["seg", "resample", "-i", "mock_in.nii.gz", "-r", "mock_ref.nii.gz", "-o", "mock_out.nii.gz"]) - assert mock_seg.from_file.call_count == 2 # Called once for input segmentation and once for reference MRI + mock_seg.from_file.assert_called_once_with(Path("mock_in.nii.gz"), dtype=np.single) + mock_mri_data.from_file.assert_called_once_with(Path("mock_ref.nii.gz"), dtype=np.single) + inst = mock_seg.from_file.return_value # Segmentation type instance returned by from_file - inst.resample_to_reference.assert_called_once_with(mock_seg.from_file.return_value) + inst.resample_to_reference.assert_called_once_with(mock_mri_data.from_file.return_value) @patch("mritk.segmentation.Segmentation") From b55bdf39a59aec424cb550e000f6dd51e28f8e34 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 24 Apr 2026 13:38:29 +0000 Subject: [PATCH 11/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/mritk/masks.py | 9 ++------- src/mritk/segmentation.py | 13 ++++++++----- tests/test_masks.py | 6 +----- tests/test_segmentation.py | 14 ++++++-------- 4 files changed, 17 insertions(+), 25 deletions(-) diff --git a/src/mritk/masks.py b/src/mritk/masks.py index 27c5187..f94e97f 100644 --- a/src/mritk/masks.py +++ b/src/mritk/masks.py @@ -9,7 +9,6 @@ from pathlib import Path import numpy as np -import scipy.interpolate import skimage from .data import MRIData @@ -204,18 +203,14 @@ def add_arguments( def dispatch(args): command = args.pop("mask-command") if command == "csf": - csf_mask_data = csf_mask( - input=MRIData.from_file(args.pop("input")), - connectivity=args.pop("connectivity"), - use_li=args.pop("use_li") + input=MRIData.from_file(args.pop("input")), connectivity=args.pop("connectivity"), use_li=args.pop("use_li") ) csf_mask_data.save(args.pop("output"), dtype=np.uint8) elif command == "intracranial": - intracranial_mask_data = intracranial_mask( segmentation=MRIData.from_file(args.pop("segmentation_path"), dtype=np.single), - csf_mask=MRIData.from_file(args.pop("csf_mask_path"), dtype=np.single) + csf_mask=MRIData.from_file(args.pop("csf_mask_path"), dtype=np.single), ) intracranial_mask_data.save(args.pop("output"), dtype=np.uint8) else: diff --git a/src/mritk/segmentation.py b/src/mritk/segmentation.py index 1d13f66..34ce39c 100644 --- a/src/mritk/segmentation.py +++ b/src/mritk/segmentation.py @@ -12,14 +12,13 @@ from collections.abc import Callable from pathlib import Path from urllib.request import urlretrieve -from dataclasses import dataclass import numpy as np import numpy.typing as npt import pandas as pd import scipy -from .data import MRIData, apply_affine, load_mri_data +from .data import MRIData, apply_affine from .testing import assert_same_space logger = logging.getLogger(__name__) @@ -84,6 +83,7 @@ "subcortical-gm": SUBCORTICAL_GM_RANGES, } + class Segmentation: """ Base class for MRI segmentations, linking spatial data with anatomical lookup tables. @@ -130,7 +130,7 @@ def from_file(cls, seg_path: Path) -> "Segmentation": Returns: Segmentation: An instance of the Segmentation class containing the loaded segmentation data and affine transformation. - """ + """ logger.info(f"Loading segmentation from {seg_path}.") mri = MRIData.from_file(seg_path, dtype=np.single) @@ -225,7 +225,7 @@ def resample_to_reference(self, reference_mri: MRIData): I_out, J_out, K_out = upsampled_indices.T seg_upsampled[I_out, J_out, K_out] = self.mri.data[I_in, J_in, K_in] - #return Segmentation(data=seg_upsampled, affine=reference_mri.affine, lut=self.lut) + # return Segmentation(data=seg_upsampled, affine=reference_mri.affine, lut=self.lut) return MRIData(data=seg_upsampled, affine=reference_mri.affine) def smooth(self, sigma: float, cutoff_score: float = 0.5, **kwargs) -> MRIData: @@ -290,6 +290,7 @@ def from_file( mri = MRIData.from_file(filepath, dtype=dtype, orient=orient) return cls(mri=mri, lut=lut) + class ExtendedFreeSurferSegmentation(FreeSurferSegmentation): """ Extended FreeSurfer segmentation handling custom tissue type classifications. @@ -355,6 +356,7 @@ def get_tissue_type(self, rois: npt.NDArray[np.int32] | None = None) -> pd.DataF ret["FreeSurfer_ROI"] = ret["ROI"] % 10000 return ret + # @dataclass class CSFSegmentation: segmentation: MRIData @@ -371,7 +373,7 @@ def from_file(cls, segmentation_path: Path, csf_mask_path: Path) -> "CSFSegmenta csf_mask = MRIData.from_file(csf_mask_path, dtype=bool) assert_same_space(segmentation, csf_mask) return cls(segmentation=segmentation, csf_mask=csf_mask) - + def to_csf_segmentation(self) -> MRIData: # Get interpolation operator I, J, K = np.where(self.segmentation.data != 0) @@ -383,6 +385,7 @@ def to_csf_segmentation(self) -> MRIData: return MRIData(data=csf_seg.astype(np.int16), affine=self.csf_mask.affine) + def default_segmentation_groups() -> dict[str, list[int]]: """ Returns the default grouping of FreeSurfer labels into brain regions. diff --git a/tests/test_masks.py b/tests/test_masks.py index 35c1d75..bdbca59 100644 --- a/tests/test_masks.py +++ b/tests/test_masks.py @@ -10,7 +10,6 @@ import nibabel as nib import numpy as np -import pytest import mritk.cli from mritk.data import MRIData @@ -193,9 +192,7 @@ def test_dispatch_intracranial_mask(mock_from_file, mock_intracranial_mask): seg_data = mock_from_file(Path("segmentation.nii.gz"), dtype=np.single) csf_data = mock_from_file(Path("csf_mask.nii.gz"), dtype=np.single) - mock_intracranial_mask.assert_called_once_with( - segmentation=seg_data, csf_mask=csf_data - ) + mock_intracranial_mask.assert_called_once_with(segmentation=seg_data, csf_mask=csf_data) def test_csf_mask(tmp_path, mri_data_dir: Path): @@ -224,4 +221,3 @@ def test_intracranial_mask(tmp_path, mri_data_dir: Path): result = intracranial_mask(segmentation=input_segmentation, csf_mask=input_csf_mask) result.save(test_output, dtype=np.uint8) compare_nifti_images(test_output, ref_output, data_tolerance=1e-12) - diff --git a/tests/test_segmentation.py b/tests/test_segmentation.py index 9cb5c7e..ea77ab6 100644 --- a/tests/test_segmentation.py +++ b/tests/test_segmentation.py @@ -10,9 +10,9 @@ from mritk.segmentation import ( LUT_REGEX, VENTRICLES, + CSFSegmentation, ExtendedFreeSurferSegmentation, Segmentation, - CSFSegmentation, default_segmentation_groups, lut_record, read_freesurfer_lut, @@ -189,18 +189,14 @@ def test_write_lut_file_io(tmp_path): # Note : Refinement is actually testing both resampling and smoothing # @pytest.mark.skip(reason="Takes too long to run") @pytest.mark.xfail( - reason=( - "Call to resample_to_reference fails " - "due to shape issue when using gonzo_roi. " - "Needs to be investigated further." - ) + reason=("Call to resample_to_reference fails due to shape issue when using gonzo_roi. Needs to be investigated further.") ) @pytest.mark.parametrize("seg_type", ["aparc+aseg", "aseg", "wmparc"]) def test_segmentation_refinement(tmp_path, mri_data_dir: Path, gonzo_roi, seg_type: str): # Get gonzo_roi from FS_segmentation FS_seg_path = mri_data_dir / f"freesurfer/mri_processed_data/freesurfer/sub-01/mri/{seg_type}.mgz" - fs_seg = Segmentation.from_file(FS_seg_path) #MRIData type + fs_seg = Segmentation.from_file(FS_seg_path) # MRIData type vi = gonzo_roi.voxel_indices(affine=fs_seg.mri.affine) v = fs_seg.mri.data[tuple(vi.T)].reshape((*gonzo_roi.shape, -1)) piece_fs_seg_data = mritk.data.MRIData(data=v, affine=gonzo_roi.affine) @@ -229,6 +225,7 @@ def test_segmentation_refinement(tmp_path, mri_data_dir: Path, gonzo_roi, seg_ty mritk.testing.compare_nifti_arrays(result.data, v_ref, data_tolerance=1e-12) + @pytest.mark.parametrize("seg_type", ["aparc+aseg", "aseg", "wmparc"]) def test_csf_segmentation(tmp_path, mri_data_dir: Path, seg_type): """Test the CSF segmentation logic by comparing against a known reference.""" @@ -244,6 +241,7 @@ def test_csf_segmentation(tmp_path, mri_data_dir: Path, seg_type): result.save(test_output, dtype=np.uint8) compare_nifti_images(test_output, ref_output, data_tolerance=1e-12) + @patch("mritk.segmentation.Segmentation") @patch("mritk.data.MRIData") def test_dispatch_resample(mock_seg, mock_mri_data): @@ -253,7 +251,7 @@ def test_dispatch_resample(mock_seg, mock_mri_data): mock_seg.from_file.assert_called_once_with(Path("mock_in.nii.gz"), dtype=np.single) mock_mri_data.from_file.assert_called_once_with(Path("mock_ref.nii.gz"), dtype=np.single) - + inst = mock_seg.from_file.return_value # Segmentation type instance returned by from_file inst.resample_to_reference.assert_called_once_with(mock_mri_data.from_file.return_value) From 4351949c5a5d9e97d391850824cd6722dd8d9f5b Mon Sep 17 00:00:00 2001 From: Henrik Finsberg Date: Fri, 24 Apr 2026 20:41:59 +0200 Subject: [PATCH 12/24] Various fixes --- .github/workflows/setup-data.yml | 2 +- src/mritk/segmentation.py | 52 +++++++++++++++++++++------ src/mritk/statistics/compute_stats.py | 6 ++-- tests/create_test_data.py | 1 + tests/test_mri_io.py | 5 +-- tests/test_segmentation.py | 40 +++++++++++++++++---- 6 files changed, 83 insertions(+), 23 deletions(-) diff --git a/.github/workflows/setup-data.yml b/.github/workflows/setup-data.yml index 34548a8..2265a39 100644 --- a/.github/workflows/setup-data.yml +++ b/.github/workflows/setup-data.yml @@ -22,7 +22,7 @@ jobs: path: data/ # The folder you want to cache # The key determines if we have a match. # Change 'v1' to 'v2' manually to force a re-download in the future. - key: test-data-v8 + key: test-data-v9 # 2. DOWNLOAD ONLY IF CACHE MISS diff --git a/src/mritk/segmentation.py b/src/mritk/segmentation.py index 34ce39c..56cc514 100644 --- a/src/mritk/segmentation.py +++ b/src/mritk/segmentation.py @@ -122,30 +122,57 @@ def __init__(self, mri: MRIData, lut: pd.DataFrame | None = None): self.label_name = "Label" if "Label" in self.lut.columns else self.lut.columns[0] @classmethod - def from_file(cls, seg_path: Path) -> "Segmentation": + def from_file( + cls, seg_path: Path, dtype: npt.DTypeLike | None = None, orient: bool = True, lut_path: Path | None = None + ) -> "Segmentation": """Loads a Segmentation from a NIfTI file. Args: seg_path (Path): The file path to the segmentation NIfTI file. + dtype (npt.DTypeLike, optional): The data type for the segmentation data. Defaults to None. + orient (bool, optional): Whether to orient the data. Defaults to True. + lut_path (Path, optional): The file path to the lookup table. Defaults to None. Returns: Segmentation: An instance of the Segmentation class containing the loaded segmentation data and affine transformation. """ logger.info(f"Loading segmentation from {seg_path}.") - mri = MRIData.from_file(seg_path, dtype=np.single) + mri = MRIData.from_file(seg_path, dtype=dtype, orient=orient) + + if lut_path is None and seg_path.with_suffix(".json").exists(): + lut_path = seg_path.with_suffix(".json") - rois = np.unique(mri.data[mri.data > 0]) - lut = pd.DataFrame({"Label": rois}, index=rois) + if lut_path is not None: + logger.info(f"Loading LUT from {lut_path}.") + lut = pd.read_json(lut_path) + else: + rois = np.unique(mri.data[mri.data > 0]) + lut = pd.DataFrame({"Label": rois}, index=rois) return cls(mri=mri, lut=lut) + def save(self, output_path: Path, dtype: npt.DTypeLike | None = None, intent_code: int = 1006, lut_path: Path | None = None): + """Saves the Segmentation to a NIfTI file. + + Args: + output_path (Path): The file path where the segmentation will be saved. + dtype (npt.DTypeLike, optional): The data type for the saved segmentation data. Defaults to None. + intent_code (int, optional): The NIfTI intent code to set in the header. Defaults to 1006 (NIFTI_INTENT_LABEL). + """ + self.mri.save(output_path, dtype=dtype, intent_code=intent_code) + if lut_path is not None: + self.lut.to_json(lut_path, orient="index") + else: + self.lut.to_json(output_path.with_suffix(".json"), orient="index") + def set_lut(self, lut: pd.DataFrame, label_column: str = "Label"): """Sets the Lookup Table (LUT) for the segmentation, ensuring it matches the present ROIs. Args: lut (pd.DataFrame): A pandas DataFrame mapping numerical labels to their descriptions. If None, a default numerical mapping is generated. Defaults to None. - label_column (str, optional): The name of the column in the LUT that contains the label descriptions. Defaults to "Label". + label_column (str, optional): The name of the column in the LUT that contains the label + descriptions. Defaults to "Label". """ self.lut = lut @@ -188,7 +215,7 @@ def get_roi_labels(self, rois: npt.NDArray[np.int32] | None = None) -> pd.DataFr return self.lut.loc[self.lut.index.isin(rois), [self.label_name]].rename_axis("ROI").reset_index() - def resample_to_reference(self, reference_mri: MRIData): + def resample_to_reference(self, reference_mri: MRIData) -> "Segmentation": """ Resamples the segmentation to match the spatial dimensions and resolution of a reference MRI. @@ -226,9 +253,10 @@ def resample_to_reference(self, reference_mri: MRIData): seg_upsampled[I_out, J_out, K_out] = self.mri.data[I_in, J_in, K_in] # return Segmentation(data=seg_upsampled, affine=reference_mri.affine, lut=self.lut) - return MRIData(data=seg_upsampled, affine=reference_mri.affine) + mri = MRIData(data=seg_upsampled, affine=reference_mri.affine) + return Segmentation(mri=mri, lut=self.lut) - def smooth(self, sigma: float, cutoff_score: float = 0.5, **kwargs) -> MRIData: + def smooth(self, sigma: float, cutoff_score: float = 0.5, **kwargs) -> "Segmentation": """ Applies Gaussian smoothing to the segmentation labels to create a soft probabilistic map. @@ -253,7 +281,8 @@ def smooth(self, sigma: float, cutoff_score: float = 0.5, **kwargs) -> MRIData: delete_scores = (high_scores < cutoff_score) * (self.mri.data == 0) smoothed_rois[delete_scores] = 0 - return MRIData(data=smoothed_rois, affine=self.mri.affine) + mri = MRIData(data=smoothed_rois, affine=self.mri.affine) + return Segmentation(mri=mri, lut=self.lut) class FreeSurferSegmentation(Segmentation): @@ -607,14 +636,17 @@ def dispatch(args): reference_mri = MRIData.from_file(args.pop("reference")) resampled_seg = input_seg.resample_to_reference(reference_mri) resampled_seg.save(args.pop("output"), dtype=np.int32) + elif command == "smooth": smoothed = Segmentation.from_file(args.pop("input")).smooth(sigma=args.pop("sigma"), cutoff_score=args.pop("cutoff")) smoothed.save(args.pop("output"), dtype=np.int32) + elif command == "refine": seg = Segmentation.from_file(args.pop("input")) refined = seg.resample_to_reference(MRIData.from_file(args.pop("reference"))) smoothed = refined.smooth(sigma=args.pop("smooth")) - refined.data = np.where(smoothed.data > 0, smoothed.data, refined.data) + refined.mri.data = np.where(smoothed.data > 0, smoothed.data, refined.data) refined.save(args.pop("output"), dtype=np.int32) + else: raise ValueError(f"Unknown segmentation command: {command}") diff --git a/src/mritk/statistics/compute_stats.py b/src/mritk/statistics/compute_stats.py index 9fb4ca5..b5ccc10 100644 --- a/src/mritk/statistics/compute_stats.py +++ b/src/mritk/statistics/compute_stats.py @@ -219,7 +219,7 @@ def generate_stats_dataframe_rois( metadata: Optional[dict] = None, ) -> pd.DataFrame: # Verify that segmentation and MRI are in the same space - assert_same_space(seg, mri) + assert_same_space(seg.mri, mri) qoi_records = [] # Collects records related to qois roi_records = [] # Collects records related to ROIs, @@ -228,7 +228,7 @@ def generate_stats_dataframe_rois( finite_mask = np.isfinite(mri.data) for roi in tqdm.rich.tqdm(seg.roi_labels, total=len(seg.roi_labels)): # Identify rois in segmentation - region_mask = (seg.data == roi) * finite_mask + region_mask = (seg.mri.data == roi) * finite_mask # print(region_mask.shape) region_data = mri.data[region_mask] nb_nans = np.isnan(region_data).sum() @@ -239,7 +239,7 @@ def generate_stats_dataframe_rois( { "ROI": roi, "voxel_count": voxelcount, - "volume_ml": seg.voxel_ml_volume * voxelcount, + "volume_ml": seg.mri.voxel_ml_volume * voxelcount, "num_nan_values": nb_nans, } ) diff --git a/tests/create_test_data.py b/tests/create_test_data.py index e487ed3..68571b0 100644 --- a/tests/create_test_data.py +++ b/tests/create_test_data.py @@ -29,6 +29,7 @@ def main(): "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-aparc+aseg_refined.nii.gz", "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-wmparc_refined.nii.gz", "mri-processed/mri_processed_data/sub-01/registered/sub-01_ses-01_T2w_registered.nii.gz", + "mri-processed/mri_processed_data/sub-01/registered/sub-01_ses-01_T1w_registered.nii.gz", "freesurfer/mri_processed_data/freesurfer/sub-01/mri/aparc+aseg.mgz", "freesurfer/mri_processed_data/freesurfer/sub-01/mri/aseg.mgz", "freesurfer/mri_processed_data/freesurfer/sub-01/mri/wmparc.mgz", diff --git a/tests/test_mri_io.py b/tests/test_mri_io.py index 57aa442..c98f9a3 100644 --- a/tests/test_mri_io.py +++ b/tests/test_mri_io.py @@ -55,8 +55,9 @@ def test_load_mri_data_invalid_suffix(mri_data_dir): @pytest.mark.parametrize("orient", (True, False)) def test_load_Segmentation(tmp_path, mri_data_dir, orient: bool): input_file = mri_data_dir / "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-aparc+aseg_refined.nii.gz" - seg = Segmentation.from_file(input_file) - assert seg.data.dtype == int + seg = Segmentation.from_file(input_file, dtype=np.int32) + + assert seg.mri.data.dtype == np.int32 mri = MRIData.from_file(input_file, dtype=np.single, orient=orient) output_file = tmp_path.with_suffix(".nii.gz") mri.save(output_file, dtype=np.single) diff --git a/tests/test_segmentation.py b/tests/test_segmentation.py index ea77ab6..7d13493 100644 --- a/tests/test_segmentation.py +++ b/tests/test_segmentation.py @@ -187,13 +187,11 @@ def test_write_lut_file_io(tmp_path): # Note : Refinement is actually testing both resampling and smoothing -# @pytest.mark.skip(reason="Takes too long to run") -@pytest.mark.xfail( - reason=("Call to resample_to_reference fails due to shape issue when using gonzo_roi. Needs to be investigated further.") -) +# @pytest.mark.xfail( +# reason=("Call to resample_to_reference fails due to shape issue when using gonzo_roi. Needs to be investigated further.") +# ) @pytest.mark.parametrize("seg_type", ["aparc+aseg", "aseg", "wmparc"]) def test_segmentation_refinement(tmp_path, mri_data_dir: Path, gonzo_roi, seg_type: str): - # Get gonzo_roi from FS_segmentation FS_seg_path = mri_data_dir / f"freesurfer/mri_processed_data/freesurfer/sub-01/mri/{seg_type}.mgz" fs_seg = Segmentation.from_file(FS_seg_path) # MRIData type @@ -215,7 +213,7 @@ def test_segmentation_refinement(tmp_path, mri_data_dir: Path, gonzo_roi, seg_ty piece_fs_seg = Segmentation(mri=piece_fs_seg_data) result = piece_fs_seg.resample_to_reference(piece_ref_mri_data) smoothed = result.smooth(sigma=smoothing) - result.data = smoothed.data + result.mri.data = smoothed.mri.data result.save(test_output, dtype=np.int32) ref_output_path = mri_data_dir / f"mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-{seg_type}_refined.nii.gz" @@ -223,7 +221,7 @@ def test_segmentation_refinement(tmp_path, mri_data_dir: Path, gonzo_roi, seg_ty vi = gonzo_roi.voxel_indices(affine=ref_output.affine) v_ref = ref_output.data[tuple(vi.T)].reshape((*gonzo_roi.shape, -1)) - mritk.testing.compare_nifti_arrays(result.data, v_ref, data_tolerance=1e-12) + mritk.testing.compare_nifti_arrays(result.mri.data, v_ref, data_tolerance=1e-12) @pytest.mark.parametrize("seg_type", ["aparc+aseg", "aseg", "wmparc"]) @@ -265,3 +263,31 @@ def test_dispatch_smoothing(mock_seg): mock_seg.from_file.assert_called_once_with(Path("mock_in.nii.gz")) inst = mock_seg.from_file.return_value # Segmentation type instance returned by from_file inst.smooth.assert_called_once_with(sigma=1, cutoff_score=0.5) + + +@patch("mritk.segmentation.Segmentation") +@patch("mritk.data.MRIData") +def test_dispatch_refine(mock_seg, mock_mri_data): + """Test that dispatch correctly routes to segmentation refinement.""" + + mritk.cli.main( + [ + "seg", + "refine", + "-i", + "mock_in.nii.gz", + "-r", + "mock_ref.nii.gz", + "-o", + "mock_out.nii.gz", + "-s", + "1", + ] + ) + + mock_seg.from_file.assert_called_once_with(Path("mock_in.nii.gz")) + mock_mri_data.from_file.assert_called_once_with(Path("mock_ref.nii.gz"), dtype=np.single) + + inst = mock_seg.from_file.return_value # Segmentation type instance returned by from_file + inst.resample_to_reference.assert_called_once_with(mock_mri_data.from_file.return_value) + inst.smooth.assert_called_once_with(sigma=1, cutoff_score=0.5) From 78ef8e10bf1170569a272bdc9679a51a18d8b984 Mon Sep 17 00:00:00 2001 From: Henrik Finsberg Date: Fri, 24 Apr 2026 20:48:57 +0200 Subject: [PATCH 13/24] Fix shape issue in refinement test --- tests/test_segmentation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_segmentation.py b/tests/test_segmentation.py index 7d13493..27561b2 100644 --- a/tests/test_segmentation.py +++ b/tests/test_segmentation.py @@ -196,14 +196,14 @@ def test_segmentation_refinement(tmp_path, mri_data_dir: Path, gonzo_roi, seg_ty FS_seg_path = mri_data_dir / f"freesurfer/mri_processed_data/freesurfer/sub-01/mri/{seg_type}.mgz" fs_seg = Segmentation.from_file(FS_seg_path) # MRIData type vi = gonzo_roi.voxel_indices(affine=fs_seg.mri.affine) - v = fs_seg.mri.data[tuple(vi.T)].reshape((*gonzo_roi.shape, -1)) + v = fs_seg.mri.data[tuple(vi.T)].reshape(gonzo_roi.shape) piece_fs_seg_data = mritk.data.MRIData(data=v, affine=gonzo_roi.affine) # Get gonzo_roi from reference MRI to use as reference for resampling ref_mri_path = mri_data_dir / "mri-processed/mri_processed_data/sub-01/registered/sub-01_ses-01_T1w_registered.nii.gz" ref_mri = MRIData.from_file(ref_mri_path, dtype=np.single) vi = gonzo_roi.voxel_indices(affine=ref_mri.affine) - v = ref_mri.data[tuple(vi.T)].reshape((*gonzo_roi.shape, -1)) + v = ref_mri.data[tuple(vi.T)].reshape(gonzo_roi.shape) piece_ref_mri_data = mritk.data.MRIData(data=v, affine=gonzo_roi.affine) # Output: Refine segmentation from gonzoi_roi segmentation and ref MRI @@ -219,7 +219,7 @@ def test_segmentation_refinement(tmp_path, mri_data_dir: Path, gonzo_roi, seg_ty ref_output_path = mri_data_dir / f"mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-{seg_type}_refined.nii.gz" ref_output = mritk.data.MRIData.from_file(ref_output_path, dtype=np.single) vi = gonzo_roi.voxel_indices(affine=ref_output.affine) - v_ref = ref_output.data[tuple(vi.T)].reshape((*gonzo_roi.shape, -1)) + v_ref = ref_output.data[tuple(vi.T)].reshape(gonzo_roi.shape) mritk.testing.compare_nifti_arrays(result.mri.data, v_ref, data_tolerance=1e-12) From 97e4acfa14701eebb0d59cee6810ca87f3fc0277 Mon Sep 17 00:00:00 2001 From: Henrik Finsberg Date: Fri, 24 Apr 2026 20:53:36 +0200 Subject: [PATCH 14/24] Use gonzo roi for csf segmentation test --- tests/test_segmentation.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/tests/test_segmentation.py b/tests/test_segmentation.py index 27561b2..b930d60 100644 --- a/tests/test_segmentation.py +++ b/tests/test_segmentation.py @@ -20,7 +20,6 @@ validate_lut_file, write_lut, ) -from mritk.testing import compare_nifti_images def test_segmentation_initialization(example_segmentation: Segmentation): @@ -225,19 +224,30 @@ def test_segmentation_refinement(tmp_path, mri_data_dir: Path, gonzo_roi, seg_ty @pytest.mark.parametrize("seg_type", ["aparc+aseg", "aseg", "wmparc"]) -def test_csf_segmentation(tmp_path, mri_data_dir: Path, seg_type): +def test_csf_segmentation(tmp_path, mri_data_dir: Path, gonzo_roi, seg_type): """Test the CSF segmentation logic by comparing against a known reference.""" - input_T2w_path = mri_data_dir / "mri-processed/mri_processed_data/sub-01/registered/sub-01_ses-01_T2w_registered.nii.gz" + input_seg_path = mri_data_dir / f"mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-{seg_type}_refined.nii.gz" input_csf_mask_path = mri_data_dir / "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-csf_binary.nii.gz" - ref_output = mri_data_dir / f"mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-csf-{seg_type}.nii.gz" - test_output = tmp_path / f"output_seg-csf-{seg_type}.nii.gz" + ref_output_path = mri_data_dir / f"mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-csf-{seg_type}.nii.gz" + + input_seg = MRIData.from_file(input_seg_path, dtype=np.single) + vi = gonzo_roi.voxel_indices(affine=input_seg.affine) + v = input_seg.data[tuple(vi.T)].reshape(gonzo_roi.shape) + piece_seg_data = mritk.data.MRIData(data=v, affine=gonzo_roi.affine) - input_T2w = MRIData.from_file(input_T2w_path, dtype=np.single) input_csf_mask = MRIData.from_file(input_csf_mask_path, dtype=np.single) - result = CSFSegmentation(segmentation=input_T2w, csf_mask=input_csf_mask).to_csf_segmentation() - result.save(test_output, dtype=np.uint8) - compare_nifti_images(test_output, ref_output, data_tolerance=1e-12) + vi = gonzo_roi.voxel_indices(affine=input_csf_mask.affine) + v = input_csf_mask.data[tuple(vi.T)].reshape(gonzo_roi.shape) + piece_csf_mask_data = mritk.data.MRIData(data=v, affine=gonzo_roi.affine) + + result = CSFSegmentation(segmentation=piece_seg_data, csf_mask=piece_csf_mask_data).to_csf_segmentation() + + ref_output = MRIData.from_file(ref_output_path, dtype=np.single) + vi = gonzo_roi.voxel_indices(affine=ref_output.affine) + v_ref = ref_output.data[tuple(vi.T)].reshape(gonzo_roi.shape) + + mritk.testing.compare_nifti_arrays(result.data, v_ref, data_tolerance=1e-12) @patch("mritk.segmentation.Segmentation") From e5e28fed7f52b6c332d1bc136fc1f6a256259dab Mon Sep 17 00:00:00 2001 From: Henrik Finsberg Date: Fri, 24 Apr 2026 20:58:16 +0200 Subject: [PATCH 15/24] Fix dispatch tests in segmentation --- src/mritk/segmentation.py | 2 +- tests/test_segmentation.py | 27 +++++++++++++++++++-------- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/src/mritk/segmentation.py b/src/mritk/segmentation.py index 56cc514..a55e610 100644 --- a/src/mritk/segmentation.py +++ b/src/mritk/segmentation.py @@ -645,7 +645,7 @@ def dispatch(args): seg = Segmentation.from_file(args.pop("input")) refined = seg.resample_to_reference(MRIData.from_file(args.pop("reference"))) smoothed = refined.smooth(sigma=args.pop("smooth")) - refined.mri.data = np.where(smoothed.data > 0, smoothed.data, refined.data) + refined.mri.data = np.where(smoothed.mri.data > 0, smoothed.mri.data, refined.mri.data) refined.save(args.pop("output"), dtype=np.int32) else: diff --git a/tests/test_segmentation.py b/tests/test_segmentation.py index b930d60..aa12d83 100644 --- a/tests/test_segmentation.py +++ b/tests/test_segmentation.py @@ -250,15 +250,15 @@ def test_csf_segmentation(tmp_path, mri_data_dir: Path, gonzo_roi, seg_type): mritk.testing.compare_nifti_arrays(result.data, v_ref, data_tolerance=1e-12) +@patch("mritk.segmentation.MRIData") @patch("mritk.segmentation.Segmentation") -@patch("mritk.data.MRIData") def test_dispatch_resample(mock_seg, mock_mri_data): """Test that dispatch correctly routes to segmentation resample.""" mritk.cli.main(["seg", "resample", "-i", "mock_in.nii.gz", "-r", "mock_ref.nii.gz", "-o", "mock_out.nii.gz"]) - mock_seg.from_file.assert_called_once_with(Path("mock_in.nii.gz"), dtype=np.single) - mock_mri_data.from_file.assert_called_once_with(Path("mock_ref.nii.gz"), dtype=np.single) + mock_seg.from_file.assert_called_once_with(Path("mock_in.nii.gz")) + mock_mri_data.from_file.assert_called_once_with(Path("mock_ref.nii.gz")) inst = mock_seg.from_file.return_value # Segmentation type instance returned by from_file inst.resample_to_reference.assert_called_once_with(mock_mri_data.from_file.return_value) @@ -272,14 +272,25 @@ def test_dispatch_smoothing(mock_seg): mock_seg.from_file.assert_called_once_with(Path("mock_in.nii.gz")) inst = mock_seg.from_file.return_value # Segmentation type instance returned by from_file - inst.smooth.assert_called_once_with(sigma=1, cutoff_score=0.5) + inst.smooth.assert_called_once_with(sigma=1.0, cutoff_score=0.5) +@patch("mritk.segmentation.MRIData") @patch("mritk.segmentation.Segmentation") -@patch("mritk.data.MRIData") def test_dispatch_refine(mock_seg, mock_mri_data): """Test that dispatch correctly routes to segmentation refinement.""" + # Mock the underlying data arrays to avoid TypeError in np.where + inst = mock_seg.from_file.return_value + refined_inst = inst.resample_to_reference.return_value + smoothed_inst = refined_inst.smooth.return_value + + # Setup mock numpy arrays for the attributes used in np.where + smoothed_inst.data = np.array([1]) # In case the source code bug isn't fixed yet + refined_inst.data = np.array([0]) # In case the source code bug isn't fixed yet + refined_inst.mri.data = np.array([0]) # Correct fixed access + smoothed_inst.mri.data = np.array([1]) # Correct fixed access + mritk.cli.main( [ "seg", @@ -296,8 +307,8 @@ def test_dispatch_refine(mock_seg, mock_mri_data): ) mock_seg.from_file.assert_called_once_with(Path("mock_in.nii.gz")) - mock_mri_data.from_file.assert_called_once_with(Path("mock_ref.nii.gz"), dtype=np.single) + mock_mri_data.from_file.assert_called_once_with(Path("mock_ref.nii.gz")) - inst = mock_seg.from_file.return_value # Segmentation type instance returned by from_file inst.resample_to_reference.assert_called_once_with(mock_mri_data.from_file.return_value) - inst.smooth.assert_called_once_with(sigma=1, cutoff_score=0.5) + refined_inst.smooth.assert_called_once_with(sigma=1.0) + refined_inst.save.assert_called_once_with(Path("mock_out.nii.gz"), dtype=np.int32) From 7edd89982c452508f2add12daf405d242d7ed844 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9cile?= Date: Sat, 25 Apr 2026 13:40:00 +0200 Subject: [PATCH 16/24] fixes in segmentation classes --- src/mritk/segmentation.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/mritk/segmentation.py b/src/mritk/segmentation.py index a55e610..a1270e9 100644 --- a/src/mritk/segmentation.py +++ b/src/mritk/segmentation.py @@ -12,6 +12,7 @@ from collections.abc import Callable from pathlib import Path from urllib.request import urlretrieve +from dataclasses import dataclass import numpy as np import numpy.typing as npt @@ -83,7 +84,7 @@ "subcortical-gm": SUBCORTICAL_GM_RANGES, } - +@dataclass class Segmentation: """ Base class for MRI segmentations, linking spatial data with anatomical lookup tables. @@ -93,11 +94,6 @@ class Segmentation: labels to a descriptive Lookup Table (LUT). """ - mri: MRIData - rois: np.ndarray - lut: pd.DataFrame - label_name: str - def __init__(self, mri: MRIData, lut: pd.DataFrame | None = None): """ Initializes the Segmentation object. @@ -386,9 +382,9 @@ def get_tissue_type(self, rois: npt.NDArray[np.int32] | None = None) -> pd.DataF return ret -# @dataclass +@dataclass class CSFSegmentation: - segmentation: MRIData + segmentation: Segmentation csf_mask: MRIData def __init__(self, segmentation: MRIData, csf_mask: MRIData): From fe21f1a48bd41cf689094494f637e479ef1c558b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 25 Apr 2026 11:40:24 +0000 Subject: [PATCH 17/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/mritk/segmentation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/mritk/segmentation.py b/src/mritk/segmentation.py index a1270e9..a9ff9bf 100644 --- a/src/mritk/segmentation.py +++ b/src/mritk/segmentation.py @@ -10,9 +10,9 @@ import os import re from collections.abc import Callable +from dataclasses import dataclass from pathlib import Path from urllib.request import urlretrieve -from dataclasses import dataclass import numpy as np import numpy.typing as npt @@ -84,6 +84,7 @@ "subcortical-gm": SUBCORTICAL_GM_RANGES, } + @dataclass class Segmentation: """ From d4e05e91a89ea5a53531fbcf1cd48d6270a9712e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9cile?= Date: Sat, 25 Apr 2026 13:42:30 +0200 Subject: [PATCH 18/24] fixes in segmentation classes --- src/mritk/segmentation.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/mritk/segmentation.py b/src/mritk/segmentation.py index a1270e9..615e9c3 100644 --- a/src/mritk/segmentation.py +++ b/src/mritk/segmentation.py @@ -387,7 +387,7 @@ class CSFSegmentation: segmentation: Segmentation csf_mask: MRIData - def __init__(self, segmentation: MRIData, csf_mask: MRIData): + def __init__(self, segmentation: Segmentation, csf_mask: MRIData): assert_same_space(segmentation, csf_mask) self.segmentation = segmentation self.csf_mask = csf_mask @@ -401,11 +401,11 @@ def from_file(cls, segmentation_path: Path, csf_mask_path: Path) -> "CSFSegmenta def to_csf_segmentation(self) -> MRIData: # Get interpolation operator - I, J, K = np.where(self.segmentation.data != 0) - interp = scipy.interpolate.NearestNDInterpolator(np.array([I, J, K]).T, self.segmentation.data[I, J, K]) + I, J, K = np.where(self.segmentation.mri.data != 0) + interp = scipy.interpolate.NearestNDInterpolator(np.array([I, J, K]).T, self.segmentation.mri.data[I, J, K]) # Interpolate segmentation values at CSF mask locations i, j, k = np.where(self.csf_mask.data != 0) - csf_seg = np.zeros_like(self.segmentation.data, dtype=np.int16) + csf_seg = np.zeros_like(self.segmentation.mri.data, dtype=np.int16) csf_seg[i, j, k] = interp(i, j, k) return MRIData(data=csf_seg.astype(np.int16), affine=self.csf_mask.affine) From edca7a17e021da63b3fa04f825d9bf0a31db2fe2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9cile?= Date: Sat, 25 Apr 2026 13:44:38 +0200 Subject: [PATCH 19/24] fixes in segmentation classes --- src/mritk/segmentation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/mritk/segmentation.py b/src/mritk/segmentation.py index d2705f8..db00e95 100644 --- a/src/mritk/segmentation.py +++ b/src/mritk/segmentation.py @@ -389,15 +389,15 @@ class CSFSegmentation: csf_mask: MRIData def __init__(self, segmentation: Segmentation, csf_mask: MRIData): - assert_same_space(segmentation, csf_mask) + assert_same_space(segmentation.mri, csf_mask) self.segmentation = segmentation self.csf_mask = csf_mask @classmethod def from_file(cls, segmentation_path: Path, csf_mask_path: Path) -> "CSFSegmentation": - segmentation = MRIData.from_file(segmentation_path, dtype=np.int16) + segmentation = Segmentation.from_file(segmentation_path, dtype=np.int16) csf_mask = MRIData.from_file(csf_mask_path, dtype=bool) - assert_same_space(segmentation, csf_mask) + assert_same_space(segmentation.mri, csf_mask) return cls(segmentation=segmentation, csf_mask=csf_mask) def to_csf_segmentation(self) -> MRIData: From 5b0e54bb36a3cf85522fef231cc60f9304eee327 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9cile?= Date: Sat, 25 Apr 2026 13:49:50 +0200 Subject: [PATCH 20/24] Fixes after changes in CSFSegmentation --- src/mritk/masks.py | 10 +++++----- tests/test_segmentation.py | 3 ++- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/mritk/masks.py b/src/mritk/masks.py index f94e97f..394ea9e 100644 --- a/src/mritk/masks.py +++ b/src/mritk/masks.py @@ -12,7 +12,7 @@ import skimage from .data import MRIData -from .segmentation import CSFSegmentation +from .segmentation import Segmentation, CSFSegmentation from .testing import assert_same_space @@ -133,7 +133,7 @@ def compute_intracranial_mask_array(csf_mask_array: np.ndarray, segmentation_arr return ~opened_background -def intracranial_mask(segmentation: MRIData, csf_mask: MRIData) -> MRIData: +def intracranial_mask(segmentation: Segmentation, csf_mask: MRIData) -> MRIData: """ I/O wrapper for generating and saving an intracranial mask from NIfTI files. @@ -151,10 +151,10 @@ def intracranial_mask(segmentation: MRIData, csf_mask: MRIData) -> MRIData: csf_seg = CSFSegmentation(segmentation, csf_mask).to_csf_segmentation() # Validate spatial alignment before array operations - assert_same_space(csf_seg, segmentation) + assert_same_space(csf_seg, segmentation.mri) - mask_data = compute_intracranial_mask_array(csf_seg.data, segmentation.data) - mri_data = MRIData(data=mask_data, affine=segmentation.affine) + mask_data = compute_intracranial_mask_array(csf_seg.data, segmentation.mri.data) + mri_data = MRIData(data=mask_data, affine=segmentation.mri.affine) return mri_data diff --git a/tests/test_segmentation.py b/tests/test_segmentation.py index aa12d83..899cfd4 100644 --- a/tests/test_segmentation.py +++ b/tests/test_segmentation.py @@ -235,13 +235,14 @@ def test_csf_segmentation(tmp_path, mri_data_dir: Path, gonzo_roi, seg_type): vi = gonzo_roi.voxel_indices(affine=input_seg.affine) v = input_seg.data[tuple(vi.T)].reshape(gonzo_roi.shape) piece_seg_data = mritk.data.MRIData(data=v, affine=gonzo_roi.affine) + piece_seg = Segmentation(mri=piece_seg_data) input_csf_mask = MRIData.from_file(input_csf_mask_path, dtype=np.single) vi = gonzo_roi.voxel_indices(affine=input_csf_mask.affine) v = input_csf_mask.data[tuple(vi.T)].reshape(gonzo_roi.shape) piece_csf_mask_data = mritk.data.MRIData(data=v, affine=gonzo_roi.affine) - result = CSFSegmentation(segmentation=piece_seg_data, csf_mask=piece_csf_mask_data).to_csf_segmentation() + result = CSFSegmentation(segmentation=piece_seg, csf_mask=piece_csf_mask_data).to_csf_segmentation() ref_output = MRIData.from_file(ref_output_path, dtype=np.single) vi = gonzo_roi.voxel_indices(affine=ref_output.affine) From 7b5660697aa281da629014f82587031cdc138a62 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 25 Apr 2026 11:50:23 +0000 Subject: [PATCH 21/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/mritk/masks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mritk/masks.py b/src/mritk/masks.py index 394ea9e..18799e5 100644 --- a/src/mritk/masks.py +++ b/src/mritk/masks.py @@ -12,7 +12,7 @@ import skimage from .data import MRIData -from .segmentation import Segmentation, CSFSegmentation +from .segmentation import CSFSegmentation, Segmentation from .testing import assert_same_space From f833bc3418beb4d95041cee686ccd9b35e4cc763 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9cile?= Date: Sat, 25 Apr 2026 13:55:09 +0200 Subject: [PATCH 22/24] Fix tests --- tests/test_masks.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_masks.py b/tests/test_masks.py index bdbca59..d2bfa01 100644 --- a/tests/test_masks.py +++ b/tests/test_masks.py @@ -20,6 +20,7 @@ intracranial_mask, largest_island, ) +from mritk.segmentation import Segmentation from mritk.testing import compare_nifti_images @@ -154,7 +155,7 @@ def test_intracranial_mask_io(tmp_path): seg_data[4:6, 4:6, 4:6] = 1.0 nib.save(nib.Nifti1Image(seg_data, affine), seg_path) - result = intracranial_mask(segmentation=MRIData(seg_data, affine), csf_mask=MRIData(csf_data, affine)) + result = intracranial_mask(segmentation=Segmentation(mri=MRIData(seg_data, affine)), csf_mask=MRIData(csf_data, affine)) result.save(out_path, dtype=np.uint8) # Verify the file was physically saved to the filesystem @@ -215,7 +216,7 @@ def test_intracranial_mask(tmp_path, mri_data_dir: Path): ref_output = mri_data_dir / "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-intracranial_binary.nii.gz" test_output = tmp_path / "output_seg-intracranial_binary.nii.gz" - input_segmentation = mritk.data.MRIData.from_file(segmentation_path, dtype=np.single) + input_segmentation = mritk.segmentation.Segmentation.from_file(segmentation_path, dtype=np.single) input_csf_mask = mritk.data.MRIData.from_file(csf_mask_path, dtype=np.single) result = intracranial_mask(segmentation=input_segmentation, csf_mask=input_csf_mask) From 7d61f357a7c013f18cb94197d962c7ea6d0bc2c3 Mon Sep 17 00:00:00 2001 From: Henrik Finsberg Date: Sun, 26 Apr 2026 19:34:02 +0200 Subject: [PATCH 23/24] Fix Segmentation dataclass --- src/mritk/segmentation.py | 35 +++++++++++++++++------------------ 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/src/mritk/segmentation.py b/src/mritk/segmentation.py index db00e95..2da0a5b 100644 --- a/src/mritk/segmentation.py +++ b/src/mritk/segmentation.py @@ -85,7 +85,7 @@ } -@dataclass +@dataclass(init=False) class Segmentation: """ Base class for MRI segmentations, linking spatial data with anatomical lookup tables. @@ -93,27 +93,28 @@ class Segmentation: This class extends MRIData by specifically treating the image array as discrete integer labels representing Regions of Interest (ROIs). It links these numerical labels to a descriptive Lookup Table (LUT). + + Args: + data (np.ndarray): 3D numpy array containing integer ROI labels. + affine (np.ndarray): 4x4 affine transformation matrix mapping voxel indices to physical space. + lut (Optional[pd.DataFrame], optional): A pandas DataFrame mapping numerical labels + to their descriptions. If None, a default numerical mapping is generated. Defaults to None. """ - def __init__(self, mri: MRIData, lut: pd.DataFrame | None = None): - """ - Initializes the Segmentation object. + mri: MRIData + lut: pd.DataFrame + label_name: str + rois: np.ndarray - Args: - data (np.ndarray): 3D numpy array containing integer ROI labels. - affine (np.ndarray): 4x4 affine transformation matrix mapping voxel indices to physical space. - lut (Optional[pd.DataFrame], optional): A pandas DataFrame mapping numerical labels - to their descriptions. If None, a default numerical mapping is generated. Defaults to None. - """ + def __init__(self, mri: MRIData, lut: pd.DataFrame | None = None): self.mri = mri - # Extract all unique active regions (ignoring 0/background) self.rois = np.unique(self.mri.data[self.mri.data > 0]) - if lut is not None: - self.lut = lut - else: + if lut is None: self.lut = pd.DataFrame({"Label": self.rois}, index=self.rois) + else: + self.lut = lut # Identify the primary label column dynamically self.label_name = "Label" if "Label" in self.lut.columns else self.lut.columns[0] @@ -388,10 +389,8 @@ class CSFSegmentation: segmentation: Segmentation csf_mask: MRIData - def __init__(self, segmentation: Segmentation, csf_mask: MRIData): - assert_same_space(segmentation.mri, csf_mask) - self.segmentation = segmentation - self.csf_mask = csf_mask + def __post_init__(self): + assert_same_space(self.segmentation.mri, self.csf_mask) @classmethod def from_file(cls, segmentation_path: Path, csf_mask_path: Path) -> "CSFSegmentation": From 3294721f750c4ae5daab926818b05ddc87f2658f Mon Sep 17 00:00:00 2001 From: Henrik Finsberg Date: Sun, 26 Apr 2026 19:36:50 +0200 Subject: [PATCH 24/24] Update docstrings --- src/mritk/segmentation.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/src/mritk/segmentation.py b/src/mritk/segmentation.py index 2da0a5b..ef8ffc5 100644 --- a/src/mritk/segmentation.py +++ b/src/mritk/segmentation.py @@ -95,8 +95,7 @@ class Segmentation: labels to a descriptive Lookup Table (LUT). Args: - data (np.ndarray): 3D numpy array containing integer ROI labels. - affine (np.ndarray): 4x4 affine transformation matrix mapping voxel indices to physical space. + mri (MRIData): The MRIData object containing the segmentation volume and affine. lut (Optional[pd.DataFrame], optional): A pandas DataFrame mapping numerical labels to their descriptions. If None, a default numerical mapping is generated. Defaults to None. """ @@ -386,6 +385,19 @@ def get_tissue_type(self, rois: npt.NDArray[np.int32] | None = None) -> pd.DataF @dataclass class CSFSegmentation: + """ + A specialized segmentation class for isolating Cerebrospinal Fluid (CSF) regions. + + This class combines a standard anatomical segmentation (e.g., FreeSurfer) with a + binary mask specifically targeting CSF regions. It provides functionality to + generate a new segmentation volume where only the CSF-labeled voxels are retained, + while all other voxels are set to zero. + + Args: + segmentation (Segmentation): The anatomical segmentation containing the full set of labels. + csf_mask (MRIData): A binary mask isolating the CSF regions, aligned in the same space as the segmentation. + """ + segmentation: Segmentation csf_mask: MRIData @@ -400,6 +412,8 @@ def from_file(cls, segmentation_path: Path, csf_mask_path: Path) -> "CSFSegmenta return cls(segmentation=segmentation, csf_mask=csf_mask) def to_csf_segmentation(self) -> MRIData: + """Generates a new MRIData object containing only the CSF-labeled + voxels from the original segmentation.""" # Get interpolation operator I, J, K = np.where(self.segmentation.mri.data != 0) interp = scipy.interpolate.NearestNDInterpolator(np.array([I, J, K]).T, self.segmentation.mri.data[I, J, K])