Skip to content

Commit 27b76e2

Browse files
authored
Merge pull request #91 from Hendrik-code/snapshot_readability
small bugfixes
2 parents 80527ce + 3dcdf1b commit 27b76e2

File tree

12 files changed

+219
-91
lines changed

12 files changed

+219
-91
lines changed

TPTBox/core/bids_files.py

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -186,28 +186,40 @@ def save_buffer(f: Path, buffer_name):
186186

187187
age = today - file_mod_time
188188
if age.days >= int(max_age_days):
189-
print(
190-
"[ ] Delete Buffer - to old:",
191-
(folder / buffer_name),
192-
f"{' ':20}",
193-
) if verbose else None
189+
(
190+
print(
191+
"[ ] Delete Buffer - to old:",
192+
(folder / buffer_name),
193+
f"{' ':20}",
194+
)
195+
if verbose
196+
else None
197+
)
194198
(folder / buffer_name).unlink()
195199
if (folder / buffer_name).exists() and parent not in recompute_parents:
196200
with open((folder / buffer_name), "rb") as b:
197201
l = pickle.load(b)
202+
(
203+
print(
204+
f"[{len(l):8}] Read Buffer:",
205+
(folder / buffer_name),
206+
f"{' ':20}",
207+
)
208+
if verbose
209+
else None
210+
)
211+
files[dataset] += l
212+
else:
213+
(
198214
print(
199-
f"[{len(l):8}] Read Buffer:",
215+
f"[{_cont:8}] Create new Buffer:",
200216
(folder / buffer_name),
201217
f"{' ':20}",
202-
) if verbose else None
203-
files[dataset] += l
204-
else:
205-
print(
206-
f"[{_cont:8}] Create new Buffer:",
207-
(folder / buffer_name),
208-
f"{' ':20}",
209-
end="\r",
210-
) if verbose else None
218+
end="\r",
219+
)
220+
if verbose
221+
else None
222+
)
211223
files[dataset] += save_buffer((folder), buffer_name)
212224
if filter_file is not None:
213225
files: dict[Path | str, list[Path]] = {d: [g for g in f if filter_file(g)] for d, f in files.items()}
@@ -353,10 +365,14 @@ def add_file_2_subject(self, bids: BIDS_FILE | Path, ds=None) -> None:
353365
if subject not in self.subjects:
354366
self.subjects[subject] = Subject_Container(subject, self.sequence_splitting_keys)
355367
self.count_file += 1
356-
print(
357-
f"Found: {subject}, total file keys {(self.count_file)}, total subjects = {len(self.subjects)} ",
358-
end="\r",
359-
) if self.verbose else None
368+
(
369+
print(
370+
f"Found: {subject}, total file keys {(self.count_file)}, total subjects = {len(self.subjects)} ",
371+
end="\r",
372+
)
373+
if self.verbose
374+
else None
375+
)
360376
self.subjects[subject].add(bids)
361377

