Skip to content

Commit 0bf3126

Browse files
authored
Updates (#74)
* add remove_to_label_for_cc_filter * add remove_to * add itk global_cord support * add referecne for loading POI (POI.load) * enable loading via Path * add set_above_3_point_plane * make an internal function to make vert-seg * do not corp when empty instead of failing * fix issue with empty images and cropout * bug fixes * ruff * remove deprecated function * x * move assert_affine ot has_grid to avoid duplicat code in NII and POI * add is_segmentation_in_border * add support for itk_cords in Global poi * add new POI definition * add saving, refactor point reg * prepare save and load global coords with POI * enable selection of cuda,mps and cpu for reg * refactor move save load poi in its separate file * add save / load for global POI * bug fix saving * change default Enum for POI * ruff --------- Co-authored-by: ga84mun <robert.graf@tum.de>
1 parent 88c4c1a commit 0bf3126

File tree

20 files changed

+1059
-673
lines changed

20 files changed

+1059
-673
lines changed

TPTBox/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,11 @@
2121
to_nii_interpolateable,
2222
to_nii_optional,
2323
to_nii_seg,
24+
Has_Grid,
2425
)
2526
from TPTBox.core.poi import AX_CODES, POI, POI_Reference, calc_centroids, calc_poi_from_subreg_vert
2627
from TPTBox.core.poi import calc_poi_from_two_segs
2728
from TPTBox.core.poi import calc_poi_from_two_segs as calc_poi_labeled_buffered
28-
from TPTBox.core.poi import load_poi
29-
from TPTBox.core.poi import load_poi as load_centroids
3029
from TPTBox.core.poi_fun.poi_global import POI_Global
3130
from TPTBox.core.vert_constants import ZOOMS, Location, Vertebra_Instance, v_idx2name, v_idx_order, v_name2idx
3231

TPTBox/core/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
# NII
1111
from .nii_wrapper import NII, Image_Reference, Interpolateable_Image_Reference, to_nii, to_nii_interpolateable, to_nii_optional, to_nii_seg
12-
from .poi import AX_CODES, POI, POI_Reference, calc_centroids, calc_poi_from_subreg_vert, calc_poi_from_two_segs, load_poi
12+
from .poi import AX_CODES, POI, POI_Reference, calc_centroids, calc_poi_from_subreg_vert, calc_poi_from_two_segs
1313
from .poi_fun.poi_global import POI_Global
1414
from .vert_constants import ZOOMS, Location, v_idx2name, v_idx_order, v_name2idx
1515

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from typing import Literal
2+
3+
import torch
4+
5+
from TPTBox.core.vert_constants import never_called
6+
7+
DEVICES = Literal["cpu", "cuda", "mps"]
8+
9+
10+
def get_device(ddevice: DEVICES, gpu_id: int):
11+
if ddevice == "cpu":
12+
# import multiprocessing
13+
14+
# try:
15+
# torch.set_num_threads(multiprocessing.cpu_count())
16+
# except Exception:
17+
# pass
18+
device = torch.device("cpu")
19+
elif ddevice == "cuda":
20+
device = torch.device(type="cuda", index=gpu_id)
21+
elif ddevice == "mps":
22+
device = torch.device("mps")
23+
else:
24+
never_called(ddevice)
25+
return device

TPTBox/core/nii_poi_abstract.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,11 @@ def spacing(self, value: ZOOMS):
7171
self.zoom = value
7272

7373
def __str__(self) -> str:
74-
#origin={tuple(np.around(self.origin, decimals=2))}
75-
return f"shape={self.shape},spacing={tuple(np.around(self.zoom, decimals=2))}, origin={self.origin}, ori={self.orientation}" # type: ignore
74+
try:
75+
origin = {tuple(np.around(self.origin, decimals=2))}
76+
except Exception:
77+
origin = self.origin
78+
return f"shape={self.shape},spacing={tuple(np.around(self.zoom, decimals=2))}, origin={origin}, ori={self.orientation}" # type: ignore
7679

7780
@property
7881
def affine(self):
@@ -95,15 +98,22 @@ def affine(self, affine: np.ndarray):
9598
self.rotation = rotation
9699
self.origin = origin.tolist()
97100

98-
def _extract_affine(self: Has_Grid, rm_key=()):
99-
out = {"zoom": self.spacing, "origin": self.origin, "shape": self.shape, "rotation": self.rotation, "orientation": self.orientation}
101+
def _extract_affine(self: Has_Grid, rm_key=(), **args):
102+
out = {
103+
"zoom": self.spacing,
104+
"origin": self.origin,
105+
"shape": self.shape,
106+
"rotation": self.rotation,
107+
"orientation": self.orientation,
108+
**args,
109+
}
100110
for k in rm_key:
101111
out.pop(k)
102112
return out
103113

