1717from TPTBox .core .nii_poi_abstract import Has_Grid
1818from TPTBox .core .nii_wrapper_math import NII_Math
1919from TPTBox .core .np_utils import (
20+ _pad_to_parameters ,
2021 np_calc_boundary_mask ,
2122 np_calc_convex_hull ,
2223 np_calc_overlapping_labels ,
2526 np_connected_components ,
2627 np_dilate_msk ,
2728 np_erode_msk ,
29+ np_extract_label ,
2830 np_fill_holes ,
31+ np_fill_holes_global_with_majority_voting ,
2932 np_get_connected_components_center_of_mass ,
3033 np_get_largest_k_connected_components ,
3134 np_map_labels ,
35+ np_map_labels_based_on_majority_label_mask_overlap ,
3236 np_point_coordinates ,
3337 np_smooth_gaussian_labelwise ,
3438 np_unique ,
@@ -670,7 +674,7 @@ def compute_crop(self,minimum: float=0, dist: float = 0, use_mm=False, other_cro
670674 #origin_shift = tuple([int(ex_slice[i].start) for i in range(len(ex_slice))])
671675 return tuple (ex_slice )# type: ignore
672676
673- def apply_center_crop (self , center_shape : tuple [int ,int ,int ], verbose : bool = False ):
677+ def apply_center_crop (self , center_shape : tuple [int ,int ,int ], inplace = False , verbose : bool = False ):
674678 shp_x , shp_y , shp_z = self .shape
675679 crop_x , crop_y , crop_z = center_shape
676680 arr = self .get_array ()
@@ -698,7 +702,8 @@ def apply_center_crop(self, center_shape: tuple[int,int,int], verbose: bool = Fa
698702 log .print (f"Center cropped from { arr_padded .shape } to { arr_cropped .shape } " , verbose = verbose )
699703 shp_x , shp_y , shp_z = arr_cropped .shape
700704 assert crop_x == shp_x and crop_y == shp_y and crop_z == shp_z
701- return self .set_array (arr_cropped )
705+ return self .set_array (arr_cropped , inplace = inplace )
706+ #return self.apply_crop(crop_slices, inplace=inplace)
702707
703708 def apply_crop_slice (self ,* args ,** qargs ):
704709 import warnings
@@ -733,22 +738,7 @@ def apply_crop_(self,ex_slice:tuple[slice,slice,slice]|Sequence[slice]):
733738 def pad_to (self ,target_shape :list [int ]| tuple [int ,int ,int ] | Self , mode :MODES = "constant" ,crop = False ,inplace = False ):
734739 if isinstance (target_shape , NII ):
735740 target_shape = target_shape .shape
736- padding = []
737- crop = []
738- requires_crop = False
739- for in_size , out_size in zip (self .shape [- 3 :], target_shape [- 3 :],strict = True ):
740- to_pad_size = max (0 , out_size - in_size ) / 2.0
741- to_crop_size = - min (0 , out_size - in_size ) / 2.0
742- padding .extend ([(ceil (to_pad_size ), floor (to_pad_size ))])
743- if to_crop_size == 0 :
744- crop .append (slice (None ))
745- else :
746- end = - floor (to_crop_size )
747- if end == 0 :
748- end = None
749- crop .append (slice (ceil (to_crop_size ), end ))
750- requires_crop = True
751-
741+ padding , crop , requires_crop = _pad_to_parameters (self .shape , target_shape )
752742 s = self
753743 if crop and requires_crop :
754744 s = s .apply_crop (tuple (crop ),inplace = inplace )
@@ -1115,26 +1105,13 @@ def erode_msk(self, n_pixel: int = 5, labels: LABEL_REFERENCE = None, connectivi
11151105
11161106 """
11171107 log .print ("erode mask" ,end = '\r ' ,verbose = verbose )
1118- if use_crop :
1119- try :
1120- crop = (self if labels is None else self .extract_label (labels )).compute_crop (dist = 1 )
1121- except ValueError :
1122- return self if inplace else self .copy ()
1123-
1124- msk_i_data_org = self .get_seg_array ()
1125- msk_i_data = msk_i_data_org [crop ]
1126- else :
1127- msk_i_data = self .get_seg_array ()
1108+ msk_i_data = self .get_seg_array ()
11281109 labels = self .unique () if labels is None else labels
11291110 if isinstance (ignore_direction ,str ):
11301111 ignore_direction = self .get_axis (ignore_direction )
1131- out = np_erode_msk (msk_i_data , label_ref = labels , mm = n_pixel , connectivity = connectivity ,border_value = border_value ,ignore_axis = ignore_direction )
1112+ out = np_erode_msk (msk_i_data , label_ref = labels , n_pixel = n_pixel , connectivity = connectivity ,border_value = border_value ,ignore_axis = ignore_direction , use_crop = use_crop )
11321113 out = out .astype (self .dtype )
11331114 log .print ("Mask eroded by" , n_pixel , "voxels" ,verbose = verbose )
1134-
1135- if use_crop :
1136- msk_i_data_org [crop ] = out
1137- out = msk_i_data_org
11381115 return self .set_array (out ,inplace = inplace )
11391116
11401117 def erode_msk_ (self , n_pixel :int = 5 , labels : LABEL_REFERENCE = None , connectivity : int = 3 , verbose :logging = True ,border_value = 0 ,use_crop = True ,ignore_direction :DIRECTIONS | int | None = None ):
@@ -1162,24 +1139,13 @@ def dilate_msk(self, n_pixel: int = 5, labels: LABEL_REFERENCE = None, connectiv
11621139 log .print ("dilate mask" ,end = '\r ' ,verbose = verbose )
11631140 if labels is None :
11641141 labels = self .unique ()
1165- if use_crop :
1166- try :
1167- crop = (self if labels is None else self .extract_label (labels )).compute_crop (dist = 1 + n_pixel )
1168- except ValueError :
1169- return self if inplace else self .copy ()
1170- msk_i_data_org = self .get_seg_array ()
1171- msk_i_data = msk_i_data_org [crop ]
1172- else :
1173- msk_i_data = self .get_seg_array ()
1142+ msk_i_data = self .get_seg_array ()
11741143 mask_ = mask .get_seg_array () if mask is not None else None
11751144 if isinstance (ignore_direction ,str ):
11761145 ignore_direction = self .get_axis (ignore_direction )
1177- out = np_dilate_msk (arr = msk_i_data , label_ref = labels , mm = n_pixel , mask = mask_ , connectivity = connectivity ,ignore_axis = ignore_direction )
1146+ out = np_dilate_msk (arr = msk_i_data , label_ref = labels , n_pixel = n_pixel , mask = mask_ , connectivity = connectivity ,ignore_axis = ignore_direction , use_crop = use_crop )
11781147 out = out .astype (self .dtype )
11791148 log .print ("Mask dilated by" , n_pixel , "voxels" ,verbose = verbose )
1180- if use_crop :
1181- msk_i_data_org [crop ] = out
1182- out = msk_i_data_org
11831149 return self .set_array (out ,inplace = inplace )
11841150
11851151 def dilate_msk_ (self , n_pixel :int = 5 , labels : LABEL_REFERENCE = None , connectivity : int = 3 , mask : Self | None = None , verbose :logging = True ,use_crop = True ,ignore_direction :DIRECTIONS | int | None = None ):
@@ -1218,7 +1184,7 @@ def fill_holes(self, labels: LABEL_REFERENCE = None, slice_wise_dim: int|str | N
12181184 if isinstance (slice_wise_dim ,str ):
12191185 slice_wise_dim = self .get_axis (slice_wise_dim )
12201186 #seg_arr = self.get_seg_array()
1221- filled = np_fill_holes (seg_arr , label_ref = labels , slice_wise_dim = slice_wise_dim )
1187+ filled = np_fill_holes (seg_arr , label_ref = labels , slice_wise_dim = slice_wise_dim , use_crop = use_crop )
12221188 if use_crop :
12231189 msk_i_data_org [crop ] = filled
12241190 filled = msk_i_data_org
@@ -1378,6 +1344,38 @@ def compute_surface_points(self, connectivity: int, dilated_surface: bool = Fals
13781344 return np_point_coordinates (surface .get_seg_array ()) # type: ignore
13791345
13801346
1347+ def fill_holes_global_with_majority_voting (self , connectivity : int = 3 , inplace : bool = False , verbose : bool = False ):
1348+ """Fills 3D holes globally, and resolves inter-label conflicts with majority voting by neighbors
1349+
1350+ Args:
1351+ connectivity (int, optional): Connectivity of fill holes. Defaults to 3.
1352+ inplace (bool, optional): Defaults to False.
1353+ verbose (bool, optional): Defaults to False.
1354+
1355+ Returns:
1356+ NII:
1357+ """
1358+ assert self .seg , "only works with segmentation masks"
1359+ arr = np_fill_holes_global_with_majority_voting (self .get_seg_array (), connectivity = connectivity , verbose = verbose , inplace = inplace )
1360+ return self .set_array (arr ,inplace = inplace )
1361+
1362+
1363+ def map_labels_based_on_majority_label_mask_overlap (self , label_mask : Self , labels : int | list [int ] | None = None , dilate_pixel : int = 1 , inplace : bool = False ):
1364+ """Relabels all individual labels from input array to the majority labels of a given label_mask
1365+
1366+ Args:
1367+ label_mask (np.ndarray): the mask from which to pull the target labels.
1368+ labels (int | list[int] | None, optional): Which labels in the input to process. Defaults to None.
1369+ dilate_pixel (int, optional): If true, will dilate the input to calculate the overlap. Defaults to 1.
1370+ inplace (bool, optional): Defaults to False.
1371+
1372+ Returns:
1373+ NII: Relabeled nifti
1374+ """
1375+ assert self .seg and label_mask .seg , "This only works on segmentations"
1376+ return self .set_array (np_map_labels_based_on_majority_label_mask_overlap (self .get_seg_array (), label_mask .get_seg_array (), labels = labels , dilate_pixel = dilate_pixel , inplace = inplace ), inplace = inplace ,)
1377+
1378+
13811379 def get_segmentation_difference_to (self , mask_gt : Self , ignore_background_tp : bool = False ) -> Self :
13821380 """Calculates an NII that represents the segmentation difference between self and given groundtruth mask
13831381
@@ -1709,8 +1707,16 @@ def get_intersecting_volume(self, b: Self) -> bool:
17091707 b = b .resample_from_to (self ,c_val = 0 ,verbose = False ) # type: ignore
17101708 return b .get_array ().sum ()
17111709
1710+ def extract_background (self ,inplace = False ):
1711+ assert self .seg , "extracting the background only makes sense for a segmentation mask"
1712+ arr_bg = self .get_seg_array ()
1713+ arr_bg = np_extract_label (arr_bg , label = 0 , to_label = 1 )
1714+ return self .set_array (arr_bg , inplace , False )
1715+
1716+
17121717 def extract_label (self ,label :int | Enum | Sequence [int ]| Sequence [Enum ], keep_label = False ,inplace = False ):
17131718 '''If this NII is a segmentation you can single out one label with [0,1].'''
1719+ assert self .seg , "extracting a label only makes sense for a segmentation mask"
17141720 seg_arr = self .get_seg_array ()
17151721
17161722 if isinstance (label , Sequence ):
@@ -1726,7 +1732,7 @@ def extract_label(self,label:int|Enum|Sequence[int]|Sequence[Enum], keep_label=F
17261732 if isinstance (label ,str ):
17271733 label = int (label )
17281734
1729- assert label != 0 , 'Zero label does not make sens . This is the background'
1735+ assert label != 0 , 'Zero label does not make sense . This is the background'
17301736 seg_arr [seg_arr != label ] = 0
17311737 seg_arr [seg_arr == label ] = 1
17321738 if keep_label :
0 commit comments