Skip to content

Commit 9f6c6f3

Browse files
authored
Merge pull request #68 from Hendrik-code/speed_and_nputils
Speed and nputils
2 parents 917f158 + c3e85e0 commit 9f6c6f3

17 files changed

+949
-109
lines changed

TPTBox/core/nii_wrapper.py

Lines changed: 54 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from TPTBox.core.nii_poi_abstract import Has_Grid
1818
from TPTBox.core.nii_wrapper_math import NII_Math
1919
from TPTBox.core.np_utils import (
20+
_pad_to_parameters,
2021
np_calc_boundary_mask,
2122
np_calc_convex_hull,
2223
np_calc_overlapping_labels,
@@ -25,10 +26,13 @@
2526
np_connected_components,
2627
np_dilate_msk,
2728
np_erode_msk,
29+
np_extract_label,
2830
np_fill_holes,
31+
np_fill_holes_global_with_majority_voting,
2932
np_get_connected_components_center_of_mass,
3033
np_get_largest_k_connected_components,
3134
np_map_labels,
35+
np_map_labels_based_on_majority_label_mask_overlap,
3236
np_point_coordinates,
3337
np_smooth_gaussian_labelwise,
3438
np_unique,
@@ -670,7 +674,7 @@ def compute_crop(self,minimum: float=0, dist: float = 0, use_mm=False, other_cro
670674
#origin_shift = tuple([int(ex_slice[i].start) for i in range(len(ex_slice))])
671675
return tuple(ex_slice)# type: ignore
672676

673-
def apply_center_crop(self, center_shape: tuple[int,int,int], verbose: bool = False):
677+
def apply_center_crop(self, center_shape: tuple[int,int,int], inplace=False, verbose: bool = False):
674678
shp_x, shp_y, shp_z = self.shape
675679
crop_x, crop_y, crop_z = center_shape
676680
arr = self.get_array()
@@ -698,7 +702,8 @@ def apply_center_crop(self, center_shape: tuple[int,int,int], verbose: bool = Fa
698702
log.print(f"Center cropped from {arr_padded.shape} to {arr_cropped.shape}", verbose=verbose)
699703
shp_x, shp_y, shp_z = arr_cropped.shape
700704
assert crop_x == shp_x and crop_y == shp_y and crop_z == shp_z
701-
return self.set_array(arr_cropped)
705+
return self.set_array(arr_cropped, inplace=inplace)
706+
#return self.apply_crop(crop_slices, inplace=inplace)
702707

703708
def apply_crop_slice(self,*args,**qargs):
704709
import warnings
@@ -733,22 +738,7 @@ def apply_crop_(self,ex_slice:tuple[slice,slice,slice]|Sequence[slice]):
733738
def pad_to(self,target_shape:list[int]|tuple[int,int,int] | Self, mode:MODES="constant",crop=False,inplace = False):
734739
if isinstance(target_shape, NII):
735740
target_shape = target_shape.shape
736-
padding = []
737-
crop = []
738-
requires_crop = False
739-
for in_size, out_size in zip(self.shape[-3:], target_shape[-3:],strict=True):
740-
to_pad_size = max(0, out_size - in_size) / 2.0
741-
to_crop_size = -min(0, out_size - in_size) / 2.0
742-
padding.extend([(ceil(to_pad_size), floor(to_pad_size))])
743-
if to_crop_size == 0:
744-
crop.append(slice(None))
745-
else:
746-
end = -floor(to_crop_size)
747-
if end == 0:
748-
end = None
749-
crop.append(slice(ceil(to_crop_size), end))
750-
requires_crop = True
751-
741+
padding, crop, requires_crop = _pad_to_parameters(self.shape, target_shape)
752742
s = self
753743
if crop and requires_crop:
754744
s = s.apply_crop(tuple(crop),inplace=inplace)
@@ -1115,26 +1105,13 @@ def erode_msk(self, n_pixel: int = 5, labels: LABEL_REFERENCE = None, connectivi
11151105
11161106
"""
11171107
log.print("erode mask",end='\r',verbose=verbose)
1118-
if use_crop:
1119-
try:
1120-
crop = (self if labels is None else self.extract_label(labels)).compute_crop(dist=1)
1121-
except ValueError:
1122-
return self if inplace else self.copy()
1123-
1124-
msk_i_data_org = self.get_seg_array()
1125-
msk_i_data = msk_i_data_org[crop]
1126-
else:
1127-
msk_i_data = self.get_seg_array()
1108+
msk_i_data = self.get_seg_array()
11281109
labels = self.unique() if labels is None else labels
11291110
if isinstance(ignore_direction,str):
11301111
ignore_direction = self.get_axis(ignore_direction)
1131-
out = np_erode_msk(msk_i_data, label_ref=labels, mm=n_pixel, connectivity=connectivity,border_value=border_value,ignore_axis=ignore_direction)
1112+
out = np_erode_msk(msk_i_data, label_ref=labels, n_pixel=n_pixel, connectivity=connectivity,border_value=border_value,ignore_axis=ignore_direction, use_crop=use_crop)
11321113
out = out.astype(self.dtype)
11331114
log.print("Mask eroded by", n_pixel, "voxels",verbose=verbose)
1134-
1135-
if use_crop:
1136-
msk_i_data_org[crop] = out
1137-
out = msk_i_data_org
11381115
return self.set_array(out,inplace=inplace)
11391116

11401117
def erode_msk_(self, n_pixel:int = 5, labels: LABEL_REFERENCE = None, connectivity: int=3, verbose:logging=True,border_value=0,use_crop=True,ignore_direction:DIRECTIONS|int|None=None):
@@ -1162,24 +1139,13 @@ def dilate_msk(self, n_pixel: int = 5, labels: LABEL_REFERENCE = None, connectiv
11621139
log.print("dilate mask",end='\r',verbose=verbose)
11631140
if labels is None:
11641141
labels = self.unique()
1165-
if use_crop:
1166-
try:
1167-
crop = (self if labels is None else self.extract_label(labels)).compute_crop(dist=1+n_pixel)
1168-
except ValueError:
1169-
return self if inplace else self.copy()
1170-
msk_i_data_org = self.get_seg_array()
1171-
msk_i_data = msk_i_data_org[crop]
1172-
else:
1173-
msk_i_data = self.get_seg_array()
1142+
msk_i_data = self.get_seg_array()
11741143
mask_ = mask.get_seg_array() if mask is not None else None
11751144
if isinstance(ignore_direction,str):
11761145
ignore_direction = self.get_axis(ignore_direction)
1177-
out = np_dilate_msk(arr=msk_i_data, label_ref=labels, mm=n_pixel, mask=mask_, connectivity=connectivity,ignore_axis=ignore_direction)
1146+
out = np_dilate_msk(arr=msk_i_data, label_ref=labels, n_pixel=n_pixel, mask=mask_, connectivity=connectivity,ignore_axis=ignore_direction, use_crop=use_crop)
11781147
out = out.astype(self.dtype)
11791148
log.print("Mask dilated by", n_pixel, "voxels",verbose=verbose)
1180-
if use_crop:
1181-
msk_i_data_org[crop] = out
1182-
out = msk_i_data_org
11831149
return self.set_array(out,inplace=inplace)
11841150

11851151
def dilate_msk_(self, n_pixel:int = 5, labels: LABEL_REFERENCE = None, connectivity: int=3, mask: Self | None = None, verbose:logging=True,use_crop=True,ignore_direction:DIRECTIONS|int|None=None):
@@ -1218,7 +1184,7 @@ def fill_holes(self, labels: LABEL_REFERENCE = None, slice_wise_dim: int|str | N
12181184
if isinstance(slice_wise_dim,str):
12191185
slice_wise_dim = self.get_axis(slice_wise_dim)
12201186
#seg_arr = self.get_seg_array()
1221-
filled = np_fill_holes(seg_arr, label_ref=labels, slice_wise_dim=slice_wise_dim)
1187+
filled = np_fill_holes(seg_arr, label_ref=labels, slice_wise_dim=slice_wise_dim, use_crop=use_crop)
12221188
if use_crop:
12231189
msk_i_data_org[crop] = filled
12241190
filled = msk_i_data_org
@@ -1378,6 +1344,38 @@ def compute_surface_points(self, connectivity: int, dilated_surface: bool = Fals
13781344
return np_point_coordinates(surface.get_seg_array()) # type: ignore
13791345

13801346

1347+
def fill_holes_global_with_majority_voting(self, connectivity: int = 3, inplace: bool = False, verbose: bool = False):
1348+
"""Fills 3D holes globally, and resolves inter-label conflicts with majority voting by neighbors
1349+
1350+
Args:
1351+
connectivity (int, optional): Connectivity of fill holes. Defaults to 3.
1352+
inplace (bool, optional): Defaults to False.
1353+
verbose (bool, optional): Defaults to False.
1354+
1355+
Returns:
1356+
NII:
1357+
"""
1358+
assert self.seg, "only works with segmentation masks"
1359+
arr = np_fill_holes_global_with_majority_voting(self.get_seg_array(), connectivity=connectivity, verbose=verbose, inplace=inplace)
1360+
return self.set_array(arr,inplace=inplace)
1361+
1362+
1363+
def map_labels_based_on_majority_label_mask_overlap(self, label_mask: Self, labels: int | list[int] | None = None, dilate_pixel: int = 1, inplace: bool = False):
1364+
"""Relabels all individual labels from input array to the majority labels of a given label_mask
1365+
1366+
Args:
1367+
label_mask (np.ndarray): the mask from which to pull the target labels.
1368+
labels (int | list[int] | None, optional): Which labels in the input to process. Defaults to None.
1369+
dilate_pixel (int, optional): If true, will dilate the input to calculate the overlap. Defaults to 1.
1370+
inplace (bool, optional): Defaults to False.
1371+
1372+
Returns:
1373+
NII: Relabeled nifti
1374+
"""
1375+
assert self.seg and label_mask.seg, "This only works on segmentations"
1376+
return self.set_array(np_map_labels_based_on_majority_label_mask_overlap(self.get_seg_array(), label_mask.get_seg_array(), labels=labels, dilate_pixel=dilate_pixel, inplace=inplace), inplace=inplace,)
1377+
1378+
13811379
def get_segmentation_difference_to(self, mask_gt: Self, ignore_background_tp: bool = False) -> Self:
13821380
"""Calculates an NII that represents the segmentation difference between self and given groundtruth mask
13831381
@@ -1709,8 +1707,16 @@ def get_intersecting_volume(self, b: Self) -> bool:
17091707
b = b.resample_from_to(self,c_val=0,verbose=False) # type: ignore
17101708
return b.get_array().sum()
17111709

1710+
def extract_background(self,inplace=False):
1711+
assert self.seg, "extracting the background only makes sense for a segmentation mask"
1712+
arr_bg = self.get_seg_array()
1713+
arr_bg = np_extract_label(arr_bg, label=0, to_label=1)
1714+
return self.set_array(arr_bg, inplace, False)
1715+
1716+
17121717
def extract_label(self,label:int|Enum|Sequence[int]|Sequence[Enum], keep_label=False,inplace=False):
17131718
'''If this NII is a segmentation you can single out one label with [0,1].'''
1719+
assert self.seg, "extracting a label only makes sense for a segmentation mask"
17141720
seg_arr = self.get_seg_array()
17151721

17161722
if isinstance(label, Sequence):
@@ -1726,7 +1732,7 @@ def extract_label(self,label:int|Enum|Sequence[int]|Sequence[Enum], keep_label=F
17261732
if isinstance(label,str):
17271733
label = int(label)
17281734

1729-
assert label != 0, 'Zero label does not make sens. This is the background'
1735+
assert label != 0, 'Zero label does not make sense. This is the background'
17301736
seg_arr[seg_arr != label] = 0
17311737
seg_arr[seg_arr == label] = 1
17321738
if keep_label:

0 commit comments

Comments
 (0)