Skip to content

Commit 5f20f41

Browse files
authored
New features robert (#75)
* add remove_to_label_for_cc_filter * add remove_to * add itk global_cord support * add referecne for loading POI (POI.load) * enable loading via Path * add set_above_3_point_plane * make an internal function to make vert-seg * do not corp when empty instead of failing * fix issue with empty images and cropout * bug fixes * ruff * remove deprecated function * x * move assert_affine ot has_grid to avoid duplicat code in NII and POI * add is_segmentation_in_border * add support for itk_cords in Global poi * add new POI definition * add saving, refactor point reg * prepare save and load global coords with POI * enable selection of cuda,mps and cpu for reg * refactor move save load poi in its separate file * add save / load for global POI * bug fix saving * change default Enum for POI * ruff * fix py 3.9; remove error in test --------- Co-authored-by: ga84mun <robert.graf@tum.de>
1 parent 0bf3126 commit 5f20f41

File tree

5 files changed

+16
-18
lines changed

5 files changed

+16
-18
lines changed

TPTBox/core/poi.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1169,19 +1169,13 @@ def calc_centroids(
11691169
assert first_stage == -1 or second_stage == -1, "first or second dimension must be fixed."
11701170
msk_nii = to_nii(msk, seg=True)
11711171
msk_data = msk_nii.get_seg_array()
1172-
axc: AX_CODES = nio.aff2axcodes(msk_nii.affine) # type: ignore
11731172
if extend_to is None:
11741173
ctd_list = POI_Descriptor()
11751174
else:
11761175
if not inplace:
11771176
extend_to = extend_to.copy()
11781177
ctd_list = extend_to.centroids
1179-
extend_to.assert_affine(
1180-
msk_nii,
1181-
shape_tolerance=0.5,
1182-
origin_tolerance=0.5,
1183-
)
1184-
# assert extend_to.orientation == axc, (extend_to.orientation, axc)
1178+
extend_to.assert_affine(msk_nii, shape_tolerance=0.5, origin_tolerance=0.5)
11851179
for i in msk_nii.unique():
11861180
msk_temp = np.zeros(msk_data.shape, dtype=bool)
11871181
msk_temp[msk_data == i] = True
@@ -1190,7 +1184,7 @@ def calc_centroids(
11901184
ctd_list[first_stage, int(i)] = tuple(round(x, decimals) for x in ctr_mass)
11911185
else:
11921186
ctd_list[int(i), second_stage] = tuple(round(x, decimals) for x in ctr_mass)
1193-
return POI(ctd_list, orientation=axc, **msk_nii._extract_affine(rm_key=["orientation"]), **args)
1187+
return POI(ctd_list, **msk_nii._extract_affine(), **args)
11941188

11951189

11961190
######## Utility #######

TPTBox/registration/deformable/_deepali/deform_reg_pair.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import json
24
from pathlib import Path
35
from timeit import default_timer as timer

TPTBox/registration/deformable/_deepali/registration_losses.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
r"""Registration loss for pairwise image registration."""
1+
from __future__ import annotations
22

33
import re
44
from collections import defaultdict
@@ -24,6 +24,8 @@
2424
from torch import Tensor
2525
from torch.nn import Module
2626

27+
r"""Registration loss for pairwise image registration."""
28+
2729
RE_WEIGHT = re.compile(r"^((?P<mul>[0-9]+(\.[0-9]+)?)\s*[\* ])?\s*(?P<chn>[a-zA-Z0-9_-]+)\s*(\+\s*(?P<add>[0-9]+(\.[0-9]+)?))?$")
2830
RE_TERM_VAR = re.compile(r"^[a-zA-Z0-9_-]+\((?P<var>[a-zA-Z0-9_]+)\)$")
2931

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ tqdm = "*"
2929
joblib = "*"
3030
scikit-learn = "*"
3131
antspyx = "0.4.2"
32+
hf-deepali = "*"
3233

3334
[tool.poetry.dev-dependencies]
3435
pytest = ">=8.1.1"

unit_tests/test_testsamples.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import unittest # noqa: E402
1818

19-
from TPTBox import NII, Location, calc_poi_from_subreg_vert # noqa: E402
19+
from TPTBox import NII, Location, Print_Logger, calc_poi_from_subreg_vert # noqa: E402
2020
from TPTBox.tests.test_utils import get_test_ct, get_test_mri, get_tests_dir # noqa: E402
2121

2222

@@ -84,15 +84,13 @@ def test_make_snapshot_both(self, keep_images=False):
8484

8585
def make_POIs(self, vert_nii: NII, subreg_nii: NII, vert_id: int, ignore_list: list[Location], locs: None | list[Location] = None, n=5):
8686
for i in range(n):
87-
if locs is None:
88-
locs = [l for l in Location if l not in ignore_list and random.random() < (i + 1) / n * 3]
89-
poi = calc_poi_from_subreg_vert(vert_nii, subreg_nii, subreg_id=locs, verbose=False, _print_phases=True).extract_region(vert_id)
90-
for l in locs:
91-
self.assertIn((vert_id, l.value), poi)
92-
poi.assert_affine(
93-
vert_nii,
94-
shape_tolerance=0.5,
87+
locs2 = [l for l in Location if l not in ignore_list and random.random() < (i + 1) / n * 3] if locs is None else locs
88+
poi = calc_poi_from_subreg_vert(vert_nii, subreg_nii, subreg_id=locs2, verbose=False, _print_phases=True).extract_region(
89+
vert_id
9590
)
91+
for l in locs2:
92+
self.assertIn((vert_id, l.value), poi)
93+
poi.assert_affine(vert_nii, shape_tolerance=0.5)
9694

9795
def test_POIs_CT(self):
9896
_, subreg_nii, vert_nii, label = get_test_ct()
@@ -107,6 +105,7 @@ def test_POIs_CT(self):
107105
Location.Dens_axis,
108106
Location.Unknown,
109107
Location.Endplate,
108+
Location.Vertebra_Disc, # CT example has no disc...
110109
Location.Spinal_Cord,
111110
Location.Spinal_Canal,
112111
Location.Spinal_Canal_ivd_lvl,

0 commit comments

Comments
 (0)