diff --git a/.github/workflows/setup-data.yml b/.github/workflows/setup-data.yml index 050ad3d..2265a39 100644 --- a/.github/workflows/setup-data.yml +++ b/.github/workflows/setup-data.yml @@ -22,7 +22,7 @@ jobs: path: data/ # The folder you want to cache # The key determines if we have a match. # Change 'v1' to 'v2' manually to force a re-download in the future. - key: test-data-v7 + key: test-data-v9 # 2. DOWNLOAD ONLY IF CACHE MISS diff --git a/src/mritk/cli.py b/src/mritk/cli.py index 10fe43a..f657c50 100644 --- a/src/mritk/cli.py +++ b/src/mritk/cli.py @@ -9,7 +9,7 @@ from rich.logging import RichHandler from rich_argparse import RichHelpFormatter -from . import concentration, datasets, hybrid, info, looklocker, masks, mixed, napari, r1, show, statistics +from . import concentration, datasets, hybrid, info, looklocker, masks, mixed, napari, r1, segmentation, show, statistics def version_info(): @@ -75,6 +75,9 @@ def setup_parser(): napari_parser = subparsers.add_parser("napari", help="Show MRI data using napari", formatter_class=parser.formatter_class) napari.add_arguments(napari_parser) + segmentation_parser = subparsers.add_parser("seg", help="Perform segmentation tasks", formatter_class=parser.formatter_class) + segmentation.add_arguments(segmentation_parser, extra_args_cb=add_extra_arguments) + looklocker_parser = subparsers.add_parser( "looklocker", help="Process Look-Locker data", formatter_class=parser.formatter_class ) @@ -142,6 +145,8 @@ def dispatch(parser: argparse.ArgumentParser, argv: Optional[Sequence[str]] = No show.dispatch(args) elif command == "napari": napari.dispatch(args) + elif command == "seg": + segmentation.dispatch(args) elif command == "looklocker": looklocker.dispatch(args) elif command == "mask": diff --git a/src/mritk/masks.py b/src/mritk/masks.py index a50b4c6..18799e5 100644 --- a/src/mritk/masks.py +++ b/src/mritk/masks.py @@ -12,6 +12,7 @@ import skimage from .data import MRIData +from .segmentation import CSFSegmentation, Segmentation from .testing import assert_same_space @@ -81,12 +82,12 @@ def compute_csf_mask_array( return binary -def csf_mask(input: Path, connectivity: int | None = 2, use_li: bool = False) -> MRIData: +def csf_mask(input: MRIData, connectivity: int | None = 2, use_li: bool = False) -> MRIData: """ I/O wrapper for generating and saving a CSF mask from a NIfTI file. Args: - input (Path): Path to the input NIfTI image. + input (MRIData): An MRIData object containing the input volume (typically T2-weighted or Spin-Echo). connectivity (Optional[int], optional): Connectivity distance. Defaults to 2. use_li (bool, optional): If True, uses Li thresholding. Defaults to False. output (Optional[Path], optional): Path to save the resulting mask. Defaults to None. @@ -97,12 +98,10 @@ def csf_mask(input: Path, connectivity: int | None = 2, use_li: bool = False) -> Raises: AssertionError: If the resulting mask contains no voxels. """ - input_vol = MRIData.from_file(input, dtype=np.single) - mask = compute_csf_mask_array(input_vol.data, connectivity, use_li) - + mask = compute_csf_mask_array(input.data, connectivity, use_li) assert np.max(mask) > 0, "Masking failed, no voxels in mask" - mri_data = MRIData(data=mask, affine=input_vol.affine) + mri_data = MRIData(data=mask, affine=input.affine) return mri_data @@ -134,7 +133,7 @@ def compute_intracranial_mask_array(csf_mask_array: np.ndarray, segmentation_arr return ~opened_background -def intracranial_mask(csf_segmentation_path: Path, segmentation_path: Path) -> MRIData: +def intracranial_mask(segmentation: Segmentation, csf_mask: MRIData) -> MRIData: """ I/O wrapper for generating and saving an intracranial mask from NIfTI files. @@ -142,21 +141,20 @@ def intracranial_mask(csf_segmentation_path: Path, segmentation_path: Path) -> M delegates the array computation. Args: - csf_segmentation_path (Path): Path to the CSF segmentation NIfTI file. - segmentation_path (Path): Path to the brain segmentation NIfTI file. - output (Optional[Path], optional): Path to save the resulting mask. Defaults to None. + segmentation (MRIData): The refined segmentation (MRIData), \ + generated by the segmentation refinement module. + csf_mask (MRIData): The CSF mask (MRIData), generated by the csf mask module. Returns: MRIData: An MRIData object containing the intracranial mask. """ - input_csf_mask = MRIData.from_file(csf_segmentation_path, dtype=bool) - segmentation_data = MRIData.from_file(segmentation_path, dtype=bool) + csf_seg = CSFSegmentation(segmentation, csf_mask).to_csf_segmentation() # Validate spatial alignment before array operations - assert_same_space(input_csf_mask, segmentation_data) + assert_same_space(csf_seg, segmentation.mri) - mask_data = compute_intracranial_mask_array(input_csf_mask.data, segmentation_data.data) - mri_data = MRIData(data=mask_data, affine=segmentation_data.affine) + mask_data = compute_intracranial_mask_array(csf_seg.data, segmentation.mri.data) + mri_data = MRIData(data=mask_data, affine=segmentation.mri.affine) return mri_data @@ -183,8 +181,18 @@ def add_arguments( intracranial_mask_parser = subparser.add_parser( "intracranial", help="Compute intracranial mask", formatter_class=parser.formatter_class ) - intracranial_mask_parser.add_argument("--csf-segmentation-path", type=Path, help="Path to the CSF segmentation NIfTI file") - intracranial_mask_parser.add_argument("--segmentation-path", type=Path, help="Path to the brain segmentation NIfTI file") + intracranial_mask_parser.add_argument( + "--segmentation-path", + type=Path, + help="Path to refined segmentation file, generated by \ + the segmentation refinement module, i.e. mritk seg refine", + ) + intracranial_mask_parser.add_argument( + "--csf-mask-path", + type=Path, + help="Path to the CSF mask NIfTI file, generated by \ + the csf mask module, i.e. mritk mask csf", + ) intracranial_mask_parser.add_argument("-o", "--output", type=Path, help="Desired output path for the resulting mask") if extra_args_cb is not None: @@ -195,11 +203,14 @@ def add_arguments( def dispatch(args): command = args.pop("mask-command") if command == "csf": - csf_mask_data = csf_mask(input=args.pop("input"), connectivity=args.pop("connectivity"), use_li=args.pop("use_li")) + csf_mask_data = csf_mask( + input=MRIData.from_file(args.pop("input")), connectivity=args.pop("connectivity"), use_li=args.pop("use_li") + ) csf_mask_data.save(args.pop("output"), dtype=np.uint8) elif command == "intracranial": intracranial_mask_data = intracranial_mask( - csf_segmentation_path=args.pop("csf_segmentation_path"), segmentation_path=args.pop("segmentation_path") + segmentation=MRIData.from_file(args.pop("segmentation_path"), dtype=np.single), + csf_mask=MRIData.from_file(args.pop("csf_mask_path"), dtype=np.single), ) intracranial_mask_data.save(args.pop("output"), dtype=np.uint8) else: diff --git a/src/mritk/segmentation.py b/src/mritk/segmentation.py index 0f97fe5..ef8ffc5 100644 --- a/src/mritk/segmentation.py +++ b/src/mritk/segmentation.py @@ -4,17 +4,23 @@ # Copyright (C) 2026 Cécile Daversin-Catty (cecile@simula.no) # Copyright (C) 2026 Simula Research Laboratory +import argparse +import itertools import logging import os import re +from collections.abc import Callable +from dataclasses import dataclass from pathlib import Path from urllib.request import urlretrieve import numpy as np import numpy.typing as npt import pandas as pd +import scipy -from .data import MRIData, load_mri_data +from .data import MRIData, apply_affine +from .testing import assert_same_space logger = logging.getLogger(__name__) @@ -79,38 +85,97 @@ } -class Segmentation(MRIData): +@dataclass(init=False) +class Segmentation: """ Base class for MRI segmentations, linking spatial data with anatomical lookup tables. This class extends MRIData by specifically treating the image array as discrete integer labels representing Regions of Interest (ROIs). It links these numerical labels to a descriptive Lookup Table (LUT). + + Args: + mri (MRIData): The MRIData object containing the segmentation volume and affine. + lut (Optional[pd.DataFrame], optional): A pandas DataFrame mapping numerical labels + to their descriptions. If None, a default numerical mapping is generated. Defaults to None. """ - def __init__(self, data: np.ndarray, affine: np.ndarray, lut: pd.DataFrame | None = None): - """ - Initializes the Segmentation object. + mri: MRIData + lut: pd.DataFrame + label_name: str + rois: np.ndarray + + def __init__(self, mri: MRIData, lut: pd.DataFrame | None = None): + self.mri = mri + # Extract all unique active regions (ignoring 0/background) + self.rois = np.unique(self.mri.data[self.mri.data > 0]) + + if lut is None: + self.lut = pd.DataFrame({"Label": self.rois}, index=self.rois) + else: + self.lut = lut + + # Identify the primary label column dynamically + self.label_name = "Label" if "Label" in self.lut.columns else self.lut.columns[0] + + @classmethod + def from_file( + cls, seg_path: Path, dtype: npt.DTypeLike | None = None, orient: bool = True, lut_path: Path | None = None + ) -> "Segmentation": + """Loads a Segmentation from a NIfTI file. Args: - data (np.ndarray): 3D numpy array containing integer ROI labels. - affine (np.ndarray): 4x4 affine transformation matrix mapping voxel indices to physical space. - lut (Optional[pd.DataFrame], optional): A pandas DataFrame mapping numerical labels - to their descriptions. If None, a default numerical mapping is generated. Defaults to None. + seg_path (Path): The file path to the segmentation NIfTI file. + dtype (npt.DTypeLike, optional): The data type for the segmentation data. Defaults to None. + orient (bool, optional): Whether to orient the data. Defaults to True. + lut_path (Path, optional): The file path to the lookup table. Defaults to None. + Returns: + Segmentation: An instance of the Segmentation class containing the loaded + segmentation data and affine transformation. """ - super().__init__(data, affine) - self.data = self.data.astype(int) + logger.info(f"Loading segmentation from {seg_path}.") + mri = MRIData.from_file(seg_path, dtype=dtype, orient=orient) - # Extract all unique active regions (ignoring 0/background) - self.rois = np.unique(self.data[self.data > 0]) + if lut_path is None and seg_path.with_suffix(".json").exists(): + lut_path = seg_path.with_suffix(".json") - if lut is not None: - self.lut = lut + if lut_path is not None: + logger.info(f"Loading LUT from {lut_path}.") + lut = pd.read_json(lut_path) else: - self.lut = pd.DataFrame({"Label": self.rois}, index=self.rois) + rois = np.unique(mri.data[mri.data > 0]) + lut = pd.DataFrame({"Label": rois}, index=rois) - # Identify the primary label column dynamically - self._label_name = "Label" if "Label" in self.lut.columns else self.lut.columns[0] + return cls(mri=mri, lut=lut) + + def save(self, output_path: Path, dtype: npt.DTypeLike | None = None, intent_code: int = 1006, lut_path: Path | None = None): + """Saves the Segmentation to a NIfTI file. + + Args: + output_path (Path): The file path where the segmentation will be saved. + dtype (npt.DTypeLike, optional): The data type for the saved segmentation data. Defaults to None. + intent_code (int, optional): The NIfTI intent code to set in the header. Defaults to 1006 (NIFTI_INTENT_LABEL). + """ + self.mri.save(output_path, dtype=dtype, intent_code=intent_code) + if lut_path is not None: + self.lut.to_json(lut_path, orient="index") + else: + self.lut.to_json(output_path.with_suffix(".json"), orient="index") + + def set_lut(self, lut: pd.DataFrame, label_column: str = "Label"): + """Sets the Lookup Table (LUT) for the segmentation, ensuring it matches the present ROIs. + + Args: + lut (pd.DataFrame): A pandas DataFrame mapping numerical labels + to their descriptions. If None, a default numerical mapping is generated. Defaults to None. + label_column (str, optional): The name of the column in the LUT that contains the label + descriptions. Defaults to "Label". + """ + + self.lut = lut + self.label_name = label_column + if self.label_name not in self.lut.columns: + raise ValueError(f"Specified label column '{self.label_name}' not found in LUT.") @property def num_rois(self) -> int: @@ -145,7 +210,76 @@ def get_roi_labels(self, rois: npt.NDArray[np.int32] | None = None) -> pd.DataFr if not np.isin(rois, self.rois).all(): raise ValueError("Some of the provided ROIs are not present in the segmentation.") - return self.lut.loc[self.lut.index.isin(rois), [self._label_name]].rename_axis("ROI").reset_index() + return self.lut.loc[self.lut.index.isin(rois), [self.label_name]].rename_axis("ROI").reset_index() + + def resample_to_reference(self, reference_mri: MRIData) -> "Segmentation": + """ + Resamples the segmentation to match the spatial dimensions and resolution of a reference MRI. + + Args: + reference_mri (MRIData): The MRI to which the segmentation will be resampled, + for example a T1-weighted anatomical scan. + Returns: + Segmentation: A new Segmentation object containing the resampled data. + """ + + shape_in = self.mri.shape + shape_out = reference_mri.shape + + # Generate a grid of voxel indices for the output space + upsampled_indices = np.fromiter( + itertools.product(*(np.arange(ni) for ni in shape_out)), + dtype=np.dtype((int, 3)), + ) + # Get voxel indices in the input segmentation space corresponding to the output grid + seg_indices = apply_affine( + np.linalg.inv(self.mri.affine), + apply_affine(reference_mri.affine, upsampled_indices), + ) + seg_indices = np.rint(seg_indices).astype(int) + + # The two images does not necessarily share field of view. + # Remove voxels which are not located within the segmentation fov. + valid_index_mask = (seg_indices > 0).all(axis=1) * (seg_indices < shape_in).all(axis=1) + upsampled_indices = upsampled_indices[valid_index_mask] + seg_indices = seg_indices[valid_index_mask] + + seg_upsampled = np.zeros(shape_out, dtype=self.mri.data.dtype) + I_in, J_in, K_in = seg_indices.T + I_out, J_out, K_out = upsampled_indices.T + seg_upsampled[I_out, J_out, K_out] = self.mri.data[I_in, J_in, K_in] + + # return Segmentation(data=seg_upsampled, affine=reference_mri.affine, lut=self.lut) + mri = MRIData(data=seg_upsampled, affine=reference_mri.affine) + return Segmentation(mri=mri, lut=self.lut) + + def smooth(self, sigma: float, cutoff_score: float = 0.5, **kwargs) -> "Segmentation": + """ + Applies Gaussian smoothing to the segmentation labels to create a soft probabilistic map. + + Args: + sigma (float): The standard deviation for the Gaussian kernel. + cutoff_score (float, optional): A threshold to remove low-confidence voxels. Defaults to 0.5. + **kwargs: Additional keyword arguments passed to scipy.ndimage.gaussian_filter. + + Returns: + dict[str, np.ndarray]: A dictionary containing 'labels' (the smoothed segmentation) + and 'scores' (the confidence scores for each voxel). + """ + smoothed_rois = np.zeros_like(self.mri.data) + high_scores = np.zeros(self.mri.data.shape) + + for roi in self.rois: + scores = scipy.ndimage.gaussian_filter((self.mri.data == roi).astype(float), sigma=sigma, **kwargs) + is_new_high_score = scores > high_scores + smoothed_rois[is_new_high_score] = roi + high_scores[is_new_high_score] = scores[is_new_high_score] + + delete_scores = (high_scores < cutoff_score) * (self.mri.data == 0) + smoothed_rois[delete_scores] = 0 + + mri = MRIData(data=smoothed_rois, affine=self.mri.affine) + return Segmentation(mri=mri, lut=self.lut) class FreeSurferSegmentation(Segmentation): @@ -179,8 +313,8 @@ def from_file( # FreeSurfer LUTs index by the "label" column lut = lut.set_index("label") if "label" in lut.columns else lut - data, affine = load_mri_data(filepath, dtype=dtype, orient=orient) - return cls(data=data, affine=affine, lut=lut) + mri = MRIData.from_file(filepath, dtype=dtype, orient=orient) + return cls(mri=mri, lut=lut) class ExtendedFreeSurferSegmentation(FreeSurferSegmentation): @@ -218,7 +352,7 @@ def get_roi_labels(self, rois: npt.NDArray[np.int32] | None = None) -> pd.DataFr left_on="FreeSurfer_ROI", right_on="FreeSurfer_ROI", how="outer", - ).drop(columns=["FreeSurfer_ROI"])[["ROI", self._label_name, "tissue_type"]] + ).drop(columns=["FreeSurfer_ROI"])[["ROI", self.label_name, "tissue_type"]] def get_tissue_type(self, rois: npt.NDArray[np.int32] | None = None) -> pd.DataFrame: """ @@ -249,6 +383,48 @@ def get_tissue_type(self, rois: npt.NDArray[np.int32] | None = None) -> pd.DataF return ret +@dataclass +class CSFSegmentation: + """ + A specialized segmentation class for isolating Cerebrospinal Fluid (CSF) regions. + + This class combines a standard anatomical segmentation (e.g., FreeSurfer) with a + binary mask specifically targeting CSF regions. It provides functionality to + generate a new segmentation volume where only the CSF-labeled voxels are retained, + while all other voxels are set to zero. + + Args: + segmentation (Segmentation): The anatomical segmentation containing the full set of labels. + csf_mask (MRIData): A binary mask isolating the CSF regions, aligned in the same space as the segmentation. + """ + + segmentation: Segmentation + csf_mask: MRIData + + def __post_init__(self): + assert_same_space(self.segmentation.mri, self.csf_mask) + + @classmethod + def from_file(cls, segmentation_path: Path, csf_mask_path: Path) -> "CSFSegmentation": + segmentation = Segmentation.from_file(segmentation_path, dtype=np.int16) + csf_mask = MRIData.from_file(csf_mask_path, dtype=bool) + assert_same_space(segmentation.mri, csf_mask) + return cls(segmentation=segmentation, csf_mask=csf_mask) + + def to_csf_segmentation(self) -> MRIData: + """Generates a new MRIData object containing only the CSF-labeled + voxels from the original segmentation.""" + # Get interpolation operator + I, J, K = np.where(self.segmentation.mri.data != 0) + interp = scipy.interpolate.NearestNDInterpolator(np.array([I, J, K]).T, self.segmentation.mri.data[I, J, K]) + # Interpolate segmentation values at CSF mask locations + i, j, k = np.where(self.csf_mask.data != 0) + csf_seg = np.zeros_like(self.segmentation.mri.data, dtype=np.int16) + csf_seg[i, j, k] = interp(i, j, k) + + return MRIData(data=csf_seg.astype(np.int16), affine=self.csf_mask.affine) + + def default_segmentation_groups() -> dict[str, list[int]]: """ Returns the default grouping of FreeSurfer labels into brain regions. @@ -407,3 +583,80 @@ def write_lut(filename: Path, table: pd.DataFrame): # Save as tab-separated values without headers or indices newtable.to_csv(filename, sep="\t", index=False, header=False) + + +def add_arguments( + parser: argparse.ArgumentParser, + extra_args_cb: Callable[[argparse.ArgumentParser], None] | None = None, +) -> None: + subparser = parser.add_subparsers(dest="seg-command", help="Commands for segmentation processing") + + resample_parser = subparser.add_parser( + "resample", help="Resample a segmentation to match the space of a reference MRI", formatter_class=parser.formatter_class + ) + resample_parser.add_argument("-i", "--input", type=Path, help="Path to the input segmentation NIfTI file") + resample_parser.add_argument( + "-r", + "--reference", + type=Path, + help="Path to the reference MRI \ + - usually a registered T1 weighted anatomical scan", + ) + resample_parser.add_argument("-o", "--output", type=Path, help="Desired output path for the resampled segmentation") + + smooth_parser = subparser.add_parser( + "smooth", + help="Apply Gaussian smoothing to a segmentation to create a soft probabilistic map", + formatter_class=parser.formatter_class, + ) + smooth_parser.add_argument("-i", "--input", type=Path, help="Path to the input (refined) segmentation NIfTI file") + smooth_parser.add_argument("-s", "--sigma", type=float, help="Standard deviation for the Gaussian kernel used in smoothing") + smooth_parser.add_argument( + "-c", "--cutoff", type=float, default=0.5, help="Cutoff score to remove low-confidence voxels (default: 0.5)" + ) + smooth_parser.add_argument("-o", "--output", type=Path, help="Desired output path for the smoothed segmentation") + + refine_parser = subparser.add_parser( + "refine", + help="Refine a segmentation by applying Gaussian smoothing to the labels", + formatter_class=parser.formatter_class, + ) + refine_parser.add_argument("-i", "--input", type=Path, help="Path to the input segmentation NIfTI file") + refine_parser.add_argument( + "-r", + "--reference", + type=Path, + help="Path to the reference MRI \ + - usually a registered T1 weighted anatomical scan", + ) + refine_parser.add_argument("-s", "--smooth", type=float, help="Standard deviation for the Gaussian kernel used in smoothing") + refine_parser.add_argument("-o", "--output", type=Path, help="Desired output path for the refined segmentation") + + if extra_args_cb is not None: + extra_args_cb(resample_parser) + extra_args_cb(smooth_parser) + extra_args_cb(refine_parser) + + +def dispatch(args): + command = args.pop("seg-command") + if command == "resample": + print("Resampling segmentation...") + input_seg = Segmentation.from_file(args.pop("input")) + reference_mri = MRIData.from_file(args.pop("reference")) + resampled_seg = input_seg.resample_to_reference(reference_mri) + resampled_seg.save(args.pop("output"), dtype=np.int32) + + elif command == "smooth": + smoothed = Segmentation.from_file(args.pop("input")).smooth(sigma=args.pop("sigma"), cutoff_score=args.pop("cutoff")) + smoothed.save(args.pop("output"), dtype=np.int32) + + elif command == "refine": + seg = Segmentation.from_file(args.pop("input")) + refined = seg.resample_to_reference(MRIData.from_file(args.pop("reference"))) + smoothed = refined.smooth(sigma=args.pop("smooth")) + refined.mri.data = np.where(smoothed.mri.data > 0, smoothed.mri.data, refined.mri.data) + refined.save(args.pop("output"), dtype=np.int32) + + else: + raise ValueError(f"Unknown segmentation command: {command}") diff --git a/src/mritk/statistics/compute_stats.py b/src/mritk/statistics/compute_stats.py index 9fb4ca5..b5ccc10 100644 --- a/src/mritk/statistics/compute_stats.py +++ b/src/mritk/statistics/compute_stats.py @@ -219,7 +219,7 @@ def generate_stats_dataframe_rois( metadata: Optional[dict] = None, ) -> pd.DataFrame: # Verify that segmentation and MRI are in the same space - assert_same_space(seg, mri) + assert_same_space(seg.mri, mri) qoi_records = [] # Collects records related to qois roi_records = [] # Collects records related to ROIs, @@ -228,7 +228,7 @@ def generate_stats_dataframe_rois( finite_mask = np.isfinite(mri.data) for roi in tqdm.rich.tqdm(seg.roi_labels, total=len(seg.roi_labels)): # Identify rois in segmentation - region_mask = (seg.data == roi) * finite_mask + region_mask = (seg.mri.data == roi) * finite_mask # print(region_mask.shape) region_data = mri.data[region_mask] nb_nans = np.isnan(region_data).sum() @@ -239,7 +239,7 @@ def generate_stats_dataframe_rois( { "ROI": roi, "voxel_count": voxelcount, - "volume_ml": seg.voxel_ml_volume * voxelcount, + "volume_ml": seg.mri.voxel_ml_volume * voxelcount, "num_nan_values": nb_nans, } ) diff --git a/tests/conftest.py b/tests/conftest.py index be56825..599b4c2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,7 +21,7 @@ def example_segmentation() -> Segmentation: base = np.array([0, 1, 2, 3], dtype=float) seg = np.tile(base, (100, 1)) - return Segmentation(seg, affine=np.eye(4)) + return Segmentation(MRIData(data=seg, affine=np.eye(4))) @pytest.fixture diff --git a/tests/create_test_data.py b/tests/create_test_data.py index 811a335..68571b0 100644 --- a/tests/create_test_data.py +++ b/tests/create_test_data.py @@ -21,11 +21,18 @@ def main(): "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-intracranial_binary.nii.gz", "mri-processed/mri_dataset/derivatives/sub-01/ses-01/sub-01_ses-01_acq-mixed_T1map.nii.gz", "mri-processed/mri_dataset/derivatives/sub-01/ses-01/sub-01_ses-01_acq-looklocker_T1map.nii.gz", - "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-aparc+aseg_refined.nii.gz", "mri-processed/mri_processed_data/sub-01/registered/sub-01_ses-01_acq-looklocker_T1map_registered.nii.gz", "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-csf-aseg.nii.gz", + "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-csf-aparc+aseg.nii.gz", + "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-csf-wmparc.nii.gz", + "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-aseg_refined.nii.gz", + "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-aparc+aseg_refined.nii.gz", "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-wmparc_refined.nii.gz", "mri-processed/mri_processed_data/sub-01/registered/sub-01_ses-01_T2w_registered.nii.gz", + "mri-processed/mri_processed_data/sub-01/registered/sub-01_ses-01_T1w_registered.nii.gz", + "freesurfer/mri_processed_data/freesurfer/sub-01/mri/aparc+aseg.mgz", + "freesurfer/mri_processed_data/freesurfer/sub-01/mri/aseg.mgz", + "freesurfer/mri_processed_data/freesurfer/sub-01/mri/wmparc.mgz", ] for file in files: diff --git a/tests/test_masks.py b/tests/test_masks.py index 2494e43..d2bfa01 100644 --- a/tests/test_masks.py +++ b/tests/test_masks.py @@ -12,7 +12,15 @@ import numpy as np import mritk.cli -from mritk.masks import compute_csf_mask_array, compute_intracranial_mask_array, csf_mask, intracranial_mask, largest_island +from mritk.data import MRIData +from mritk.masks import ( + compute_csf_mask_array, + compute_intracranial_mask_array, + csf_mask, + intracranial_mask, + largest_island, +) +from mritk.segmentation import Segmentation from mritk.testing import compare_nifti_images @@ -118,7 +126,8 @@ def test_csf_mask_io(tmp_path): nii = nib.Nifti1Image(data, np.eye(4)) nib.save(nii, in_path) - result = csf_mask(input=in_path, use_li=True) + input_data = mritk.data.MRIData.from_file(in_path, dtype=np.single) + result = csf_mask(input=input_data, use_li=True) result.save(out_path, dtype=np.uint8) # Verify the file was physically saved to the filesystem @@ -146,7 +155,7 @@ def test_intracranial_mask_io(tmp_path): seg_data[4:6, 4:6, 4:6] = 1.0 nib.save(nib.Nifti1Image(seg_data, affine), seg_path) - result = intracranial_mask(csf_segmentation_path=csf_path, segmentation_path=seg_path) + result = intracranial_mask(segmentation=Segmentation(mri=MRIData(seg_data, affine)), csf_mask=MRIData(csf_data, affine)) result.save(out_path, dtype=np.uint8) # Verify the file was physically saved to the filesystem @@ -156,32 +165,35 @@ def test_intracranial_mask_io(tmp_path): @patch("mritk.masks.csf_mask") -def test_dispatch_csf_mask(mock_csf_mask): +@patch("mritk.data.MRIData.from_file") +def test_dispatch_csf_mask(mock_from_file, mock_csf_mask): """Test the CLI dispatch for the CSF mask command.""" - mritk.cli.main(["mask", "csf", "-i", "input.nii.gz", "--output", "mock_out.nii.gz", "--use-li", "--connectivity", "2"]) + mritk.cli.main(["mask", "csf", "-i", "input.nii.gz", "-o", "mock_out.nii.gz", "--use-li", "--connectivity", "2"]) - mock_csf_mask.assert_called_once_with(input=Path("input.nii.gz"), connectivity=2, use_li=True) + input_data = mock_from_file(Path("input.nii.gz"), dtype=np.single) + mock_csf_mask.assert_called_once_with(input=input_data, connectivity=2, use_li=True) @patch("mritk.masks.intracranial_mask") -def test_dispatch_intracranial_mask(mock_intracranial_mask): +@patch("mritk.data.MRIData.from_file") +def test_dispatch_intracranial_mask(mock_from_file, mock_intracranial_mask): """Test the CLI dispatch for the intracranial mask command.""" mritk.cli.main( [ "mask", "intracranial", - "--csf-segmentation-path", - "csf_segmentation.nii.gz", "--segmentation-path", "segmentation.nii.gz", + "--csf-mask-path", + "csf_mask.nii.gz", "-o", "ic_mask.nii.gz", ] ) - mock_intracranial_mask.assert_called_once_with( - csf_segmentation_path=Path("csf_segmentation.nii.gz"), segmentation_path=Path("segmentation.nii.gz") - ) + seg_data = mock_from_file(Path("segmentation.nii.gz"), dtype=np.single) + csf_data = mock_from_file(Path("csf_mask.nii.gz"), dtype=np.single) + mock_intracranial_mask.assert_called_once_with(segmentation=seg_data, csf_mask=csf_data) def test_csf_mask(tmp_path, mri_data_dir: Path): @@ -191,18 +203,22 @@ def test_csf_mask(tmp_path, mri_data_dir: Path): ref_output = mri_data_dir / "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-csf_binary.nii.gz" test_output = tmp_path / "output_seg-csf_binary.nii.gz" - result = csf_mask(input=input_T2w_path, use_li=use_li) + input_T2w = mritk.data.MRIData.from_file(input_T2w_path, dtype=np.single) + result = csf_mask(input=input_T2w, use_li=use_li) result.save(test_output, dtype=np.uint8) compare_nifti_images(test_output, ref_output, data_tolerance=1e-12) def test_intracranial_mask(tmp_path, mri_data_dir: Path): - csf_segmentation_path = mri_data_dir / "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-csf-aseg.nii.gz" + csf_mask_path = mri_data_dir / "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-csf_binary.nii.gz" segmentation_path = mri_data_dir / "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-wmparc_refined.nii.gz" ref_output = mri_data_dir / "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-intracranial_binary.nii.gz" test_output = tmp_path / "output_seg-intracranial_binary.nii.gz" - result = intracranial_mask(csf_segmentation_path=csf_segmentation_path, segmentation_path=segmentation_path) + input_segmentation = mritk.segmentation.Segmentation.from_file(segmentation_path, dtype=np.single) + input_csf_mask = mritk.data.MRIData.from_file(csf_mask_path, dtype=np.single) + + result = intracranial_mask(segmentation=input_segmentation, csf_mask=input_csf_mask) result.save(test_output, dtype=np.uint8) compare_nifti_images(test_output, ref_output, data_tolerance=1e-12) diff --git a/tests/test_mri_io.py b/tests/test_mri_io.py index 57aa442..c98f9a3 100644 --- a/tests/test_mri_io.py +++ b/tests/test_mri_io.py @@ -55,8 +55,9 @@ def test_load_mri_data_invalid_suffix(mri_data_dir): @pytest.mark.parametrize("orient", (True, False)) def test_load_Segmentation(tmp_path, mri_data_dir, orient: bool): input_file = mri_data_dir / "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-aparc+aseg_refined.nii.gz" - seg = Segmentation.from_file(input_file) - assert seg.data.dtype == int + seg = Segmentation.from_file(input_file, dtype=np.int32) + + assert seg.mri.data.dtype == np.int32 mri = MRIData.from_file(input_file, dtype=np.single, orient=orient) output_file = tmp_path.with_suffix(".nii.gz") mri.save(output_file, dtype=np.single) diff --git a/tests/test_segmentation.py b/tests/test_segmentation.py index a368fde..899cfd4 100644 --- a/tests/test_segmentation.py +++ b/tests/test_segmentation.py @@ -5,9 +5,12 @@ import pandas as pd import pytest +import mritk.cli +from mritk.data import MRIData from mritk.segmentation import ( LUT_REGEX, VENTRICLES, + CSFSegmentation, ExtendedFreeSurferSegmentation, Segmentation, default_segmentation_groups, @@ -20,8 +23,8 @@ def test_segmentation_initialization(example_segmentation: Segmentation): - assert example_segmentation.data.shape == (100, 4) - assert example_segmentation.affine.shape == (4, 4) + assert example_segmentation.mri.data.shape == (100, 4) + assert example_segmentation.mri.affine.shape == (4, 4) assert example_segmentation.num_rois == 3 assert set(example_segmentation.roi_labels) == {1, 2, 3} assert example_segmentation.lut.shape == (3, 1) @@ -44,11 +47,11 @@ def test_freesurfer_segmentation_labels(mri_data_dir: Path): def test_extended_freesurfer_segmentation_labels(example_segmentation: Segmentation, mri_data_dir: Path): - data = example_segmentation.data + data = example_segmentation.mri.data data[0:2, 0:2] = 10001 # csf data[3:5, 3:5] = 20001 # dura - ext_fs_seg = ExtendedFreeSurferSegmentation(data, affine=np.eye(4)) + ext_fs_seg = ExtendedFreeSurferSegmentation(MRIData(data=data, affine=np.eye(4))) labels = ext_fs_seg.get_roi_labels() assert set(labels["ROI"]) == set(ext_fs_seg.roi_labels) @@ -180,3 +183,133 @@ def test_write_lut_file_io(tmp_path): # Verify the denormalization restored the original 0-255 integers assert content[0] == "4\tLeft-Lateral-Ventricle\t120\t18\t134\t0" assert content[1] == "5\tLeft-Inf-Lat-Vent\t198\t51\t122\t0" + + +# Note : Refinement is actually testing both resampling and smoothing +# @pytest.mark.xfail( +# reason=("Call to resample_to_reference fails due to shape issue when using gonzo_roi. Needs to be investigated further.") +# ) +@pytest.mark.parametrize("seg_type", ["aparc+aseg", "aseg", "wmparc"]) +def test_segmentation_refinement(tmp_path, mri_data_dir: Path, gonzo_roi, seg_type: str): + # Get gonzo_roi from FS_segmentation + FS_seg_path = mri_data_dir / f"freesurfer/mri_processed_data/freesurfer/sub-01/mri/{seg_type}.mgz" + fs_seg = Segmentation.from_file(FS_seg_path) # MRIData type + vi = gonzo_roi.voxel_indices(affine=fs_seg.mri.affine) + v = fs_seg.mri.data[tuple(vi.T)].reshape(gonzo_roi.shape) + piece_fs_seg_data = mritk.data.MRIData(data=v, affine=gonzo_roi.affine) + + # Get gonzo_roi from reference MRI to use as reference for resampling + ref_mri_path = mri_data_dir / "mri-processed/mri_processed_data/sub-01/registered/sub-01_ses-01_T1w_registered.nii.gz" + ref_mri = MRIData.from_file(ref_mri_path, dtype=np.single) + vi = gonzo_roi.voxel_indices(affine=ref_mri.affine) + v = ref_mri.data[tuple(vi.T)].reshape(gonzo_roi.shape) + piece_ref_mri_data = mritk.data.MRIData(data=v, affine=gonzo_roi.affine) + + # Output: Refine segmentation from gonzoi_roi segmentation and ref MRI + test_output = tmp_path / "output_refined.nii.gz" + + smoothing = 1 + piece_fs_seg = Segmentation(mri=piece_fs_seg_data) + result = piece_fs_seg.resample_to_reference(piece_ref_mri_data) + smoothed = result.smooth(sigma=smoothing) + result.mri.data = smoothed.mri.data + result.save(test_output, dtype=np.int32) + + ref_output_path = mri_data_dir / f"mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-{seg_type}_refined.nii.gz" + ref_output = mritk.data.MRIData.from_file(ref_output_path, dtype=np.single) + vi = gonzo_roi.voxel_indices(affine=ref_output.affine) + v_ref = ref_output.data[tuple(vi.T)].reshape(gonzo_roi.shape) + + mritk.testing.compare_nifti_arrays(result.mri.data, v_ref, data_tolerance=1e-12) + + +@pytest.mark.parametrize("seg_type", ["aparc+aseg", "aseg", "wmparc"]) +def test_csf_segmentation(tmp_path, mri_data_dir: Path, gonzo_roi, seg_type): + """Test the CSF segmentation logic by comparing against a known reference.""" + input_seg_path = mri_data_dir / f"mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-{seg_type}_refined.nii.gz" + input_csf_mask_path = mri_data_dir / "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-csf_binary.nii.gz" + + ref_output_path = mri_data_dir / f"mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-csf-{seg_type}.nii.gz" + + input_seg = MRIData.from_file(input_seg_path, dtype=np.single) + vi = gonzo_roi.voxel_indices(affine=input_seg.affine) + v = input_seg.data[tuple(vi.T)].reshape(gonzo_roi.shape) + piece_seg_data = mritk.data.MRIData(data=v, affine=gonzo_roi.affine) + piece_seg = Segmentation(mri=piece_seg_data) + + input_csf_mask = MRIData.from_file(input_csf_mask_path, dtype=np.single) + vi = gonzo_roi.voxel_indices(affine=input_csf_mask.affine) + v = input_csf_mask.data[tuple(vi.T)].reshape(gonzo_roi.shape) + piece_csf_mask_data = mritk.data.MRIData(data=v, affine=gonzo_roi.affine) + + result = CSFSegmentation(segmentation=piece_seg, csf_mask=piece_csf_mask_data).to_csf_segmentation() + + ref_output = MRIData.from_file(ref_output_path, dtype=np.single) + vi = gonzo_roi.voxel_indices(affine=ref_output.affine) + v_ref = ref_output.data[tuple(vi.T)].reshape(gonzo_roi.shape) + + mritk.testing.compare_nifti_arrays(result.data, v_ref, data_tolerance=1e-12) + + +@patch("mritk.segmentation.MRIData") +@patch("mritk.segmentation.Segmentation") +def test_dispatch_resample(mock_seg, mock_mri_data): + """Test that dispatch correctly routes to segmentation resample.""" + + mritk.cli.main(["seg", "resample", "-i", "mock_in.nii.gz", "-r", "mock_ref.nii.gz", "-o", "mock_out.nii.gz"]) + + mock_seg.from_file.assert_called_once_with(Path("mock_in.nii.gz")) + mock_mri_data.from_file.assert_called_once_with(Path("mock_ref.nii.gz")) + + inst = mock_seg.from_file.return_value # Segmentation type instance returned by from_file + inst.resample_to_reference.assert_called_once_with(mock_mri_data.from_file.return_value) + + +@patch("mritk.segmentation.Segmentation") +def test_dispatch_smoothing(mock_seg): + """Test that dispatch correctly routes to segmentation smoothing.""" + + mritk.cli.main(["seg", "smooth", "-i", "mock_in.nii.gz", "-o", "mock_out.nii.gz", "-s", "1"]) + + mock_seg.from_file.assert_called_once_with(Path("mock_in.nii.gz")) + inst = mock_seg.from_file.return_value # Segmentation type instance returned by from_file + inst.smooth.assert_called_once_with(sigma=1.0, cutoff_score=0.5) + + +@patch("mritk.segmentation.MRIData") +@patch("mritk.segmentation.Segmentation") +def test_dispatch_refine(mock_seg, mock_mri_data): + """Test that dispatch correctly routes to segmentation refinement.""" + + # Mock the underlying data arrays to avoid TypeError in np.where + inst = mock_seg.from_file.return_value + refined_inst = inst.resample_to_reference.return_value + smoothed_inst = refined_inst.smooth.return_value + + # Setup mock numpy arrays for the attributes used in np.where + smoothed_inst.data = np.array([1]) # In case the source code bug isn't fixed yet + refined_inst.data = np.array([0]) # In case the source code bug isn't fixed yet + refined_inst.mri.data = np.array([0]) # Correct fixed access + smoothed_inst.mri.data = np.array([1]) # Correct fixed access + + mritk.cli.main( + [ + "seg", + "refine", + "-i", + "mock_in.nii.gz", + "-r", + "mock_ref.nii.gz", + "-o", + "mock_out.nii.gz", + "-s", + "1", + ] + ) + + mock_seg.from_file.assert_called_once_with(Path("mock_in.nii.gz")) + mock_mri_data.from_file.assert_called_once_with(Path("mock_ref.nii.gz")) + + inst.resample_to_reference.assert_called_once_with(mock_mri_data.from_file.return_value) + refined_inst.smooth.assert_called_once_with(sigma=1.0) + refined_inst.save.assert_called_once_with(Path("mock_out.nii.gz"), dtype=np.int32)