From 4351949c5a5d9e97d391850824cd6722dd8d9f5b Mon Sep 17 00:00:00 2001 From: Henrik Finsberg Date: Fri, 24 Apr 2026 20:41:59 +0200 Subject: [PATCH 1/4] Various fixes --- .github/workflows/setup-data.yml | 2 +- src/mritk/segmentation.py | 52 +++++++++++++++++++++------ src/mritk/statistics/compute_stats.py | 6 ++-- tests/create_test_data.py | 1 + tests/test_mri_io.py | 5 +-- tests/test_segmentation.py | 40 +++++++++++++++++---- 6 files changed, 83 insertions(+), 23 deletions(-) diff --git a/.github/workflows/setup-data.yml b/.github/workflows/setup-data.yml index 34548a8..2265a39 100644 --- a/.github/workflows/setup-data.yml +++ b/.github/workflows/setup-data.yml @@ -22,7 +22,7 @@ jobs: path: data/ # The folder you want to cache # The key determines if we have a match. # Change 'v1' to 'v2' manually to force a re-download in the future. - key: test-data-v8 + key: test-data-v9 # 2. DOWNLOAD ONLY IF CACHE MISS diff --git a/src/mritk/segmentation.py b/src/mritk/segmentation.py index 34ce39c..56cc514 100644 --- a/src/mritk/segmentation.py +++ b/src/mritk/segmentation.py @@ -122,30 +122,57 @@ def __init__(self, mri: MRIData, lut: pd.DataFrame | None = None): self.label_name = "Label" if "Label" in self.lut.columns else self.lut.columns[0] @classmethod - def from_file(cls, seg_path: Path) -> "Segmentation": + def from_file( + cls, seg_path: Path, dtype: npt.DTypeLike | None = None, orient: bool = True, lut_path: Path | None = None + ) -> "Segmentation": """Loads a Segmentation from a NIfTI file. Args: seg_path (Path): The file path to the segmentation NIfTI file. + dtype (npt.DTypeLike, optional): The data type for the segmentation data. Defaults to None. + orient (bool, optional): Whether to orient the data. Defaults to True. + lut_path (Path, optional): The file path to the lookup table. Defaults to None. Returns: Segmentation: An instance of the Segmentation class containing the loaded segmentation data and affine transformation. """ logger.info(f"Loading segmentation from {seg_path}.") - mri = MRIData.from_file(seg_path, dtype=np.single) + mri = MRIData.from_file(seg_path, dtype=dtype, orient=orient) + + if lut_path is None and seg_path.with_suffix(".json").exists(): + lut_path = seg_path.with_suffix(".json") - rois = np.unique(mri.data[mri.data > 0]) - lut = pd.DataFrame({"Label": rois}, index=rois) + if lut_path is not None: + logger.info(f"Loading LUT from {lut_path}.") + lut = pd.read_json(lut_path) + else: + rois = np.unique(mri.data[mri.data > 0]) + lut = pd.DataFrame({"Label": rois}, index=rois) return cls(mri=mri, lut=lut) + def save(self, output_path: Path, dtype: npt.DTypeLike | None = None, intent_code: int = 1006, lut_path: Path | None = None): + """Saves the Segmentation to a NIfTI file. + + Args: + output_path (Path): The file path where the segmentation will be saved. + dtype (npt.DTypeLike, optional): The data type for the saved segmentation data. Defaults to None. + intent_code (int, optional): The NIfTI intent code to set in the header. Defaults to 1006 (NIFTI_INTENT_LABEL). + """ + self.mri.save(output_path, dtype=dtype, intent_code=intent_code) + if lut_path is not None: + self.lut.to_json(lut_path, orient="index") + else: + self.lut.to_json(output_path.with_suffix(".json"), orient="index") + def set_lut(self, lut: pd.DataFrame, label_column: str = "Label"): """Sets the Lookup Table (LUT) for the segmentation, ensuring it matches the present ROIs. Args: lut (pd.DataFrame): A pandas DataFrame mapping numerical labels to their descriptions. If None, a default numerical mapping is generated. Defaults to None. - label_column (str, optional): The name of the column in the LUT that contains the label descriptions. Defaults to "Label". + label_column (str, optional): The name of the column in the LUT that contains the label + descriptions. Defaults to "Label". """ self.lut = lut @@ -188,7 +215,7 @@ def get_roi_labels(self, rois: npt.NDArray[np.int32] | None = None) -> pd.DataFr return self.lut.loc[self.lut.index.isin(rois), [self.label_name]].rename_axis("ROI").reset_index() - def resample_to_reference(self, reference_mri: MRIData): + def resample_to_reference(self, reference_mri: MRIData) -> "Segmentation": """ Resamples the segmentation to match the spatial dimensions and resolution of a reference MRI. @@ -226,9 +253,10 @@ def resample_to_reference(self, reference_mri: MRIData): seg_upsampled[I_out, J_out, K_out] = self.mri.data[I_in, J_in, K_in] # return Segmentation(data=seg_upsampled, affine=reference_mri.affine, lut=self.lut) - return MRIData(data=seg_upsampled, affine=reference_mri.affine) + mri = MRIData(data=seg_upsampled, affine=reference_mri.affine) + return Segmentation(mri=mri, lut=self.lut) - def smooth(self, sigma: float, cutoff_score: float = 0.5, **kwargs) -> MRIData: + def smooth(self, sigma: float, cutoff_score: float = 0.5, **kwargs) -> "Segmentation": """ Applies Gaussian smoothing to the segmentation labels to create a soft probabilistic map. @@ -253,7 +281,8 @@ def smooth(self, sigma: float, cutoff_score: float = 0.5, **kwargs) -> MRIData: delete_scores = (high_scores < cutoff_score) * (self.mri.data == 0) smoothed_rois[delete_scores] = 0 - return MRIData(data=smoothed_rois, affine=self.mri.affine) + mri = MRIData(data=smoothed_rois, affine=self.mri.affine) + return Segmentation(mri=mri, lut=self.lut) class FreeSurferSegmentation(Segmentation): @@ -607,14 +636,17 @@ def dispatch(args): reference_mri = MRIData.from_file(args.pop("reference")) resampled_seg = input_seg.resample_to_reference(reference_mri) resampled_seg.save(args.pop("output"), dtype=np.int32) + elif command == "smooth": smoothed = Segmentation.from_file(args.pop("input")).smooth(sigma=args.pop("sigma"), cutoff_score=args.pop("cutoff")) smoothed.save(args.pop("output"), dtype=np.int32) + elif command == "refine": seg = Segmentation.from_file(args.pop("input")) refined = seg.resample_to_reference(MRIData.from_file(args.pop("reference"))) smoothed = refined.smooth(sigma=args.pop("smooth")) - refined.data = np.where(smoothed.data > 0, smoothed.data, refined.data) + refined.mri.data = np.where(smoothed.data > 0, smoothed.data, refined.data) refined.save(args.pop("output"), dtype=np.int32) + else: raise ValueError(f"Unknown segmentation command: {command}") diff --git a/src/mritk/statistics/compute_stats.py b/src/mritk/statistics/compute_stats.py index 9fb4ca5..b5ccc10 100644 --- a/src/mritk/statistics/compute_stats.py +++ b/src/mritk/statistics/compute_stats.py @@ -219,7 +219,7 @@ def generate_stats_dataframe_rois( metadata: Optional[dict] = None, ) -> pd.DataFrame: # Verify that segmentation and MRI are in the same space - assert_same_space(seg, mri) + assert_same_space(seg.mri, mri) qoi_records = [] # Collects records related to qois roi_records = [] # Collects records related to ROIs, @@ -228,7 +228,7 @@ def generate_stats_dataframe_rois( finite_mask = np.isfinite(mri.data) for roi in tqdm.rich.tqdm(seg.roi_labels, total=len(seg.roi_labels)): # Identify rois in segmentation - region_mask = (seg.data == roi) * finite_mask + region_mask = (seg.mri.data == roi) * finite_mask # print(region_mask.shape) region_data = mri.data[region_mask] nb_nans = np.isnan(region_data).sum() @@ -239,7 +239,7 @@ def generate_stats_dataframe_rois( { "ROI": roi, "voxel_count": voxelcount, - "volume_ml": seg.voxel_ml_volume * voxelcount, + "volume_ml": seg.mri.voxel_ml_volume * voxelcount, "num_nan_values": nb_nans, } ) diff --git a/tests/create_test_data.py b/tests/create_test_data.py index e487ed3..68571b0 100644 --- a/tests/create_test_data.py +++ b/tests/create_test_data.py @@ -29,6 +29,7 @@ def main(): "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-aparc+aseg_refined.nii.gz", "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-wmparc_refined.nii.gz", "mri-processed/mri_processed_data/sub-01/registered/sub-01_ses-01_T2w_registered.nii.gz", + "mri-processed/mri_processed_data/sub-01/registered/sub-01_ses-01_T1w_registered.nii.gz", "freesurfer/mri_processed_data/freesurfer/sub-01/mri/aparc+aseg.mgz", "freesurfer/mri_processed_data/freesurfer/sub-01/mri/aseg.mgz", "freesurfer/mri_processed_data/freesurfer/sub-01/mri/wmparc.mgz", diff --git a/tests/test_mri_io.py b/tests/test_mri_io.py index 57aa442..c98f9a3 100644 --- a/tests/test_mri_io.py +++ b/tests/test_mri_io.py @@ -55,8 +55,9 @@ def test_load_mri_data_invalid_suffix(mri_data_dir): @pytest.mark.parametrize("orient", (True, False)) def test_load_Segmentation(tmp_path, mri_data_dir, orient: bool): input_file = mri_data_dir / "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-aparc+aseg_refined.nii.gz" - seg = Segmentation.from_file(input_file) - assert seg.data.dtype == int + seg = Segmentation.from_file(input_file, dtype=np.int32) + + assert seg.mri.data.dtype == np.int32 mri = MRIData.from_file(input_file, dtype=np.single, orient=orient) output_file = tmp_path.with_suffix(".nii.gz") mri.save(output_file, dtype=np.single) diff --git a/tests/test_segmentation.py b/tests/test_segmentation.py index ea77ab6..7d13493 100644 --- a/tests/test_segmentation.py +++ b/tests/test_segmentation.py @@ -187,13 +187,11 @@ def test_write_lut_file_io(tmp_path): # Note : Refinement is actually testing both resampling and smoothing -# @pytest.mark.skip(reason="Takes too long to run") -@pytest.mark.xfail( - reason=("Call to resample_to_reference fails due to shape issue when using gonzo_roi. Needs to be investigated further.") -) +# @pytest.mark.xfail( +# reason=("Call to resample_to_reference fails due to shape issue when using gonzo_roi. Needs to be investigated further.") +# ) @pytest.mark.parametrize("seg_type", ["aparc+aseg", "aseg", "wmparc"]) def test_segmentation_refinement(tmp_path, mri_data_dir: Path, gonzo_roi, seg_type: str): - # Get gonzo_roi from FS_segmentation FS_seg_path = mri_data_dir / f"freesurfer/mri_processed_data/freesurfer/sub-01/mri/{seg_type}.mgz" fs_seg = Segmentation.from_file(FS_seg_path) # MRIData type @@ -215,7 +213,7 @@ def test_segmentation_refinement(tmp_path, mri_data_dir: Path, gonzo_roi, seg_ty piece_fs_seg = Segmentation(mri=piece_fs_seg_data) result = piece_fs_seg.resample_to_reference(piece_ref_mri_data) smoothed = result.smooth(sigma=smoothing) - result.data = smoothed.data + result.mri.data = smoothed.mri.data result.save(test_output, dtype=np.int32) ref_output_path = mri_data_dir / f"mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-{seg_type}_refined.nii.gz" @@ -223,7 +221,7 @@ def test_segmentation_refinement(tmp_path, mri_data_dir: Path, gonzo_roi, seg_ty vi = gonzo_roi.voxel_indices(affine=ref_output.affine) v_ref = ref_output.data[tuple(vi.T)].reshape((*gonzo_roi.shape, -1)) - mritk.testing.compare_nifti_arrays(result.data, v_ref, data_tolerance=1e-12) + mritk.testing.compare_nifti_arrays(result.mri.data, v_ref, data_tolerance=1e-12) @pytest.mark.parametrize("seg_type", ["aparc+aseg", "aseg", "wmparc"]) @@ -265,3 +263,31 @@ def test_dispatch_smoothing(mock_seg): mock_seg.from_file.assert_called_once_with(Path("mock_in.nii.gz")) inst = mock_seg.from_file.return_value # Segmentation type instance returned by from_file inst.smooth.assert_called_once_with(sigma=1, cutoff_score=0.5) + + +@patch("mritk.segmentation.Segmentation") +@patch("mritk.data.MRIData") +def test_dispatch_refine(mock_seg, mock_mri_data): + """Test that dispatch correctly routes to segmentation refinement.""" + + mritk.cli.main( + [ + "seg", + "refine", + "-i", + "mock_in.nii.gz", + "-r", + "mock_ref.nii.gz", + "-o", + "mock_out.nii.gz", + "-s", + "1", + ] + ) + + mock_seg.from_file.assert_called_once_with(Path("mock_in.nii.gz")) + mock_mri_data.from_file.assert_called_once_with(Path("mock_ref.nii.gz"), dtype=np.single) + + inst = mock_seg.from_file.return_value # Segmentation type instance returned by from_file + inst.resample_to_reference.assert_called_once_with(mock_mri_data.from_file.return_value) + inst.smooth.assert_called_once_with(sigma=1, cutoff_score=0.5) From 78ef8e10bf1170569a272bdc9679a51a18d8b984 Mon Sep 17 00:00:00 2001 From: Henrik Finsberg Date: Fri, 24 Apr 2026 20:48:57 +0200 Subject: [PATCH 2/4] Fix shape issue in refinement test --- tests/test_segmentation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_segmentation.py b/tests/test_segmentation.py index 7d13493..27561b2 100644 --- a/tests/test_segmentation.py +++ b/tests/test_segmentation.py @@ -196,14 +196,14 @@ def test_segmentation_refinement(tmp_path, mri_data_dir: Path, gonzo_roi, seg_ty FS_seg_path = mri_data_dir / f"freesurfer/mri_processed_data/freesurfer/sub-01/mri/{seg_type}.mgz" fs_seg = Segmentation.from_file(FS_seg_path) # MRIData type vi = gonzo_roi.voxel_indices(affine=fs_seg.mri.affine) - v = fs_seg.mri.data[tuple(vi.T)].reshape((*gonzo_roi.shape, -1)) + v = fs_seg.mri.data[tuple(vi.T)].reshape(gonzo_roi.shape) piece_fs_seg_data = mritk.data.MRIData(data=v, affine=gonzo_roi.affine) # Get gonzo_roi from reference MRI to use as reference for resampling ref_mri_path = mri_data_dir / "mri-processed/mri_processed_data/sub-01/registered/sub-01_ses-01_T1w_registered.nii.gz" ref_mri = MRIData.from_file(ref_mri_path, dtype=np.single) vi = gonzo_roi.voxel_indices(affine=ref_mri.affine) - v = ref_mri.data[tuple(vi.T)].reshape((*gonzo_roi.shape, -1)) + v = ref_mri.data[tuple(vi.T)].reshape(gonzo_roi.shape) piece_ref_mri_data = mritk.data.MRIData(data=v, affine=gonzo_roi.affine) # Output: Refine segmentation from gonzoi_roi segmentation and ref MRI @@ -219,7 +219,7 @@ def test_segmentation_refinement(tmp_path, mri_data_dir: Path, gonzo_roi, seg_ty ref_output_path = mri_data_dir / f"mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-{seg_type}_refined.nii.gz" ref_output = mritk.data.MRIData.from_file(ref_output_path, dtype=np.single) vi = gonzo_roi.voxel_indices(affine=ref_output.affine) - v_ref = ref_output.data[tuple(vi.T)].reshape((*gonzo_roi.shape, -1)) + v_ref = ref_output.data[tuple(vi.T)].reshape(gonzo_roi.shape) mritk.testing.compare_nifti_arrays(result.mri.data, v_ref, data_tolerance=1e-12) From 97e4acfa14701eebb0d59cee6810ca87f3fc0277 Mon Sep 17 00:00:00 2001 From: Henrik Finsberg Date: Fri, 24 Apr 2026 20:53:36 +0200 Subject: [PATCH 3/4] Use gonzo roi for csf segmentation test --- tests/test_segmentation.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/tests/test_segmentation.py b/tests/test_segmentation.py index 27561b2..b930d60 100644 --- a/tests/test_segmentation.py +++ b/tests/test_segmentation.py @@ -20,7 +20,6 @@ validate_lut_file, write_lut, ) -from mritk.testing import compare_nifti_images def test_segmentation_initialization(example_segmentation: Segmentation): @@ -225,19 +224,30 @@ def test_segmentation_refinement(tmp_path, mri_data_dir: Path, gonzo_roi, seg_ty @pytest.mark.parametrize("seg_type", ["aparc+aseg", "aseg", "wmparc"]) -def test_csf_segmentation(tmp_path, mri_data_dir: Path, seg_type): +def test_csf_segmentation(tmp_path, mri_data_dir: Path, gonzo_roi, seg_type): """Test the CSF segmentation logic by comparing against a known reference.""" - input_T2w_path = mri_data_dir / "mri-processed/mri_processed_data/sub-01/registered/sub-01_ses-01_T2w_registered.nii.gz" + input_seg_path = mri_data_dir / f"mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-{seg_type}_refined.nii.gz" input_csf_mask_path = mri_data_dir / "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-csf_binary.nii.gz" - ref_output = mri_data_dir / f"mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-csf-{seg_type}.nii.gz" - test_output = tmp_path / f"output_seg-csf-{seg_type}.nii.gz" + ref_output_path = mri_data_dir / f"mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-csf-{seg_type}.nii.gz" + + input_seg = MRIData.from_file(input_seg_path, dtype=np.single) + vi = gonzo_roi.voxel_indices(affine=input_seg.affine) + v = input_seg.data[tuple(vi.T)].reshape(gonzo_roi.shape) + piece_seg_data = mritk.data.MRIData(data=v, affine=gonzo_roi.affine) - input_T2w = MRIData.from_file(input_T2w_path, dtype=np.single) input_csf_mask = MRIData.from_file(input_csf_mask_path, dtype=np.single) - result = CSFSegmentation(segmentation=input_T2w, csf_mask=input_csf_mask).to_csf_segmentation() - result.save(test_output, dtype=np.uint8) - compare_nifti_images(test_output, ref_output, data_tolerance=1e-12) + vi = gonzo_roi.voxel_indices(affine=input_csf_mask.affine) + v = input_csf_mask.data[tuple(vi.T)].reshape(gonzo_roi.shape) + piece_csf_mask_data = mritk.data.MRIData(data=v, affine=gonzo_roi.affine) + + result = CSFSegmentation(segmentation=piece_seg_data, csf_mask=piece_csf_mask_data).to_csf_segmentation() + + ref_output = MRIData.from_file(ref_output_path, dtype=np.single) + vi = gonzo_roi.voxel_indices(affine=ref_output.affine) + v_ref = ref_output.data[tuple(vi.T)].reshape(gonzo_roi.shape) + + mritk.testing.compare_nifti_arrays(result.data, v_ref, data_tolerance=1e-12) @patch("mritk.segmentation.Segmentation") From e5e28fed7f52b6c332d1bc136fc1f6a256259dab Mon Sep 17 00:00:00 2001 From: Henrik Finsberg Date: Fri, 24 Apr 2026 20:58:16 +0200 Subject: [PATCH 4/4] Fix dispatch tests in segmentation --- src/mritk/segmentation.py | 2 +- tests/test_segmentation.py | 27 +++++++++++++++++++-------- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/src/mritk/segmentation.py b/src/mritk/segmentation.py index 56cc514..a55e610 100644 --- a/src/mritk/segmentation.py +++ b/src/mritk/segmentation.py @@ -645,7 +645,7 @@ def dispatch(args): seg = Segmentation.from_file(args.pop("input")) refined = seg.resample_to_reference(MRIData.from_file(args.pop("reference"))) smoothed = refined.smooth(sigma=args.pop("smooth")) - refined.mri.data = np.where(smoothed.data > 0, smoothed.data, refined.data) + refined.mri.data = np.where(smoothed.mri.data > 0, smoothed.mri.data, refined.mri.data) refined.save(args.pop("output"), dtype=np.int32) else: diff --git a/tests/test_segmentation.py b/tests/test_segmentation.py index b930d60..aa12d83 100644 --- a/tests/test_segmentation.py +++ b/tests/test_segmentation.py @@ -250,15 +250,15 @@ def test_csf_segmentation(tmp_path, mri_data_dir: Path, gonzo_roi, seg_type): mritk.testing.compare_nifti_arrays(result.data, v_ref, data_tolerance=1e-12) +@patch("mritk.segmentation.MRIData") @patch("mritk.segmentation.Segmentation") -@patch("mritk.data.MRIData") def test_dispatch_resample(mock_seg, mock_mri_data): """Test that dispatch correctly routes to segmentation resample.""" mritk.cli.main(["seg", "resample", "-i", "mock_in.nii.gz", "-r", "mock_ref.nii.gz", "-o", "mock_out.nii.gz"]) - mock_seg.from_file.assert_called_once_with(Path("mock_in.nii.gz"), dtype=np.single) - mock_mri_data.from_file.assert_called_once_with(Path("mock_ref.nii.gz"), dtype=np.single) + mock_seg.from_file.assert_called_once_with(Path("mock_in.nii.gz")) + mock_mri_data.from_file.assert_called_once_with(Path("mock_ref.nii.gz")) inst = mock_seg.from_file.return_value # Segmentation type instance returned by from_file inst.resample_to_reference.assert_called_once_with(mock_mri_data.from_file.return_value) @@ -272,14 +272,25 @@ def test_dispatch_smoothing(mock_seg): mock_seg.from_file.assert_called_once_with(Path("mock_in.nii.gz")) inst = mock_seg.from_file.return_value # Segmentation type instance returned by from_file - inst.smooth.assert_called_once_with(sigma=1, cutoff_score=0.5) + inst.smooth.assert_called_once_with(sigma=1.0, cutoff_score=0.5) +@patch("mritk.segmentation.MRIData") @patch("mritk.segmentation.Segmentation") -@patch("mritk.data.MRIData") def test_dispatch_refine(mock_seg, mock_mri_data): """Test that dispatch correctly routes to segmentation refinement.""" + # Mock the underlying data arrays to avoid TypeError in np.where + inst = mock_seg.from_file.return_value + refined_inst = inst.resample_to_reference.return_value + smoothed_inst = refined_inst.smooth.return_value + + # Setup mock numpy arrays for the attributes used in np.where + smoothed_inst.data = np.array([1]) # In case the source code bug isn't fixed yet + refined_inst.data = np.array([0]) # In case the source code bug isn't fixed yet + refined_inst.mri.data = np.array([0]) # Correct fixed access + smoothed_inst.mri.data = np.array([1]) # Correct fixed access + mritk.cli.main( [ "seg", @@ -296,8 +307,8 @@ def test_dispatch_refine(mock_seg, mock_mri_data): ) mock_seg.from_file.assert_called_once_with(Path("mock_in.nii.gz")) - mock_mri_data.from_file.assert_called_once_with(Path("mock_ref.nii.gz"), dtype=np.single) + mock_mri_data.from_file.assert_called_once_with(Path("mock_ref.nii.gz")) - inst = mock_seg.from_file.return_value # Segmentation type instance returned by from_file inst.resample_to_reference.assert_called_once_with(mock_mri_data.from_file.return_value) - inst.smooth.assert_called_once_with(sigma=1, cutoff_score=0.5) + refined_inst.smooth.assert_called_once_with(sigma=1.0) + refined_inst.save.assert_called_once_with(Path("mock_out.nii.gz"), dtype=np.int32)