Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/setup-data.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
path: data/ # The folder you want to cache
# The key determines if we have a match.
# Change 'v1' to 'v2' manually to force a re-download in the future.
key: test-data-v8
key: test-data-v9


# 2. DOWNLOAD ONLY IF CACHE MISS
Expand Down
52 changes: 42 additions & 10 deletions src/mritk/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,30 +122,57 @@ def __init__(self, mri: MRIData, lut: pd.DataFrame | None = None):
self.label_name = "Label" if "Label" in self.lut.columns else self.lut.columns[0]

@classmethod
def from_file(cls, seg_path: Path) -> "Segmentation":
def from_file(
cls, seg_path: Path, dtype: npt.DTypeLike | None = None, orient: bool = True, lut_path: Path | None = None
) -> "Segmentation":
"""Loads a Segmentation from a NIfTI file.

Args:
seg_path (Path): The file path to the segmentation NIfTI file.
dtype (npt.DTypeLike, optional): The data type for the segmentation data. Defaults to None.
orient (bool, optional): Whether to orient the data. Defaults to True.
lut_path (Path, optional): The file path to the lookup table. Defaults to None.
Returns:
Segmentation: An instance of the Segmentation class containing the loaded
segmentation data and affine transformation.
"""
logger.info(f"Loading segmentation from {seg_path}.")
mri = MRIData.from_file(seg_path, dtype=np.single)
mri = MRIData.from_file(seg_path, dtype=dtype, orient=orient)

if lut_path is None and seg_path.with_suffix(".json").exists():
lut_path = seg_path.with_suffix(".json")

rois = np.unique(mri.data[mri.data > 0])
lut = pd.DataFrame({"Label": rois}, index=rois)
if lut_path is not None:
logger.info(f"Loading LUT from {lut_path}.")
lut = pd.read_json(lut_path)
else:
rois = np.unique(mri.data[mri.data > 0])
lut = pd.DataFrame({"Label": rois}, index=rois)

return cls(mri=mri, lut=lut)

def save(self, output_path: Path, dtype: npt.DTypeLike | None = None, intent_code: int = 1006, lut_path: Path | None = None):
"""Saves the Segmentation to a NIfTI file.

Args:
output_path (Path): The file path where the segmentation will be saved.
dtype (npt.DTypeLike, optional): The data type for the saved segmentation data. Defaults to None.
intent_code (int, optional): The NIfTI intent code to set in the header. Defaults to 1006 (NIFTI_INTENT_LABEL).
"""
self.mri.save(output_path, dtype=dtype, intent_code=intent_code)
if lut_path is not None:
self.lut.to_json(lut_path, orient="index")
else:
self.lut.to_json(output_path.with_suffix(".json"), orient="index")

def set_lut(self, lut: pd.DataFrame, label_column: str = "Label"):
"""Sets the Lookup Table (LUT) for the segmentation, ensuring it matches the present ROIs.

Args:
lut (pd.DataFrame): A pandas DataFrame mapping numerical labels
to their descriptions. If None, a default numerical mapping is generated. Defaults to None.
label_column (str, optional): The name of the column in the LUT that contains the label descriptions. Defaults to "Label".
label_column (str, optional): The name of the column in the LUT that contains the label
descriptions. Defaults to "Label".
"""

self.lut = lut
Expand Down Expand Up @@ -188,7 +215,7 @@ def get_roi_labels(self, rois: npt.NDArray[np.int32] | None = None) -> pd.DataFr

return self.lut.loc[self.lut.index.isin(rois), [self.label_name]].rename_axis("ROI").reset_index()

def resample_to_reference(self, reference_mri: MRIData):
def resample_to_reference(self, reference_mri: MRIData) -> "Segmentation":
"""
Resamples the segmentation to match the spatial dimensions and resolution of a reference MRI.

Expand Down Expand Up @@ -226,9 +253,10 @@ def resample_to_reference(self, reference_mri: MRIData):
seg_upsampled[I_out, J_out, K_out] = self.mri.data[I_in, J_in, K_in]

# return Segmentation(data=seg_upsampled, affine=reference_mri.affine, lut=self.lut)
return MRIData(data=seg_upsampled, affine=reference_mri.affine)
mri = MRIData(data=seg_upsampled, affine=reference_mri.affine)
return Segmentation(mri=mri, lut=self.lut)

def smooth(self, sigma: float, cutoff_score: float = 0.5, **kwargs) -> MRIData:
def smooth(self, sigma: float, cutoff_score: float = 0.5, **kwargs) -> "Segmentation":
"""
Applies Gaussian smoothing to the segmentation labels to create a soft probabilistic map.

