From 6d8da5a0f8f9727558caec4dda489cd93097cd33 Mon Sep 17 00:00:00 2001 From: iback Date: Tue, 16 Sep 2025 14:58:37 +0000 Subject: [PATCH 1/6] small bugfixes --- TPTBox/core/nii_wrapper.py | 14 +++++++++----- TPTBox/core/np_utils.py | 23 ++++++++++++++++++++--- TPTBox/core/vert_constants.py | 3 ++- TPTBox/mesh3D/mesh.py | 23 +++++++++++++++++++---- TPTBox/mesh3D/snapshot3D.py | 11 +++++++---- 5 files changed, 57 insertions(+), 17 deletions(-) diff --git a/TPTBox/core/nii_wrapper.py b/TPTBox/core/nii_wrapper.py index 3481259..4c97653 100755 --- a/TPTBox/core/nii_wrapper.py +++ b/TPTBox/core/nii_wrapper.py @@ -959,7 +959,7 @@ def normalize_to_range_(self, min_value: int = 0, max_value: int = 1500, verbose mi, ma = self.min(), self.max() self += -mi + min_value # min = 0 self_dtype = self.dtype - max_value2 = ma + max_value2 = self.max() # this is a new value if min got shifted if max_value2 > max_value: self *= max_value / max_value2 self.set_dtype_(self_dtype) @@ -1019,7 +1019,9 @@ def smooth_gaussian_labelwise( boundary_mode: str = "nearest", dilate_prior: int = 1, dilate_connectivity: int = 1, + dilate_channelwise: bool = False, smooth_background: bool = True, + background_threshold: float | None = None, inplace: bool = False, ): """Smoothes the segmentation mask by applying a gaussian filter label-wise and then using argmax to derive the smoothed segmentation labels again. @@ -1039,7 +1041,7 @@ def smooth_gaussian_labelwise( NII: The smoothed NII object. """ assert self.seg, "You cannot use this on a non-segmentation NII" - smoothed = np_smooth_gaussian_labelwise(self.get_seg_array(), label_to_smooth=label_to_smooth, sigma=sigma, radius=radius, truncate=truncate, boundary_mode=boundary_mode, dilate_prior=dilate_prior, dilate_connectivity=dilate_connectivity,smooth_background=smooth_background,) + smoothed = np_smooth_gaussian_labelwise(self.get_seg_array(), label_to_smooth=label_to_smooth, sigma=sigma, radius=radius, truncate=truncate, boundary_mode=boundary_mode, dilate_prior=dilate_prior, dilate_connectivity=dilate_connectivity,smooth_background=smooth_background,background_threshold=background_threshold,dilate_channelwise=dilate_channelwise) return self.set_array(smoothed,inplace,verbose=False) def smooth_gaussian_labelwise_( @@ -1051,9 +1053,11 @@ def smooth_gaussian_labelwise_( boundary_mode: str = "nearest", dilate_prior: int = 1, dilate_connectivity: int = 1, - smooth_background: bool = True + dilate_channelwise: bool = False, + smooth_background: bool = True, + background_threshold: float | None = None, ): - return self.smooth_gaussian_labelwise(label_to_smooth=label_to_smooth, sigma=sigma, radius=radius, truncate=truncate, boundary_mode=boundary_mode, dilate_prior=dilate_prior, dilate_connectivity=dilate_connectivity, smooth_background=smooth_background, inplace=True,) + return self.smooth_gaussian_labelwise(label_to_smooth=label_to_smooth, sigma=sigma, radius=radius, truncate=truncate, boundary_mode=boundary_mode, dilate_prior=dilate_prior, dilate_connectivity=dilate_connectivity, smooth_background=smooth_background, inplace=True, background_threshold=background_threshold, dilate_channelwise=dilate_channelwise) def to_ants(self): try: @@ -1296,7 +1300,7 @@ def filter_connected_components(self, labels: int |list[int]|None=None,min_volum #print("filter",nii.unique()) #assert max_count_component is None or nii.max() <= max_count_component, nii.unique() return self.set_array(arr, inplace=inplace) - def filter_connected_components_(self, labels: int |list[int]|None,min_volume:int=0,max_volume:int|None=None, max_count_component = None, connectivity: int = 3,keep_label=False): + def filter_connected_components_(self, labels: int |list[int]|None=None,min_volume:int=0,max_volume:int|None=None, max_count_component = None, connectivity: int = 3,keep_label=False): return self.filter_connected_components(labels,min_volume=min_volume,max_volume=max_volume, max_count_component = max_count_component, connectivity = connectivity,keep_label=keep_label,inplace=True) def get_segmentation_connected_components_center_of_mass(self, label: int, connectivity: int = 3, sort_by_axis: int | None = None) -> list[COORDINATE]: diff --git a/TPTBox/core/np_utils.py b/TPTBox/core/np_utils.py index 9816f98..0e4f37c 100755 --- a/TPTBox/core/np_utils.py +++ b/TPTBox/core/np_utils.py @@ -634,9 +634,14 @@ def np_compute_surface(arr: UINTARRAY, connectivity: int = 3, dilated_surface: b """ assert 1 <= connectivity <= 3, f"expected connectivity in [1,3], but got {connectivity}" if dilated_surface: - return np_dilate_msk(arr.copy(), n_pixel=1, connectivity=connectivity) - arr + dil = np_dilate_msk(arr.copy(), n_pixel=1, connectivity=connectivity) + dil[arr != 0] = 0 # remove all non-zero entries + return dil else: - return arr - np_erode_msk(arr.copy(), n_pixel=1, connectivity=connectivity) + ero = np_erode_msk(arr.copy(), n_pixel=1, connectivity=connectivity) + arr = arr.copy() + arr[ero != 0] = 0 # remove all non-zero entries + return arr def np_point_coordinates( @@ -954,7 +959,9 @@ def np_smooth_gaussian_labelwise( boundary_mode: str = "nearest", dilate_prior: int = 0, dilate_connectivity: int = 3, + dilate_channelwise: bool = False, smooth_background: bool = True, + background_threshold: float | None = None, ) -> UINTARRAY: """Smoothes labels in a segmentation mask array @@ -983,7 +990,7 @@ def np_smooth_gaussian_labelwise( for l in label_to_smooth: assert l in sem_labels, f"You want to smooth label {l} but it is not present in the given segmentation mask" - if dilate_prior > 0: + if dilate_prior > 0 and not dilate_channelwise: arr = np_dilate_msk( arr, n_pixel=dilate_prior, @@ -996,6 +1003,13 @@ def np_smooth_gaussian_labelwise( sem_labels_plus_background.append(0) for l in sem_labels_plus_background[:-1]: arr_l = (arr == l).astype(float) + if dilate_prior > 0 and dilate_channelwise: + arr_l = np_dilate_msk( + arr_l, + n_pixel=dilate_prior, + label_ref=1, + connectivity=dilate_connectivity, + ) if l in label_to_smooth: arr_l = gaussian_filter( arr_l, @@ -1026,6 +1040,9 @@ def np_smooth_gaussian_labelwise( seg_arr_smoothed = np.argmax(arr_stack, axis=0) seg_arr_s = seg_arr_smoothed.copy() + if background_threshold is not None: + seg_arr_smoothed[seg_arr_smoothed < background_threshold] = len(sem_labels_plus_background) - 1 # background label + for idx, l in enumerate(sem_labels_plus_background): seg_arr_s[seg_arr_smoothed == idx] = l diff --git a/TPTBox/core/vert_constants.py b/TPTBox/core/vert_constants.py index f3be05d..2b8e1b0 100755 --- a/TPTBox/core/vert_constants.py +++ b/TPTBox/core/vert_constants.py @@ -293,6 +293,7 @@ def __init__( self._rib = None self._ivd = None self._endplate = None + self.has_rib = has_rib if has_rib: self._rib = ( vertebra_label + VERTEBRA_INSTANCE_RIB_LABEL_OFFSET if vertebra_label != 28 else 21 + VERTEBRA_INSTANCE_RIB_LABEL_OFFSET @@ -398,7 +399,7 @@ def get_previous_poi(self, poi: POI | NII | list[int]): C3 = 3 C4 = 4 C5 = 5 - C6 = 6 + C6 = 6, True, True C7 = 7, True, True T1 = 8, True, True T2 = 9, True, True diff --git a/TPTBox/mesh3D/mesh.py b/TPTBox/mesh3D/mesh.py index 32a35cf..073b338 100644 --- a/TPTBox/mesh3D/mesh.py +++ b/TPTBox/mesh3D/mesh.py @@ -31,14 +31,17 @@ def __init__(self, mesh: pv.PolyData) -> None: def save(self, filepath: str | Path, mode: MeshOutputType = MeshOutputType.PLY, verbose: logging = True): filepath = str(filepath) if not filepath.endswith(mode.value): - filepath += mode.value + filepath += "." + mode.value filepath = Path(filepath) - if not filepath.exists(): - raise FileNotFoundError(filepath) + if not filepath.parent.exists(): + raise FileNotFoundError(filepath.parent) if mode == MeshOutputType.PLY: - self.mesh.export_obj(filepath) + try: + self.mesh.export_obj(filepath) + except AttributeError: + self.mesh.save(filepath) else: raise NotImplementedError(f"save with mode {mode}") log.print(f"Saved mesh: {filepath}", Log_Type.SAVE, verbose=verbose) @@ -61,6 +64,18 @@ def show(self): pl.add_mesh(self.mesh) pl.show() + def save_to_html(self, file_output: str | Path): + pv.start_xvfb() + pl = pv.Plotter() + pl.set_background("black", top=None) + pl.add_axes() + pv.global_theme.axes.show = True + pv.global_theme.edge_color = "white" + pv.global_theme.interactive = True + + pl.add_mesh(self.mesh) + pl.export_html(file_output) + class SegmentationMesh(Mesh3D): def __init__(self, int_arr: np.ndarray | Image_Reference) -> None: diff --git a/TPTBox/mesh3D/snapshot3D.py b/TPTBox/mesh3D/snapshot3D.py index 18ad63d..c4e6c86 100644 --- a/TPTBox/mesh3D/snapshot3D.py +++ b/TPTBox/mesh3D/snapshot3D.py @@ -34,7 +34,8 @@ def make_snapshot3D( ids_list: list[Sequence[int]] | None = None, smoothing=20, resolution: float | None = None, - width_factor=1.0, + width_factor: float = 1.0, + scale_factor: int = 1, verbose=True, crop=True, ) -> Image.Image: @@ -80,7 +81,7 @@ def make_snapshot3D( nii = to_nii_seg(img) if crop: try: - nii.apply_crop_(nii.compute_crop()) + nii.apply_crop_(nii.compute_crop(dist=2)) except ValueError: pass if resolution is None: @@ -98,7 +99,7 @@ def make_snapshot3D( ids_list = ids_list2 # TOP : ("A", "I", "R") - nii = nii.reorient(("A", "S", "L")).rescale_((resolution, resolution, resolution)) + nii = nii.reorient(("A", "S", "L")).rescale_((resolution, resolution, resolution), mode="constant") width = int(max(nii.shape[0], nii.shape[2]) * width_factor) window_size = (width * len(ids_list), nii.shape[1]) with Xvfb(): @@ -110,7 +111,7 @@ def make_snapshot3D( _plot_sub_seg(scene, nii.extract_label(ids, keep_label=True), x, 0, smoothing, view[i % len(view)]) scene.projection(proj_type="parallel") scene.reset_camera_tight(margin_factor=1.02) - window.record(scene, size=window_size, out_path=output_path, reset_camera=False) + window.record(scene, size=window_size, out_path=output_path, reset_camera=False, magnification=scale_factor) scene.clear() if not is_tmp: logger.on_save("Save Snapshot3D:", output_path, verbose=verbose) @@ -129,6 +130,7 @@ def make_sub_snapshot_parallel( resolution=2, cpus=10, width_factor=1.0, + scale_factor: int = 1, ): ress = [] with Pool(cpus) as p: # type: ignore @@ -143,6 +145,7 @@ def make_sub_snapshot_parallel( "smoothing": smoothing, "resolution": resolution, "width_factor": width_factor, + "scale_factor": scale_factor, }, ) ress.append(res) From 628cfc86e9c3ea4d8e0ecdd06b2b42ab8aca0bda Mon Sep 17 00:00:00 2001 From: Hendrik Date: Wed, 24 Sep 2025 12:29:59 +0200 Subject: [PATCH 2/6] Update TPTBox/core/nii_wrapper.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- TPTBox/core/nii_wrapper.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/TPTBox/core/nii_wrapper.py b/TPTBox/core/nii_wrapper.py index 84f81b5..f0a29a9 100755 --- a/TPTBox/core/nii_wrapper.py +++ b/TPTBox/core/nii_wrapper.py @@ -1147,8 +1147,20 @@ def smooth_gaussian_labelwise( NII: The smoothed NII object. """ assert self.seg, "You cannot use this on a non-segmentation NII" - smoothed = np_smooth_gaussian_labelwise(self.get_seg_array(), label_to_smooth=label_to_smooth, sigma=sigma, radius=radius, truncate=truncate, boundary_mode=boundary_mode, dilate_prior=dilate_prior, dilate_connectivity=dilate_connectivity,smooth_background=smooth_background,background_threshold=background_threshold,dilate_channelwise=dilate_channelwise) - return self.set_array(smoothed,inplace,verbose=False) + smoothed = np_smooth_gaussian_labelwise( + self.get_seg_array(), + label_to_smooth=label_to_smooth, + sigma=sigma, + radius=radius, + truncate=truncate, + boundary_mode=boundary_mode, + dilate_prior=dilate_prior, + dilate_connectivity=dilate_connectivity, + smooth_background=smooth_background, + background_threshold=background_threshold, + dilate_channelwise=dilate_channelwise, + ) + return self.set_array(smoothed, inplace, verbose=False) def smooth_gaussian_labelwise_( self, From d1edd8e864ad3e4daaae0d202185ecf872f4adbf Mon Sep 17 00:00:00 2001 From: Hendrik Date: Wed, 24 Sep 2025 12:30:11 +0200 Subject: [PATCH 3/6] Update TPTBox/core/nii_wrapper.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- TPTBox/core/nii_wrapper.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/TPTBox/core/nii_wrapper.py b/TPTBox/core/nii_wrapper.py index f0a29a9..ff51128 100755 --- a/TPTBox/core/nii_wrapper.py +++ b/TPTBox/core/nii_wrapper.py @@ -1175,7 +1175,19 @@ def smooth_gaussian_labelwise_( smooth_background: bool = True, background_threshold: float | None = None, ): - return self.smooth_gaussian_labelwise(label_to_smooth=label_to_smooth, sigma=sigma, radius=radius, truncate=truncate, boundary_mode=boundary_mode, dilate_prior=dilate_prior, dilate_connectivity=dilate_connectivity, smooth_background=smooth_background, inplace=True, background_threshold=background_threshold, dilate_channelwise=dilate_channelwise) + return self.smooth_gaussian_labelwise( + label_to_smooth=label_to_smooth, + sigma=sigma, + radius=radius, + truncate=truncate, + boundary_mode=boundary_mode, + dilate_prior=dilate_prior, + dilate_connectivity=dilate_connectivity, + smooth_background=smooth_background, + inplace=True, + background_threshold=background_threshold, + dilate_channelwise=dilate_channelwise, + ) def to_ants(self): try: From ba31ff61c6394a1143858d38a8e8017230c54986 Mon Sep 17 00:00:00 2001 From: iback Date: Fri, 17 Oct 2025 14:45:16 +0000 Subject: [PATCH 4/6] registration_stuff mostly, but other small things --- TPTBox/core/bids_files.py | 57 +++++++++----- TPTBox/core/nii_wrapper.py | 7 +- TPTBox/registration/deepali/deepali_model.py | 7 +- .../registration/deepali/deepali_trainer.py | 35 ++++++--- .../registration/deformable/deformable_reg.py | 7 +- .../deformable/multilabel_segmentation.py | 76 ++++++++++++------- .../TotalVibeSeg/inference_nnunet.py | 18 +++-- 7 files changed, 135 insertions(+), 72 deletions(-) diff --git a/TPTBox/core/bids_files.py b/TPTBox/core/bids_files.py index 64d98ca..2f07686 100755 --- a/TPTBox/core/bids_files.py +++ b/TPTBox/core/bids_files.py @@ -186,28 +186,40 @@ def save_buffer(f: Path, buffer_name): age = today - file_mod_time if age.days >= int(max_age_days): - print( - "[ ] Delete Buffer - to old:", - (folder / buffer_name), - f"{' ':20}", - ) if verbose else None + ( + print( + "[ ] Delete Buffer - to old:", + (folder / buffer_name), + f"{' ':20}", + ) + if verbose + else None + ) (folder / buffer_name).unlink() if (folder / buffer_name).exists() and parent not in recompute_parents: with open((folder / buffer_name), "rb") as b: l = pickle.load(b) + ( + print( + f"[{len(l):8}] Read Buffer:", + (folder / buffer_name), + f"{' ':20}", + ) + if verbose + else None + ) + files[dataset] += l + else: + ( print( - f"[{len(l):8}] Read Buffer:", + f"[{_cont:8}] Create new Buffer:", (folder / buffer_name), f"{' ':20}", - ) if verbose else None - files[dataset] += l - else: - print( - f"[{_cont:8}] Create new Buffer:", - (folder / buffer_name), - f"{' ':20}", - end="\r", - ) if verbose else None + end="\r", + ) + if verbose + else None + ) files[dataset] += save_buffer((folder), buffer_name) if filter_file is not None: files: dict[Path | str, list[Path]] = {d: [g for g in f if filter_file(g)] for d, f in files.items()} @@ -353,10 +365,14 @@ def add_file_2_subject(self, bids: BIDS_FILE | Path, ds=None) -> None: if subject not in self.subjects: self.subjects[subject] = Subject_Container(subject, self.sequence_splitting_keys) self.count_file += 1 - print( - f"Found: {subject}, total file keys {(self.count_file)}, total subjects = {len(self.subjects)} ", - end="\r", - ) if self.verbose else None + ( + print( + f"Found: {subject}, total file keys {(self.count_file)}, total subjects = {len(self.subjects)} ", + end="\r", + ) + if self.verbose + else None + ) self.subjects[subject].add(bids) def enumerate_subjects(self, sort=False, shuffle=False) -> list[tuple[str, Subject_Container]]: @@ -729,6 +745,9 @@ def get_changed_path( # noqa: C901 info = {} if non_strict_mode and not self.BIDS_key.startswith("sub"): info["sub"] = self.BIDS_key.replace("_", "-").replace(".", "-") + else: + # replace _ with - in all info + self.info = {k: v.replace("_", "-") for k, v in self.info.items()} if isinstance(file_type, str) and file_type.startswith("."): file_type = file_type[1:] path = self.insert_info_into_path(path) diff --git a/TPTBox/core/nii_wrapper.py b/TPTBox/core/nii_wrapper.py index 84f81b5..ef8e1d7 100755 --- a/TPTBox/core/nii_wrapper.py +++ b/TPTBox/core/nii_wrapper.py @@ -16,6 +16,7 @@ from nibabel import Nifti1Header, Nifti1Image # type: ignore from typing_extensions import Self +from TPTBox.core import bids_files from TPTBox.core.compat import zip_strict from TPTBox.core.internal.nii_help import _resample_from_to, secure_save from TPTBox.core.nii_poi_abstract import Has_Grid @@ -47,10 +48,7 @@ np_unique_withoutzero, np_volume, ) -from TPTBox.logger.log_file import Log_Type - -from . import bids_files -from .vert_constants import ( +from TPTBox.core.vert_constants import ( AFFINE, AX_CODES, COORDINATE, @@ -65,6 +63,7 @@ logging, v_name2idx, ) +from TPTBox.logger.log_file import Log_Type if TYPE_CHECKING: from torch import device diff --git a/TPTBox/registration/deepali/deepali_model.py b/TPTBox/registration/deepali/deepali_model.py index 581ac62..628d6ce 100644 --- a/TPTBox/registration/deepali/deepali_model.py +++ b/TPTBox/registration/deepali/deepali_model.py @@ -173,8 +173,9 @@ def __init__( fixed_mask: Image_Reference | None = None, moving_mask: Image_Reference | None = None, # normalize - normalize_strategy: Literal["auto", "CT", "MRI"] - | None = "auto", # Override on_normalize for finer normalization schema or normalize before and set to None. auto: [min,max] -> [0,1]; None: Do noting + normalize_strategy: ( + Literal["auto", "CT", "MRI"] | None + ) = "auto", # Override on_normalize for finer normalization schema or normalize before and set to None. auto: [min,max] -> [0,1]; None: Do noting # Pyramid pyramid_levels: int | None = None, # 1/None = no pyramid; int: number of stacks, tuple from to (0 is finest) finest_level: int = 0, @@ -188,6 +189,7 @@ def __init__( transform_init: PathStr | None = None, # reload initial flowfield from file optim_name="Adam", # Optimizer name defined in torch.optim. or override on_optimizer finer control lr: float | Sequence[float] = 0.01, # Learning rate + lr_end_factor: float | None = None, # if set, will use a LinearLR scheduler to reduce the learning rate to this factor * lr optim_args=None, # args of Optimizer with out lr smooth_grad=0.0, verbose=99, @@ -245,6 +247,7 @@ def __init__( transform_init=transform_init, optim_name=optim_name, lr=lr, + lr_end_factor=lr_end_factor, optim_args=optim_args, smooth_grad=smooth_grad, verbose=verbose, diff --git a/TPTBox/registration/deepali/deepali_trainer.py b/TPTBox/registration/deepali/deepali_trainer.py index a2edd5a..b9a1ef5 100644 --- a/TPTBox/registration/deepali/deepali_trainer.py +++ b/TPTBox/registration/deepali/deepali_trainer.py @@ -35,7 +35,7 @@ from torch import Tensor from torch.nn import Module from torch.optim import Optimizer -from torch.optim.lr_scheduler import LRScheduler +from torch.optim.lr_scheduler import LinearLR, LRScheduler from torch.utils.hooks import RemovableHandle from ._hooks import normalize_grad_hook, print_eval_loss_hook_tqdm, print_step_loss_hook_tqdm, smooth_grad_hook @@ -82,8 +82,9 @@ def __init__( source_mask: Union[Image, PathStr] | None = None, target_mask: Union[Image, PathStr] | None = None, # normalize - normalize_strategy: Literal["auto", "CT", "MRI"] - | None = "auto", # Override on_normalize for finer normalization schema or normalize before and set to None. auto: [min,max] -> [0,1]; None: Do noting + normalize_strategy: ( + Literal["auto", "CT", "MRI"] | None + ) = "auto", # Override on_normalize for finer normalization schema or normalize before and set to None. auto: [min,max] -> [0,1]; None: Do noting # Pyramid pyramid_levels: int | None = None, # 1/None = no pyramid; int: number of stacks, tuple from to (0 is finest) finest_level: int = 0, @@ -97,6 +98,7 @@ def __init__( transform_init: PathStr | None = None, # reload initial flowfield from file optim_name="Adam", # Optimizer name defined in torch.optim. or override on_optimizer finer control lr: float | Sequence[float] = 0.01, # Learning rate + lr_end_factor: float | None = None, # if set, will use a LinearLR scheduler to reduce the learning rate to this factor * lr optim_args=None, # args of Optimizer with out lr smooth_grad=0.0, verbose=0, @@ -190,6 +192,7 @@ def __init__( self.model_init = transform_init self.optim_name = optim_name self.lr = lr if not isinstance(lr, (Sequence)) else lr[::-1] + self.lr_end_factor = lr_end_factor self.optim_args = optim_args self.max_steps = max_steps if not isinstance(max_steps, (Sequence)) else max_steps[::-1] self.max_history = max_history @@ -353,7 +356,12 @@ def make_pyramid( def on_make_transform(self, transform_name, grid, groups=1, **model_args): return new_spatial_transform(transform_name, grid, groups=groups, **model_args) - def on_optimizer(self, grid_transform: SequentialTransform, level) -> tuple[Optimizer, LRScheduler | None]: + def on_optimizer( + self, + grid_transform: SequentialTransform, + level, + lr_end_factor: float | None, + ) -> tuple[Optimizer, LRScheduler | None]: name = self.optim_name cls = getattr(torch.optim, name, None) if cls is None: @@ -362,7 +370,18 @@ def on_optimizer(self, grid_transform: SequentialTransform, level) -> tuple[Opti raise TypeError(f"Requested type '{name}' is not a subclass of torch.optim.Optimizer") kwargs = self.optim_args kwargs["lr"] = self.lr[level] if isinstance(self.lr, (list, tuple)) else self.lr - return cls(grid_transform.parameters(), **kwargs), None + + optimizer = cls(grid_transform.parameters(), **kwargs) + lr_sq = None + if lr_end_factor is not None and lr_end_factor > 0 and lr_end_factor < 1.0: + lr_sq = LinearLR( # type: ignore + optimizer, + start_factor=1.0, + end_factor=lr_end_factor, + total_iters=self.max_steps[level] if isinstance(self.max_steps, (list, tuple)) else self.max_steps, + ) + + return optimizer, lr_sq def on_converged(self, level) -> bool: r"""Check convergence criteria.""" @@ -463,9 +482,7 @@ def on_loss( # noqa: C901 else: mask = None for name, term in self.loss_pairwise_image_terms2.items(): - losses[name] = term( # DICE - moved_data.unsqueeze(0), target_data_seg.unsqueeze(0), mask=mask - ) + losses[name] = term(moved_data.unsqueeze(0), target_data_seg.unsqueeze(0), mask=mask) # DICE result["source"] = moved_data result["target"] = target_data result["mask"] = mask @@ -587,7 +604,7 @@ def _run_level( elif not isinstance(grid_transform, CompositeTransform): raise TypeError("PairwiseImageRegistrationLoss() 'transform' must be of type CompositeTransform") grid_transform = grid_transform.to(self.device) - opt, lr_sq = self.on_optimizer(grid_transform, level) + opt, lr_sq = self.on_optimizer(grid_transform, level, self.lr_end_factor) self.optimizer = opt if isinstance(self.max_steps, int): max_steps = self.max_steps diff --git a/TPTBox/registration/deformable/deformable_reg.py b/TPTBox/registration/deformable/deformable_reg.py index 0a9c73f..5f4c31d 100644 --- a/TPTBox/registration/deformable/deformable_reg.py +++ b/TPTBox/registration/deformable/deformable_reg.py @@ -48,8 +48,9 @@ def __init__( fixed_mask: Image_Reference | None = None, moving_mask: Image_Reference | None = None, # normalize - normalize_strategy: Literal["auto", "CT", "MRI"] - | None = "auto", # Override on_normalize for finer normalization schema or normalize before and set to None. auto: [min,max] -> [0,1]; None: Do noting + normalize_strategy: ( + Literal["auto", "CT", "MRI"] | None + ) = "auto", # Override on_normalize for finer normalization schema or normalize before and set to None. auto: [min,max] -> [0,1]; None: Do noting # Pyramid pyramid_levels: int | None = 3, # 1/None = no pyramid; int: number of stacks, tuple from to (0 is finest) finest_level: int = 0, @@ -63,6 +64,7 @@ def __init__( transform_init: PathStr | None = None, # reload initial flowfield from file optim_name="Adam", # Optimizer name defined in torch.optim. or override on_optimizer finer controle lr: float | Sequence[float] = 0.001, # Learning rate + lr_end_factor: float | None = None, # if set, will use a LinearLR scheduler to reduce the learning rate to this factor * lr optim_args=None, # args of Optimizer with out lr smooth_grad=0.0, verbose=0, @@ -112,6 +114,7 @@ def __init__( transform_init=transform_init, optim_name=optim_name, lr=lr, + lr_end_factor=lr_end_factor, optim_args=optim_args, smooth_grad=smooth_grad, verbose=verbose, diff --git a/TPTBox/registration/deformable/multilabel_segmentation.py b/TPTBox/registration/deformable/multilabel_segmentation.py index d2b9a4b..d01305b 100644 --- a/TPTBox/registration/deformable/multilabel_segmentation.py +++ b/TPTBox/registration/deformable/multilabel_segmentation.py @@ -40,11 +40,13 @@ def __init__( loss_terms=None, # type: ignore weights=None, lr=0.01, + lr_end_factor=None, max_steps=1500, min_delta=1e-06, pyramid_levels=4, coarsest_level=3, finest_level=0, + crop: bool = True, cms_ids: list | None = None, poi_target_cms: POI | None = None, **args, @@ -110,41 +112,56 @@ def __init__( poi_target_cms[k1, k2] = (x, poi_target_cms.shape[1] - 1 - y, z) elif axis == 2: poi_target_cms[k1, k2] = (x, y, poi_target_cms.shape[2] - 1 - z) - print("crop") - crop = 50 - t_crop = (target).compute_crop(0, crop) - target = target.apply_crop(t_crop) - if atlas.is_segmentation_in_border(): - atlas = atlas.apply_pad(((1, 1), (1, 1), (1, 1))) - for i in range(10): # 1000, - if i != 0: - target = target.apply_pad(((25, 25), (25, 25), (25, 25))) - crop += 50 - t_crop = (target).compute_crop(0, crop) # if the angel is to different we need a larger crop... - target_ = target.apply_crop(t_crop) - # Point Registration - print("calc_centroids") + if crop: + print("crop") + crop = 50 + t_crop = (target).compute_crop(0, crop) + target = target.apply_crop(t_crop) + if atlas.is_segmentation_in_border(): + atlas = atlas.apply_pad(((1, 1), (1, 1), (1, 1))) + for i in range(10): # 1000, + if i != 0: + target = target.apply_pad(((25, 25), (25, 25), (25, 25))) + crop += 50 + t_crop = (target).compute_crop(0, crop) # if the angel is to different we need a larger crop... + target_ = target.apply_crop(t_crop) + # Point Registration + print("calc_centroids") + if poi_target_cms is None: + x = target_.extract_label(cms_ids, keep_label=True) if cms_ids else target_ + poi_target = calc_centroids(x, second_stage=40, bar=True) # TODO REMOVE + else: + poi_target = poi_target_cms.resample_from_to(target_) + if poi_cms is None: + x = atlas.extract_label(cms_ids, keep_label=True) if cms_ids else atlas + poi_cms = calc_centroids(x, second_stage=40, bar=True) # This will be needlessly computed all the time + if not poi_cms.assert_affine(atlas, raise_error=False): + poi_cms = poi_cms.resample_from_to(atlas) + self.reg_point = Point_Registration(poi_target, poi_cms) + atlas_reg = self.reg_point.transform_nii(atlas) + if atlas_reg.is_segmentation_in_border(): + print("atlas_reg does touch the border") + else: + target = target_ + break + else: + target_ = target if poi_target_cms is None: x = target_.extract_label(cms_ids, keep_label=True) if cms_ids else target_ poi_target = calc_centroids(x, second_stage=40, bar=True) # TODO REMOVE else: poi_target = poi_target_cms.resample_from_to(target_) - if poi_cms is None: - x = atlas.extract_label(cms_ids, keep_label=True) if cms_ids else atlas - poi_cms = calc_centroids(x, second_stage=40, bar=True) # This will be needlessly computed all the time - if not poi_cms.assert_affine(atlas, raise_error=False): - poi_cms = poi_cms.resample_from_to(atlas) self.reg_point = Point_Registration(poi_target, poi_cms) atlas_reg = self.reg_point.transform_nii(atlas) - if atlas_reg.is_segmentation_in_border(): - print("atlas_reg does touch the border") - else: - target = target_ - break + target = target_ - self.crop = (target + atlas_reg).compute_crop(0, 5) - target = target.apply_crop(self.crop) - atlas_reg = atlas_reg.apply_crop(self.crop) + if crop: + self.crop = (target + atlas_reg).compute_crop(0, 5) + target = target.apply_crop(self.crop) + atlas_reg = atlas_reg.apply_crop(self.crop) + else: + self.crop = None + self.target_grid = target.to_gird() self.reg_deform = Deformable_Registration( target, @@ -154,6 +171,7 @@ def __init__( loss_terms=loss_terms, weights=weights, lr=lr, + lr_end_factor=lr_end_factor, max_steps=max_steps, min_delta=min_delta, pyramid_levels=pyramid_levels, @@ -223,7 +241,7 @@ def load_(cls, w): self.same_side, self.atlas_org, self.target_grid_org, self.target_grid, self.crop = x return self - def forward_nii(self, nii_atlas: NII): + def transform_nii(self, nii_atlas: NII): """ Apply both rigid and deformable registration to a new NII object. @@ -253,7 +271,7 @@ def forward_nii(self, nii_atlas: NII): return target - def forward_poi(self, poi_atlas: POI_Global | POI): + def transform_poi(self, poi_atlas: POI_Global | POI): """ Apply both rigid and deformable registration to a POI (landmark) object. diff --git a/TPTBox/segmentation/TotalVibeSeg/inference_nnunet.py b/TPTBox/segmentation/TotalVibeSeg/inference_nnunet.py index 13a54e4..ec7dc46 100644 --- a/TPTBox/segmentation/TotalVibeSeg/inference_nnunet.py +++ b/TPTBox/segmentation/TotalVibeSeg/inference_nnunet.py @@ -41,7 +41,7 @@ def squash_so_it_fits_in_float16(x: NII): def run_inference_on_file( - idx, + idx: int | Path, input_nii: list[NII], out_file: str | Path | None = None, orientation=None, @@ -74,12 +74,16 @@ def run_inference_on_file( from TPTBox.segmentation.nnUnet_utils.inference_api import load_inf_model, run_inference # noqa: PLC0415 - download_weights(idx, model_path) - try: - nnunet_path = next(next(iter(model_path.glob(f"*{idx:03}*"))).glob("*__nnUNet*ResEnc*")) - except StopIteration: - nnunet_path = next(next(iter(model_path.glob(f"*{idx:03}*"))).glob("*__nnUNetPlans*")) - folds = sorted([int(f.name.split("fold_")[-1]) for f in nnunet_path.glob("fold*")]) + if isinstance(idx, int): + download_weights(idx, model_path) + try: + nnunet_path = next(next(iter(model_path.glob(f"*{idx:03}*"))).glob("*__nnUNet*ResEnc*")) + except StopIteration: + nnunet_path = next(next(iter(model_path.glob(f"*{idx:03}*"))).glob("*__nnUNetPlans*")) + else: + nnunet_path = Path(idx) + assert nnunet_path.exists(), nnunet_path + folds = sorted([f.name.split("fold_")[-1] for f in nnunet_path.glob("fold*")]) if max_folds is not None: folds = max_folds if isinstance(max_folds, list) else folds[:max_folds] From 5b6b8e941f7538e3d147914e33ec5a2bbd72b236 Mon Sep 17 00:00:00 2001 From: iback Date: Fri, 17 Oct 2025 15:36:28 +0000 Subject: [PATCH 5/6] ruff new ignore rule --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index d9d0029..f40cc7c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -157,6 +157,7 @@ ignore = [ "B905", # strict= in zip "UP007", # Union and "|" python 3.9 "PLC0415", # import-outside-top-level + "RUF059", ] # Allow fix for all enabled rules (when `--fix`) is provided. From 3dcdf1b2cb7fcaf7d1aa5935db286134ff83b005 Mon Sep 17 00:00:00 2001 From: iback Date: Fri, 17 Oct 2025 15:40:47 +0000 Subject: [PATCH 6/6] ruff --- TPTBox/registration/deformable/multilabel_segmentation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/TPTBox/registration/deformable/multilabel_segmentation.py b/TPTBox/registration/deformable/multilabel_segmentation.py index d01305b..2a6a76c 100644 --- a/TPTBox/registration/deformable/multilabel_segmentation.py +++ b/TPTBox/registration/deformable/multilabel_segmentation.py @@ -28,7 +28,7 @@ class Register_Multi_Seg: target_grid (NII): Cropped spatial grid used for deformable registration. """ - def __init__( + def __init__( # noqa: C901 self, target: NII, atlas: NII,