Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 31 additions & 18 deletions src/mritk/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,35 +85,35 @@
}


@dataclass
@dataclass(init=False)
class Segmentation:
"""
Base class for MRI segmentations, linking spatial data with anatomical lookup tables.

This class extends MRIData by specifically treating the image array as discrete
integer labels representing Regions of Interest (ROIs). It links these numerical
labels to a descriptive Lookup Table (LUT).

Args:
mri (MRIData): The MRIData object containing the segmentation volume and affine.
lut (Optional[pd.DataFrame], optional): A pandas DataFrame mapping numerical labels
to their descriptions. If None, a default numerical mapping is generated. Defaults to None.
"""

def __init__(self, mri: MRIData, lut: pd.DataFrame | None = None):
"""
Initializes the Segmentation object.
mri: MRIData
lut: pd.DataFrame
label_name: str
rois: np.ndarray

Args:
data (np.ndarray): 3D numpy array containing integer ROI labels.
affine (np.ndarray): 4x4 affine transformation matrix mapping voxel indices to physical space.
lut (Optional[pd.DataFrame], optional): A pandas DataFrame mapping numerical labels
to their descriptions. If None, a default numerical mapping is generated. Defaults to None.
"""
def __init__(self, mri: MRIData, lut: pd.DataFrame | None = None):
self.mri = mri

# Extract all unique active regions (ignoring 0/background)
self.rois = np.unique(self.mri.data[self.mri.data > 0])

if lut is not None:
self.lut = lut
else:
if lut is None:
self.lut = pd.DataFrame({"Label": self.rois}, index=self.rois)
else:
self.lut = lut

# Identify the primary label column dynamically
self.label_name = "Label" if "Label" in self.lut.columns else self.lut.columns[0]
Expand Down Expand Up @@ -385,13 +385,24 @@ def get_tissue_type(self, rois: npt.NDArray[np.int32] | None = None) -> pd.DataF

@dataclass
class CSFSegmentation:
"""
A specialized segmentation class for isolating Cerebrospinal Fluid (CSF) regions.

This class combines a standard anatomical segmentation (e.g., FreeSurfer) with a
binary mask specifically targeting CSF regions. It provides functionality to
generate a new segmentation volume where only the CSF-labeled voxels are retained,
while all other voxels are set to zero.

Args:
segmentation (Segmentation): The anatomical segmentation containing the full set of labels.
csf_mask (MRIData): A binary mask isolating the CSF regions, aligned in the same space as the segmentation.
"""

segmentation: Segmentation
csf_mask: MRIData

def __init__(self, segmentation: Segmentation, csf_mask: MRIData):
assert_same_space(segmentation.mri, csf_mask)
self.segmentation = segmentation
self.csf_mask = csf_mask
def __post_init__(self):
assert_same_space(self.segmentation.mri, self.csf_mask)

@classmethod
def from_file(cls, segmentation_path: Path, csf_mask_path: Path) -> "CSFSegmentation":
Expand All @@ -401,6 +412,8 @@ def from_file(cls, segmentation_path: Path, csf_mask_path: Path) -> "CSFSegmenta
return cls(segmentation=segmentation, csf_mask=csf_mask)

def to_csf_segmentation(self) -> MRIData:
"""Generates a new MRIData object containing only the CSF-labeled
voxels from the original segmentation."""
# Get interpolation operator
I, J, K = np.where(self.segmentation.mri.data != 0)
interp = scipy.interpolate.NearestNDInterpolator(np.array([I, J, K]).T, self.segmentation.mri.data[I, J, K])
Expand Down
Loading