Expand All @@ -253,7 +281,8 @@ def smooth(self, sigma: float, cutoff_score: float = 0.5, **kwargs) -> MRIData:
delete_scores = (high_scores < cutoff_score) * (self.mri.data == 0)
smoothed_rois[delete_scores] = 0

return MRIData(data=smoothed_rois, affine=self.mri.affine)
mri = MRIData(data=smoothed_rois, affine=self.mri.affine)
return Segmentation(mri=mri, lut=self.lut)


class FreeSurferSegmentation(Segmentation):
Expand Down Expand Up @@ -607,14 +636,17 @@ def dispatch(args):
reference_mri = MRIData.from_file(args.pop("reference"))
resampled_seg = input_seg.resample_to_reference(reference_mri)
resampled_seg.save(args.pop("output"), dtype=np.int32)

elif command == "smooth":
smoothed = Segmentation.from_file(args.pop("input")).smooth(sigma=args.pop("sigma"), cutoff_score=args.pop("cutoff"))
smoothed.save(args.pop("output"), dtype=np.int32)

elif command == "refine":
seg = Segmentation.from_file(args.pop("input"))
refined = seg.resample_to_reference(MRIData.from_file(args.pop("reference")))
smoothed = refined.smooth(sigma=args.pop("smooth"))
refined.data = np.where(smoothed.data > 0, smoothed.data, refined.data)
refined.mri.data = np.where(smoothed.mri.data > 0, smoothed.mri.data, refined.mri.data)
refined.save(args.pop("output"), dtype=np.int32)

else:
raise ValueError(f"Unknown segmentation command: {command}")
6 changes: 3 additions & 3 deletions src/mritk/statistics/compute_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def generate_stats_dataframe_rois(
metadata: Optional[dict] = None,
) -> pd.DataFrame:
# Verify that segmentation and MRI are in the same space
assert_same_space(seg, mri)
assert_same_space(seg.mri, mri)

qoi_records = [] # Collects records related to qois
roi_records = [] # Collects records related to ROIs,
Expand All @@ -228,7 +228,7 @@ def generate_stats_dataframe_rois(
finite_mask = np.isfinite(mri.data)
for roi in tqdm.rich.tqdm(seg.roi_labels, total=len(seg.roi_labels)):
# Identify rois in segmentation
region_mask = (seg.data == roi) * finite_mask
region_mask = (seg.mri.data == roi) * finite_mask
# print(region_mask.shape)
region_data = mri.data[region_mask]
nb_nans = np.isnan(region_data).sum()
Expand All @@ -239,7 +239,7 @@ def generate_stats_dataframe_rois(
{
"ROI": roi,
"voxel_count": voxelcount,
"volume_ml": seg.voxel_ml_volume * voxelcount,
"volume_ml": seg.mri.voxel_ml_volume * voxelcount,
"num_nan_values": nb_nans,
}
)
Expand Down
1 change: 1 addition & 0 deletions tests/create_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def main():
"mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-aparc+aseg_refined.nii.gz",
"mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-wmparc_refined.nii.gz",
"mri-processed/mri_processed_data/sub-01/registered/sub-01_ses-01_T2w_registered.nii.gz",
"mri-processed/mri_processed_data/sub-01/registered/sub-01_ses-01_T1w_registered.nii.gz",
"freesurfer/mri_processed_data/freesurfer/sub-01/mri/aparc+aseg.mgz",
"freesurfer/mri_processed_data/freesurfer/sub-01/mri/aseg.mgz",
"freesurfer/mri_processed_data/freesurfer/sub-01/mri/wmparc.mgz",
Expand Down
5 changes: 3 additions & 2 deletions tests/test_mri_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,9 @@ def test_load_mri_data_invalid_suffix(mri_data_dir):
@pytest.mark.parametrize("orient", (True, False))
def test_load_Segmentation(tmp_path, mri_data_dir, orient: bool):
input_file = mri_data_dir / "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-aparc+aseg_refined.nii.gz"
seg = Segmentation.from_file(input_file)
assert seg.data.dtype == int
seg = Segmentation.from_file(input_file, dtype=np.int32)

assert seg.mri.data.dtype == np.int32
mri = MRIData.from_file(input_file, dtype=np.single, orient=orient)
output_file = tmp_path.with_suffix(".nii.gz")
mri.save(output_file, dtype=np.single)
93 changes: 70 additions & 23 deletions tests/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
validate_lut_file,
write_lut,
)
from mritk.testing import compare_nifti_images