362378
def enumerate_subjects(self, sort=False, shuffle=False) -> list[tuple[str, Subject_Container]]:
@@ -729,6 +745,9 @@ def get_changed_path( # noqa: C901
729745
info = {}
730746
if non_strict_mode and not self.BIDS_key.startswith("sub"):
731747
info["sub"] = self.BIDS_key.replace("_", "-").replace(".", "-")
748+
else:
749+
# replace _ with - in all info
750+
self.info = {k: v.replace("_", "-") for k, v in self.info.items()}
732751
if isinstance(file_type, str) and file_type.startswith("."):
733752
file_type = file_type[1:]
734753
path = self.insert_info_into_path(path)

TPTBox/core/nii_wrapper.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from nibabel import Nifti1Header, Nifti1Image # type: ignore
1717
from typing_extensions import Self
1818

19+
from TPTBox.core import bids_files
1920
from TPTBox.core.compat import zip_strict
2021
from TPTBox.core.internal.nii_help import _resample_from_to, secure_save
2122
from TPTBox.core.nii_poi_abstract import Has_Grid
@@ -47,10 +48,7 @@
4748
np_unique_withoutzero,
4849
np_volume,
4950
)
50-
from TPTBox.logger.log_file import Log_Type
51-
52-
from . import bids_files
53-
from .vert_constants import (
51+
from TPTBox.core.vert_constants import (
5452
AFFINE,
5553
AX_CODES,
5654
COORDINATE,
@@ -65,6 +63,7 @@
6563
logging,
6664
v_name2idx,
6765
)
66+
from TPTBox.logger.log_file import Log_Type
6867

6968
if TYPE_CHECKING:
7069
from torch import device
@@ -1065,7 +1064,7 @@ def normalize_to_range_(self, min_value: int = 0, max_value: int = 1500, verbose
10651064
mi, ma = self.min(), self.max()
10661065
self += -mi + min_value # min = 0
10671066
self_dtype = self.dtype
1068-
max_value2 = ma
1067+
max_value2 = self.max() # this is a new value if min got shifted
10691068
if max_value2 > max_value:
10701069
self *= max_value / max_value2
10711070
self.set_dtype_(self_dtype)
@@ -1125,7 +1124,9 @@ def smooth_gaussian_labelwise(
11251124
boundary_mode: str = "nearest",
11261125
dilate_prior: int = 0,
11271126
dilate_connectivity: int = 1,
1127+
dilate_channelwise: bool = False,
11281128
smooth_background: bool = True,
1129+
background_threshold: float | None = None,
11291130
inplace: bool = False,
11301131
):
11311132
"""Smoothes the segmentation mask by applying a gaussian filter label-wise and then using argmax to derive the smoothed segmentation labels again.
@@ -1145,8 +1146,20 @@ def smooth_gaussian_labelwise(
11451146
NII: The smoothed NII object.
11461147
"""
11471148
assert self.seg, "You cannot use this on a non-segmentation NII"
1148-
smoothed = np_smooth_gaussian_labelwise(self.get_seg_array(), label_to_smooth=label_to_smooth, sigma=sigma, radius=radius, truncate=truncate, boundary_mode=boundary_mode, dilate_prior=dilate_prior, dilate_connectivity=dilate_connectivity,smooth_background=smooth_background,)
1149-
return self.set_array(smoothed,inplace,verbose=False)
1149+
smoothed = np_smooth_gaussian_labelwise(
1150+
self.get_seg_array(),
1151+
label_to_smooth=label_to_smooth,
1152+
sigma=sigma,
1153+
radius=radius,
1154+
truncate=truncate,
1155+
boundary_mode=boundary_mode,
1156+
dilate_prior=dilate_prior,
1157+
dilate_connectivity=dilate_connectivity,
1158+
smooth_background=smooth_background,
1159+
background_threshold=background_threshold,
1160+
dilate_channelwise=dilate_channelwise,
1161+
)
1162+
return self.set_array(smoothed, inplace, verbose=False)
11501163

11511164
def smooth_gaussian_labelwise_(
11521165
self,
@@ -1157,9 +1170,23 @@ def smooth_gaussian_labelwise_(
11571170
boundary_mode: str = "nearest",
11581171
dilate_prior: int = 1,
11591172
dilate_connectivity: int = 1,
1160-
smooth_background: bool = True
1173+
dilate_channelwise: bool = False,
1174+
smooth_background: bool = True,
1175+
background_threshold: float | None = None,
11611176
):
1162-
return self.smooth_gaussian_labelwise(label_to_smooth=label_to_smooth, sigma=sigma, radius=radius, truncate=truncate, boundary_mode=boundary_mode, dilate_prior=dilate_prior, dilate_connectivity=dilate_connectivity, smooth_background=smooth_background, inplace=True,)
1177+
return self.smooth_gaussian_labelwise(
1178+
label_to_smooth=label_to_smooth,
1179+
sigma=sigma,
1180+
radius=radius,
1181+
truncate=truncate,
1182+
boundary_mode=boundary_mode,
1183+
dilate_prior=dilate_prior,
1184+
dilate_connectivity=dilate_connectivity,
1185+
smooth_background=smooth_background,
1186+
inplace=True,
1187+
background_threshold=background_threshold,
1188+
dilate_channelwise=dilate_channelwise,
1189+
)
11631190

11641191
def to_ants(self):
11651192
try:
@@ -1402,7 +1429,7 @@ def filter_connected_components(self, labels: int |list[int]|None=None,min_volum
14021429
#print("filter",nii.unique())
14031430
#assert max_count_component is None or nii.max() <= max_count_component, nii.unique()
14041431
return self.set_array(arr, inplace=inplace)
1405-
def filter_connected_components_(self, labels: int |list[int]|None,min_volume:int=0,max_volume:int|None=None, max_count_component = None, connectivity: int = 3,keep_label=False):
1432+
def filter_connected_components_(self, labels: int |list[int]|None=None,min_volume:int=0,max_volume:int|None=None, max_count_component = None, connectivity: int = 3,keep_label=False):
14061433
return self.filter_connected_components(labels,min_volume=min_volume,max_volume=max_volume, max_count_component = max_count_component, connectivity = connectivity,keep_label=keep_label,inplace=True)
14071434

14081435
def get_segmentation_connected_components_center_of_mass(self, label: int, connectivity: int = 3, sort_by_axis: int | None = None) -> list[COORDINATE]:

TPTBox/core/np_utils.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -645,9 +645,14 @@ def np_compute_surface(arr: UINTARRAY, connectivity: int = 3, dilated_surface: b
645645
"""
646646
assert 1 <= connectivity <= 3, f"expected connectivity in [1,3], but got {connectivity}"
647647
if dilated_surface:
648-
return np_dilate_msk(arr.copy(), n_pixel=1, connectivity=connectivity) - arr
648+
dil = np_dilate_msk(arr.copy(), n_pixel=1, connectivity=connectivity)
649+
dil[arr != 0] = 0 # remove all non-zero entries
650+
return dil
649651
else:
650-
return arr - np_erode_msk(arr.copy(), n_pixel=1, connectivity=connectivity)
652+
ero = np_erode_msk(arr.copy(), n_pixel=1, connectivity=connectivity)
653+
arr = arr.copy()
654+
arr[ero != 0] = 0 # remove all non-zero entries
655+
return arr
651656

652657

653658
def np_point_coordinates(
@@ -969,7 +974,9 @@ def np_smooth_gaussian_labelwise(
969974
boundary_mode: str = "nearest",
970975
dilate_prior: int = 0,
971976
dilate_connectivity: int = 3,
977+
dilate_channelwise: bool = False,
972978
smooth_background: bool = True,
979+
background_threshold: float | None = None,
973980
) -> UINTARRAY:
974981
"""Smoothes selected labels in a segmentation mask using Gaussian filtering,
975982
while keeping other labels unaffected.
@@ -1010,7 +1017,7 @@ def np_smooth_gaussian_labelwise(
10101017
for l in label_to_smooth:
10111018
assert l in sem_labels, f"You want to smooth label {l} but it is not present in the given segmentation mask"
10121019

1013-
if dilate_prior > 0:
1020+
if dilate_prior > 0 and not dilate_channelwise:
10141021
arr = np_dilate_msk(
10151022
arr,
10161023
n_pixel=dilate_prior,
@@ -1023,6 +1030,13 @@ def np_smooth_gaussian_labelwise(
10231030
sem_labels_plus_background.append(0)
10241031
for l in sem_labels_plus_background[:-1]:
10251032
arr_l = (arr == l).astype(float)
1033+
if dilate_prior > 0 and dilate_channelwise:
1034+
arr_l = np_dilate_msk(
1035+
arr_l,
1036+
n_pixel=dilate_prior,
1037+
label_ref=1,
1038+
connectivity=dilate_connectivity,
1039+
)
10261040
if l in label_to_smooth:
10271041
arr_l = gaussian_filter(
10281042
arr_l,
@@ -1053,6 +1067,9 @@ def np_smooth_gaussian_labelwise(
10531067
seg_arr_smoothed = np.argmax(arr_stack, axis=0)
10541068
seg_arr_s = seg_arr_smoothed.copy()
10551069

1070+
if background_threshold is not None:
1071+
seg_arr_smoothed[seg_arr_smoothed < background_threshold] = len(sem_labels_plus_background) - 1 # background label
1072+
10561073
for idx, l in enumerate(sem_labels_plus_background):
10571074
seg_arr_s[seg_arr_smoothed == idx] = l
10581075

TPTBox/core/vert_constants.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,7 @@ def __init__(
368368
self._rib = None
369369
self._ivd = None
370370
self._endplate = None
371+
self.has_rib = has_rib
371372
if has_rib:
372373
self._rib = (
373374
vertebra_label + VERTEBRA_INSTANCE_RIB_LABEL_OFFSET if vertebra_label != 28 else 21 + VERTEBRA_INSTANCE_RIB_LABEL_OFFSET
@@ -487,7 +488,7 @@ def get_previous_poi(self, poi: POI | NII | list[int]):
487488
C3 = 3
488489
C4 = 4
489490
C5 = 5
490-
C6 = 6
491+
C6 = 6, True, True
491492
C7 = 7, True, True
492493
T1 = 8, True, True
493494
T2 = 9, True, True

TPTBox/mesh3D/mesh.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,17 @@ def __init__(self, mesh: pv.PolyData) -> None:
3131
def save(self, filepath: str | Path, mode: MeshOutputType = MeshOutputType.PLY, verbose: logging = True):
3232
filepath = str(filepath)
3333
if not filepath.endswith(mode.value):
34-
filepath += mode.value
34+
filepath += "." + mode.value
3535

3636
filepath = Path(filepath)
37-
if not filepath.exists():
38-
raise FileNotFoundError(filepath)
37+
if not filepath.parent.exists():
38+
raise FileNotFoundError(filepath.parent)
3939

4040
if mode == MeshOutputType.PLY:
41-
self.mesh.export_obj(filepath)
41+
try:
42+
self.mesh.export_obj(filepath)
43+
except AttributeError:
44+
self.mesh.save(filepath)
4245
else:
4346
raise NotImplementedError(f"save with mode {mode}")
4447
log.print(f"Saved mesh: {filepath}", Log_Type.SAVE, verbose=verbose)
@@ -61,6 +64,18 @@ def show(self):
6164
pl.add_mesh(self.mesh)
6265
pl.show()
6366

67+
def save_to_html(self, file_output: str | Path):
68+
pv.start_xvfb()
69+
pl = pv.Plotter()
70+
pl.set_background("black", top=None)
71+
pl.add_axes()
72+
pv.global_theme.axes.show = True
73+
pv.global_theme.edge_color = "white"
74+
pv.global_theme.interactive = True
75+
76+
pl.add_mesh(self.mesh)
77+
pl.export_html(file_output)
78+
6479

6580
class SegmentationMesh(Mesh3D):
6681
def __init__(self, int_arr: np.ndarray | Image_Reference) -> None:

TPTBox/mesh3D/snapshot3D.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ def make_snapshot3D(
3434
ids_list: list[Sequence[int]] | None = None,
3535
smoothing=20,
3636
resolution: float | None = None,
37-
width_factor=1.0,
37+
width_factor: float = 1.0,
38+
scale_factor: int = 1,
3839
verbose=True,
3940
crop=True,
4041
) -> Image.Image:
@@ -80,7 +81,7 @@ def make_snapshot3D(
8081
nii = to_nii_seg(img)
8182
if crop:
8283
try:
83-
nii.apply_crop_(nii.compute_crop())
84+
nii.apply_crop_(nii.compute_crop(dist=2))
8485
except ValueError:
8586
pass
8687
if resolution is None:
@@ -98,7 +99,7 @@ def make_snapshot3D(
9899
ids_list = ids_list2
99100

100101
# TOP : ("A", "I", "R")
101-
nii = nii.reorient(("A", "S", "L")).rescale_((resolution, resolution, resolution))
102+
nii = nii.reorient(("A", "S", "L")).rescale_((resolution, resolution, resolution), mode="constant")
102103
width = int(max(nii.shape[0], nii.shape[2]) * width_factor)
103104
window_size = (width * len(ids_list), nii.shape[1])
104105
with Xvfb():
@@ -110,7 +111,7 @@ def make_snapshot3D(
110111
_plot_sub_seg(scene, nii.extract_label(ids, keep_label=True), x, 0, smoothing, view[i % len(view)])
111112
scene.projection(proj_type="parallel")
112113
scene.reset_camera_tight(margin_factor=1.02)
113-
window.record(scene, size=window_size, out_path=output_path, reset_camera=False)
114+
window.record(scene, size=window_size, out_path=output_path, reset_camera=False, magnification=scale_factor)
114115
scene.clear()
115116
if not is_tmp:
116117
logger.on_save("Save Snapshot3D:", output_path, verbose=verbose)
@@ -129,6 +130,7 @@ def make_snapshot3D_parallel(
129130
resolution: float = 2,
130131
cpus=10,
131132
width_factor=1.0,
133+
scale_factor: int = 1,
132134
override=True,
133135
):
134136
ress = []
@@ -146,6 +148,7 @@ def make_snapshot3D_parallel(
146148
"smoothing": smoothing,
147149
"resolution": resolution,
148150
"width_factor": width_factor,
151+
"scale_factor": scale_factor,
149152
},
150153
)
151154
ress.append(res)

TPTBox/registration/deepali/deepali_model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,9 @@ def __init__(
173173
fixed_mask: Image_Reference | None = None,
174174
moving_mask: Image_Reference | None = None,
175175
# normalize
176-
normalize_strategy: Literal["auto", "CT", "MRI"]
177-
| None = "auto", # Override on_normalize for finer normalization schema or normalize before and set to None. auto: [min,max] -> [0,1]; None: Do noting
176+
normalize_strategy: (
177+
Literal["auto", "CT", "MRI"] | None
178+
) = "auto", # Override on_normalize for finer normalization schema or normalize before and set to None. auto: [min,max] -> [0,1]; None: Do noting
178179
# Pyramid
179180
pyramid_levels: int | None = None, # 1/None = no pyramid; int: number of stacks, tuple from to (0 is finest)
180181
finest_level: int = 0,
@@ -188,6 +189,7 @@ def __init__(
188189
transform_init: PathStr | None = None, # reload initial flowfield from file
189190
optim_name="Adam", # Optimizer name defined in torch.optim. or override on_optimizer finer control
190191
lr: float | Sequence[float] = 0.01, # Learning rate
192+
lr_end_factor: float | None = None, # if set, will use a LinearLR scheduler to reduce the learning rate to this factor * lr
191193
optim_args=None, # args of Optimizer with out lr
192194
smooth_grad=0.0,
193195
verbose=99,
@@ -245,6 +247,7 @@ def __init__(
245247
transform_init=transform_init,
246248
optim_name=optim_name,
247249
lr=lr,
250+
lr_end_factor=lr_end_factor,
248251
optim_args=optim_args,
249252
smooth_grad=smooth_grad,
250253
verbose=verbose,

0 commit comments

Comments
 (0)