104114
def assert_affine(
105115
self,
106-
other: Self | NII | POI | None = None,
116+
other: Self | Has_Grid | None = None,
107117
ignore_missing_values: bool = False,
108118
affine: AFFINE | None = None,
109119
zoom: ZOOMS | None = None,
@@ -121,7 +131,7 @@ def assert_affine(
121131
"""Checks if the different metadata is equal to some comparison entries
122132
123133
Args:
124-
other (Self | POI | None, optional): If set, will assert each entry of that object instead. Defaults to None.
134+
other (Has_Grid | None, optional): If set, will assert each entry of that object instead. Defaults to None.
125135
affine (AFFINE | None, optional): Affine matrix to compare against. If none, will not assert affine. Defaults to None.
126136
zms (Zooms | None, optional): Zoom to compare against. If none, will not assert zoom. Defaults to None.
127137
orientation (Ax_Codes | None, optional): Orientation to compare against. If none, will not assert orientation. Defaults to None.
@@ -251,21 +261,26 @@ def make_empty_POI(self, points: dict | None = None):
251261
from TPTBox import POI
252262

253263
p = {} if points is None else points
254-
return POI(p, orientation=self.orientation, zoom=self.zoom, shape=self.shape, rotation=self.rotation, origin=self.origin)
264+
args = {}
265+
if isinstance(self, POI):
266+
args["level_one_info"] = self.level_one_info
267+
args["level_two_info"] = self.level_two_info
268+
269+
return POI(p, orientation=self.orientation, zoom=self.zoom, shape=self.shape, rotation=self.rotation, origin=self.origin, **args)
255270

256271
def make_empty_nii(self, seg=False, _arr=None):
257272
from TPTBox import NII
258273

259274
if _arr is None:
260275
_arr = np.zeros(self.shape_int)
261276
else:
262-
assert (
263-
_arr.shape == self.shape_int
264-
), f"Expected the correct shape for generating a image from Grid; Got {_arr.shape}, expected {self.shape_int}"
277+
assert _arr.shape == self.shape_int, (
278+
f"Expected the correct shape for generating a image from Grid; Got {_arr.shape}, expected {self.shape_int}"
279+
)
265280
nii = nib.Nifti1Image(_arr, affine=self.affine)
266281
return NII(nii, seg=seg)
267282

268-
def make_nii(self, arr: np.ndarray, seg=False):
283+
def make_nii(self, arr: np.ndarray | None = None, seg=False):
269284
"""Make a nii with the same grid as object. Shape must fit the Grid.
270285
271286
Args:
@@ -275,6 +290,8 @@ def make_nii(self, arr: np.ndarray, seg=False):
275290
Returns:
276291
NII
277292
"""
293+
if arr is None:
294+
arr = np.zeros(self.shape_int)
278295
return self.make_empty_nii(_arr=arr, seg=seg)
279296

280297
def global_to_local(self, x: COORDINATE):

TPTBox/core/nii_wrapper.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,7 @@ def load_nrrd(cls, path: str | Path, seg: bool):
234234
raise KeyError(f"Missing expected header field: {e}") from None
235235
ref_orientation = header.get("ref_orientation")
236236
for i in ["ref_orientation","dimension","space directions","space origin""space","type","endian"]:
237-
if i in header :
238-
del header[i]
237+
header.pop(i, None)
239238
for key in list(header.keys()):
240239
if "_Extent" in key:
241240
del header[key]
@@ -696,9 +695,9 @@ def apply_center_crop(self, center_shape: tuple[int,int,int], inplace=False, ver
696695
else:
697696
arr_padded = arr
698697

699-
crop_rel_x = int(round((shp_x - crop_x) / 2.0))
700-
crop_rel_y = int(round((shp_y - crop_y) / 2.0))
701-
crop_rel_z = int(round((shp_z - crop_z) / 2.0))
698+
crop_rel_x = round((shp_x - crop_x) / 2.0)
699+
crop_rel_y = round((shp_y - crop_y) / 2.0)
700+
crop_rel_z = round((shp_z - crop_z) / 2.0)
702701

703702
crop_slices = (slice(crop_rel_x, crop_rel_x + crop_x),slice(crop_rel_y, crop_rel_y + crop_y),slice(crop_rel_z, crop_rel_z + crop_z))
704703
arr_cropped = arr_padded[crop_slices]
@@ -1427,6 +1426,27 @@ def get_overlapping_labels_to(
14271426
assert self.seg and mask_other.seg
14281427
return np_calc_overlapping_labels(self.get_seg_array(), mask_other.get_seg_array())
14291428

1429+
def is_segmentation_in_border(self,minimum=0, voxel_tolerance: int = 2,use_mm=False):
1430+
"""
1431+
Checks if the segmentation is touching the border of the image volume.
1432+
1433+
Parameters:
1434+
- minimum (int, optional): Minimum intensity threshold for segmentation. Defaults to 0.
1435+
- voxel_tolerance (int, optional): Number of voxels allowed as tolerance from the border. Defaults to 2.
1436+
- use_mm (bool, optional): Whether to use millimeter units instead of voxels. Defaults to False.
1437+
1438+
Returns:
1439+
- bool: True if the segmentation is within the defined voxel tolerance of the border, False otherwise.
1440+
"""
1441+
slices = self.compute_crop(minimum,dist=0,use_mm=use_mm)
1442+
shp = self.shape
1443+
seg_at_border = False
1444+
for d in range(3):
1445+
if slices[d].start <= voxel_tolerance or slices[d].stop - 1 >= shp[d] - voxel_tolerance:
1446+
seg_at_border = True
1447+
break
1448+
return seg_at_border
1449+
14301450
def truncate_labels_beyond_reference_(
14311451
self, idx: int | list[int] = 1, not_beyond: int | list[int] = 1, fill: int = 0, axis: DIRECTIONS = "S", inclusion: bool = False, inplace: bool = True
14321452
):
@@ -1618,8 +1638,7 @@ def save_nrrd(self:Self, file: str | Path|bids_files.BIDS_FILE,make_parents=True
16181638
'encoding': 'gzip',
16191639
**_header,**self.info
16201640
}
1621-
if "Segmentation_ConversionParameters" in header:
1622-
del header["Segmentation_ConversionParameters"]
1641+
header.pop("Segmentation_ConversionParameters", None)
16231642
# Save NRRD file
16241643

16251644
log.print(f"Saveing {file}",verbose=verbose,ltype=Log_Type.SAVE,end='\r')

0 commit comments

Comments
 (0)