def test_segmentation_initialization(example_segmentation: Segmentation):
Expand Down Expand Up @@ -187,25 +186,23 @@ def test_write_lut_file_io(tmp_path):


# Note : Refinement is actually testing both resampling and smoothing
# @pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.xfail(
reason=("Call to resample_to_reference fails due to shape issue when using gonzo_roi. Needs to be investigated further.")
)
# @pytest.mark.xfail(
# reason=("Call to resample_to_reference fails due to shape issue when using gonzo_roi. Needs to be investigated further.")
# )
@pytest.mark.parametrize("seg_type", ["aparc+aseg", "aseg", "wmparc"])
def test_segmentation_refinement(tmp_path, mri_data_dir: Path, gonzo_roi, seg_type: str):

# Get gonzo_roi from FS_segmentation
FS_seg_path = mri_data_dir / f"freesurfer/mri_processed_data/freesurfer/sub-01/mri/{seg_type}.mgz"
fs_seg = Segmentation.from_file(FS_seg_path) # MRIData type
vi = gonzo_roi.voxel_indices(affine=fs_seg.mri.affine)
v = fs_seg.mri.data[tuple(vi.T)].reshape((*gonzo_roi.shape, -1))
v = fs_seg.mri.data[tuple(vi.T)].reshape(gonzo_roi.shape)
piece_fs_seg_data = mritk.data.MRIData(data=v, affine=gonzo_roi.affine)

# Get gonzo_roi from reference MRI to use as reference for resampling
ref_mri_path = mri_data_dir / "mri-processed/mri_processed_data/sub-01/registered/sub-01_ses-01_T1w_registered.nii.gz"
ref_mri = MRIData.from_file(ref_mri_path, dtype=np.single)
vi = gonzo_roi.voxel_indices(affine=ref_mri.affine)
v = ref_mri.data[tuple(vi.T)].reshape((*gonzo_roi.shape, -1))
v = ref_mri.data[tuple(vi.T)].reshape(gonzo_roi.shape)
piece_ref_mri_data = mritk.data.MRIData(data=v, affine=gonzo_roi.affine)

# Output: Refine segmentation from gonzoi_roi segmentation and ref MRI
Expand All @@ -215,42 +212,53 @@ def test_segmentation_refinement(tmp_path, mri_data_dir: Path, gonzo_roi, seg_ty
piece_fs_seg = Segmentation(mri=piece_fs_seg_data)
result = piece_fs_seg.resample_to_reference(piece_ref_mri_data)
smoothed = result.smooth(sigma=smoothing)
result.data = smoothed.data
result.mri.data = smoothed.mri.data
result.save(test_output, dtype=np.int32)

ref_output_path = mri_data_dir / f"mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-{seg_type}_refined.nii.gz"
ref_output = mritk.data.MRIData.from_file(ref_output_path, dtype=np.single)
vi = gonzo_roi.voxel_indices(affine=ref_output.affine)
v_ref = ref_output.data[tuple(vi.T)].reshape((*gonzo_roi.shape, -1))
v_ref = ref_output.data[tuple(vi.T)].reshape(gonzo_roi.shape)

mritk.testing.compare_nifti_arrays(result.data, v_ref, data_tolerance=1e-12)
mritk.testing.compare_nifti_arrays(result.mri.data, v_ref, data_tolerance=1e-12)


@pytest.mark.parametrize("seg_type", ["aparc+aseg", "aseg", "wmparc"])
def test_csf_segmentation(tmp_path, mri_data_dir: Path, seg_type):
def test_csf_segmentation(tmp_path, mri_data_dir: Path, gonzo_roi, seg_type):
"""Test the CSF segmentation logic by comparing against a known reference."""
input_T2w_path = mri_data_dir / "mri-processed/mri_processed_data/sub-01/registered/sub-01_ses-01_T2w_registered.nii.gz"
input_seg_path = mri_data_dir / f"mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-{seg_type}_refined.nii.gz"
input_csf_mask_path = mri_data_dir / "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-csf_binary.nii.gz"

ref_output = mri_data_dir / f"mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-csf-{seg_type}.nii.gz"
test_output = tmp_path / f"output_seg-csf-{seg_type}.nii.gz"
ref_output_path = mri_data_dir / f"mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-csf-{seg_type}.nii.gz"

input_seg = MRIData.from_file(input_seg_path, dtype=np.single)
vi = gonzo_roi.voxel_indices(affine=input_seg.affine)
v = input_seg.data[tuple(vi.T)].reshape(gonzo_roi.shape)
piece_seg_data = mritk.data.MRIData(data=v, affine=gonzo_roi.affine)

input_T2w = MRIData.from_file(input_T2w_path, dtype=np.single)
input_csf_mask = MRIData.from_file(input_csf_mask_path, dtype=np.single)
result = CSFSegmentation(segmentation=input_T2w, csf_mask=input_csf_mask).to_csf_segmentation()
result.save(test_output, dtype=np.uint8)
compare_nifti_images(test_output, ref_output, data_tolerance=1e-12)
vi = gonzo_roi.voxel_indices(affine=input_csf_mask.affine)
v = input_csf_mask.data[tuple(vi.T)].reshape(gonzo_roi.shape)
piece_csf_mask_data = mritk.data.MRIData(data=v, affine=gonzo_roi.affine)

result = CSFSegmentation(segmentation=piece_seg_data, csf_mask=piece_csf_mask_data).to_csf_segmentation()

ref_output = MRIData.from_file(ref_output_path, dtype=np.single)
vi = gonzo_roi.voxel_indices(affine=ref_output.affine)
v_ref = ref_output.data[tuple(vi.T)].reshape(gonzo_roi.shape)

mritk.testing.compare_nifti_arrays(result.data, v_ref, data_tolerance=1e-12)


@patch("mritk.segmentation.MRIData")
@patch("mritk.segmentation.Segmentation")
@patch("mritk.data.MRIData")
def test_dispatch_resample(mock_seg, mock_mri_data):
"""Test that dispatch correctly routes to segmentation resample."""

mritk.cli.main(["seg", "resample", "-i", "mock_in.nii.gz", "-r", "mock_ref.nii.gz", "-o", "mock_out.nii.gz"])

mock_seg.from_file.assert_called_once_with(Path("mock_in.nii.gz"), dtype=np.single)
mock_mri_data.from_file.assert_called_once_with(Path("mock_ref.nii.gz"), dtype=np.single)
mock_seg.from_file.assert_called_once_with(Path("mock_in.nii.gz"))
mock_mri_data.from_file.assert_called_once_with(Path("mock_ref.nii.gz"))

inst = mock_seg.from_file.return_value # Segmentation type instance returned by from_file
inst.resample_to_reference.assert_called_once_with(mock_mri_data.from_file.return_value)
Expand All @@ -264,4 +272,43 @@ def test_dispatch_smoothing(mock_seg):

mock_seg.from_file.assert_called_once_with(Path("mock_in.nii.gz"))
inst = mock_seg.from_file.return_value # Segmentation type instance returned by from_file
inst.smooth.assert_called_once_with(sigma=1, cutoff_score=0.5)
inst.smooth.assert_called_once_with(sigma=1.0, cutoff_score=0.5)


@patch("mritk.segmentation.MRIData")
@patch("mritk.segmentation.Segmentation")
def test_dispatch_refine(mock_seg, mock_mri_data):
"""Test that dispatch correctly routes to segmentation refinement."""

# Mock the underlying data arrays to avoid TypeError in np.where
inst = mock_seg.from_file.return_value
refined_inst = inst.resample_to_reference.return_value
smoothed_inst = refined_inst.smooth.return_value

# Setup mock numpy arrays for the attributes used in np.where
smoothed_inst.data = np.array([1]) # In case the source code bug isn't fixed yet
refined_inst.data = np.array([0]) # In case the source code bug isn't fixed yet
refined_inst.mri.data = np.array([0]) # Correct fixed access
smoothed_inst.mri.data = np.array([1]) # Correct fixed access

mritk.cli.main(
[
"seg",
"refine",
"-i",
"mock_in.nii.gz",
"-r",
"mock_ref.nii.gz",
"-o",
"mock_out.nii.gz",
"-s",
"1",
]
)

mock_seg.from_file.assert_called_once_with(Path("mock_in.nii.gz"))
mock_mri_data.from_file.assert_called_once_with(Path("mock_ref.nii.gz"))

inst.resample_to_reference.assert_called_once_with(mock_mri_data.from_file.return_value)
refined_inst.smooth.assert_called_once_with(sigma=1.0)
refined_inst.save.assert_called_once_with(Path("mock_out.nii.gz"), dtype=np.int32)
Loading