From 56ddf766da21cf48648adac183cb805b40a637bc Mon Sep 17 00:00:00 2001 From: robert Date: Thu, 31 Jul 2025 10:50:32 +0200 Subject: [PATCH 01/28] improved defaults --- TPTBox/registration/deepali/deepali_trainer.py | 14 +++++++++++--- .../deformable/multilabel_segmentation.py | 8 ++++---- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/TPTBox/registration/deepali/deepali_trainer.py b/TPTBox/registration/deepali/deepali_trainer.py index a2edd5a..88db61f 100644 --- a/TPTBox/registration/deepali/deepali_trainer.py +++ b/TPTBox/registration/deepali/deepali_trainer.py @@ -101,7 +101,7 @@ def __init__( smooth_grad=0.0, verbose=0, max_steps: int | Sequence[int] = 250, # Early stopping. override on_converged finer control - max_history: int | None = None, + max_history: int | None = 100, # Used for on_converged. look at the last n sample to compute the convergence min_value=0.0, # Early stopping. override on_converged finer control min_delta: float | Sequence[float] = 0.0, # Early stopping. override on_converged finer control loss_terms: list[LOSS | str] | dict[str, LOSS] | dict[str, str] | dict[str, tuple[str, dict]] | None = None, @@ -222,7 +222,16 @@ def __init__( u = [a for a in u if a != 0] # Build a mapping from original label -> index (starting from 1) mapping = {int(label.item()): idx for idx, label in enumerate(u, 1)} + num_classes = len(mapping) + 1 # Add 1 for background or assume no 0 + u2 = torch.unique(self.source_seg_org.tensor()) + u2 = u2.detach().cpu() + u2 = [a for a in u2 if a != 0] + for idx in u2: + idx = int(idx.item()) # noqa: PLW2901 + if idx not in mapping: + print("Warning no matching idx found:", idx) + mapping[idx] = 0 # Remap the segmentation labels according to mapping source_remapped = self.source_seg_org.tensor().clone() target_remapped = self.target_seg_org.tensor().clone() @@ -231,8 +240,7 @@ def __init__( target_remapped[self.target_seg_org.tensor() == orig_label] = new_label # Convert to one-hot if needed (optional) - num_classes = len(mapping) + 1 # Add 1 for background or assume no 0 - print(f"Found {num_classes=}, {source_remapped.unique()}, {target_remapped.unique()}") + print(f"Found {num_classes=}, {source_remapped.unique()}, {target_remapped.unique()}; internal mapping: {mapping}") one_hot_source = ( (torch.nn.functional.one_hot(source_remapped.long(), num_classes).to(self._dtype).to(self.device)) .permute(0, 4, 1, 2, 3) diff --git a/TPTBox/registration/deformable/multilabel_segmentation.py b/TPTBox/registration/deformable/multilabel_segmentation.py index d2b9a4b..2435265 100644 --- a/TPTBox/registration/deformable/multilabel_segmentation.py +++ b/TPTBox/registration/deformable/multilabel_segmentation.py @@ -41,7 +41,7 @@ def __init__( weights=None, lr=0.01, max_steps=1500, - min_delta=1e-06, + min_delta: float | list[float] = 1e-06, pyramid_levels=4, coarsest_level=3, finest_level=0, @@ -116,10 +116,10 @@ def __init__( target = target.apply_crop(t_crop) if atlas.is_segmentation_in_border(): atlas = atlas.apply_pad(((1, 1), (1, 1), (1, 1))) - for i in range(10): # 1000, + for i in range(2): # 1000, if i != 0: - target = target.apply_pad(((25, 25), (25, 25), (25, 25))) - crop += 50 + target = target.apply_pad(((50, 50), (50, 50), (50, 50))) + crop += 75 t_crop = (target).compute_crop(0, crop) # if the angel is to different we need a larger crop... target_ = target.apply_crop(t_crop) # Point Registration From 24a9cd7164aa52b94d749649d11de3ab380e5ad7 Mon Sep 17 00:00:00 2001 From: ga84mun Date: Mon, 1 Sep 2025 11:42:18 +0000 Subject: [PATCH 02/28] add options to some functions (raise error, 2d infect, better defautls) --- TPTBox/core/nii_wrapper.py | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/TPTBox/core/nii_wrapper.py b/TPTBox/core/nii_wrapper.py index d924d7b..7f15524 100755 --- a/TPTBox/core/nii_wrapper.py +++ b/TPTBox/core/nii_wrapper.py @@ -714,7 +714,7 @@ def reorient_(self:Self, axcodes_to: AX_CODES|None = ("P", "I", "R"), verbose:lo return self.reorient(axcodes_to=axcodes_to, verbose=verbose,inplace=True) - def compute_crop(self,minimum: float=0, dist: float = 0, use_mm=False, other_crop:tuple[slice,...]|None=None, maximum_size:tuple[slice,...]|int|tuple[int,...]|None=None,)->tuple[slice,slice,slice]: + def compute_crop(self,minimum: float=0, dist: float = 0, use_mm=False, other_crop:tuple[slice,...]|None=None, maximum_size:tuple[slice,...]|int|tuple[int,...]|None=None, raise_error=True)->tuple[slice,slice,slice]: """ Computes the minimum slice that removes unused space from the image and returns the corresponding slice tuple along with the origin shift required for centroids. @@ -723,6 +723,7 @@ def compute_crop(self,minimum: float=0, dist: float = 0, use_mm=False, other_cro dist (int): The amount of padding to be added to the cropped image. Default value is 0. use_mm: dist will be mm instead of number of voxels other_crop (tuple[slice,...], optional): A tuple of slice objects representing the slice of an other image to be combined with the current slice. Default value is None. + raise_error: if crop is empty a "ValueError: bbox_nd: img is empty, cannot calculate a bbox" is produced. When False return None instead. Returns: ex_slice: A tuple of slice objects that need to be applied to crop the image. @@ -738,7 +739,7 @@ def compute_crop(self,minimum: float=0, dist: float = 0, use_mm=False, other_cro d = np.around(dist / np.asarray(self.zoom)).astype(int) if use_mm else (int(dist),int(dist),int(dist)) array = self.get_array() #+ minimum - ex_slice = list(np_bbox_binary(array > minimum, px_dist=d)) + ex_slice = list(np_bbox_binary(array > minimum, px_dist=d,raise_error=raise_error)) if other_crop is not None: assert all((a.step is None) for a in other_crop), 'Only None slice is supported for combining x' @@ -1127,6 +1128,7 @@ def smooth_gaussian_labelwise( dilate_connectivity: int = 1, smooth_background: bool = True, inplace: bool = False, + verbose:logging=False, ): """Smoothes the segmentation mask by applying a gaussian filter label-wise and then using argmax to derive the smoothed segmentation labels again. @@ -1145,6 +1147,7 @@ def smooth_gaussian_labelwise( NII: The smoothed NII object. """ assert self.seg, "You cannot use this on a non-segmentation NII" + log.print("smooth_gaussian_labelwise",verbose=verbose) smoothed = np_smooth_gaussian_labelwise(self.get_seg_array(), label_to_smooth=label_to_smooth, sigma=sigma, radius=radius, truncate=truncate, boundary_mode=boundary_mode, dilate_prior=dilate_prior, dilate_connectivity=dilate_connectivity,smooth_background=smooth_background,) return self.set_array(smoothed,inplace,verbose=False) @@ -1363,7 +1366,7 @@ def boundary_mask(self, threshold: float,inplace = False): """ return self.set_array(np_calc_boundary_mask(self.get_array(),threshold),inplace=inplace,verbose=False) - def get_connected_components(self, labels: int |list[int]=1, connectivity: int = 3, include_zero: bool=False,inplace=False) -> Self: # noqa: ARG002 + def get_connected_components(self, labels: int |list[int]|None=None, connectivity: int = 3, include_zero: bool=False,inplace=False) -> Self: # noqa: ARG002 assert self.seg, "This only works on segmentations" out, _ = np_connected_components(self.get_seg_array(), label_ref=labels, connectivity=connectivity, include_zero=include_zero) return self.set_array(out,inplace=inplace) @@ -1681,7 +1684,8 @@ def infect_conv(self: NII, reference_mask: NII, max_iters=100,inplace=False): org = self.get_seg_array() org[crop] = self_mask return self.set_array(org,inplace=inplace) - def infect(self: NII, reference_mask: NII, inplace=False,verbose=True): + + def infect(self: NII, reference_mask: NII, inplace=False,verbose=True,axis:int|str|None=None): """ Expands labels from self_mask into regions of reference_mask == 1 via breadth-first diffusion. @@ -1699,13 +1703,22 @@ def infect(self: NII, reference_mask: NII, inplace=False,verbose=True): ref_mask = np.clip(reference_mask.get_seg_array(), 0, 1) ref_mask[self_mask_org != 0] = 0 searched = np.clip(self_mask,0,1).astype(np.uint8) - ndim = len(self_mask.shape) # Define neighborhood kernel - if ndim == 3: + if axis is None: kernel = [(1,0,0),(0,1,0),(0,0,1),(-1,0,0),(0,-1,0),(0,0,-1)] else: - raise NotImplementedError("Only 2D or 3D masks are supported.") + if isinstance(axis,str): + axis = self.get_axis(axis) + if axis == 0: + kernel = [(0,1,0),(0,0,1),(0,-1,0),(0,0,-1)] + elif axis == 1: + kernel = [(1,0,0),(0,0,1),(-1,0,0),(0,0,-1)] + elif axis == 2: + kernel = [(1,0,0),(0,1,0),(-1,0,0),(0,-1,0)] + else: + raise NotImplementedError(axis) + search = [] coords = np.where(self_mask != 0) def _add_idx(x,y,z,v): From cd20aa8edcf329bda9060bd7590526467195a930 Mon Sep 17 00:00:00 2001 From: ga84mun Date: Mon, 1 Sep 2025 11:42:53 +0000 Subject: [PATCH 03/28] update constance Full Body --- TPTBox/core/vert_constants.py | 90 +++++++++++++++++++++++++++++++++-- 1 file changed, 87 insertions(+), 3 deletions(-) diff --git a/TPTBox/core/vert_constants.py b/TPTBox/core/vert_constants.py index 79d4e94..dd60c5f 100755 --- a/TPTBox/core/vert_constants.py +++ b/TPTBox/core/vert_constants.py @@ -148,8 +148,12 @@ class Full_Body_Instance(Abstract_lvl): scapula_left = 103 humerus_right = 4 humerus_left = 104 - hand_rest_right = 5 - hand_rest_left = 105 + hand_right = 5 + hand_left = 105 + radius_right = 60 + radius_left = 160 + ulna_right = 61 + ulna_left = 161 sternum = 6 costal_cartilage = 7 rib_right = 8 @@ -229,7 +233,87 @@ class Full_Body_Instance(Abstract_lvl): subcutaneous_fat = 57 muscle_other = 58 inner_fat = 59 - ignore = 60 + ignore = 63 + + @classmethod + def get_totalVibeSegMapping(cls): + return { + 1: Full_Body_Instance.spleen.value, # spleen + 2: Full_Body_Instance.kidney_right.value, # kidney_right + 3: Full_Body_Instance.kidney_left.value, # kidney_left + 4: Full_Body_Instance.gallbladder.value, # gallbladder + 5: Full_Body_Instance.liver.value, # liver + 6: Full_Body_Instance.stomach.value, # stomach + 7: Full_Body_Instance.pancreas.value, # pancreas + 8: Full_Body_Instance.adrenal_gland_right.value, # adrenal_gland_right + 9: Full_Body_Instance.adrenal_gland_left.value, # adrenal_gland_left + 10: Full_Body_Instance.lung_left.value, # lung_upper_lobe_left + 11: Full_Body_Instance.lung_left.value, # lung_lower_lobe_left + 12: Full_Body_Instance.lung_right.value, # lung_upper_lobe_right + 13: Full_Body_Instance.lung_right.value, # lung_middle_lobe_right + 14: Full_Body_Instance.lung_right.value, # lung_lower_lobe_right + 15: Full_Body_Instance.esophagus.value, # esophagus + 16: Full_Body_Instance.trachea.value, # trachea + 17: Full_Body_Instance.thyroid_gland_right.value, # thyroid_gland + 18: Full_Body_Instance.intestine.value, # intestine + 19: Full_Body_Instance.doudenum.value, # duodenum + 20: Full_Body_Instance.rib_right.value, # unused + 21: Full_Body_Instance.urinary_bladder.value, # urinary_bladder + 22: Full_Body_Instance.prostate.value, # prostate + 23: Full_Body_Instance.sacrum.value, # sacrum + 24: Full_Body_Instance.heart.value, # heart + 25: Full_Body_Instance.aorta.value, # aorta + 26: Full_Body_Instance.pulmonary_vein.value, # pulmonary_vein + 27: Full_Body_Instance.brachiocephalic_trunk.value, # brachiocephalic_trunk + 28: Full_Body_Instance.subclavian_artery_right.value, # subclavian_artery_right + 29: Full_Body_Instance.subclavian_artery_left.value, # subclavian_artery_left + 30: Full_Body_Instance.common_carotid_artery_right.value, # common_carotid_artery_right + 31: Full_Body_Instance.common_carotid_artery_left.value, # common_carotid_artery_left + 32: Full_Body_Instance.brachiocephalic_vein_left.value, # brachiocephalic_vein_left + 33: Full_Body_Instance.brachiocephalic_vein_right.value, # brachiocephalic_vein_right + 34: Full_Body_Instance.atrial_appendage_left.value, # atrial_appendage_left + 35: Full_Body_Instance.superior_vena_cava.value, # superior_vena_cava + 36: Full_Body_Instance.inferior_vena_cava.value, # inferior_vena_cava + 37: Full_Body_Instance.portal_vein_and_splenic_vein.value, # portal_vein_and_splenic_vein + 38: Full_Body_Instance.iliac_artery_left.value, # iliac_artery_left + 39: Full_Body_Instance.iliac_artery_right.value, # iliac_artery_right + 40: Full_Body_Instance.iliac_vena_left.value, # iliac_vena_left + 41: Full_Body_Instance.iliac_vena_right.value, # iliac_vena_right + 42: Full_Body_Instance.humerus_left.value, # humerus_left + 43: Full_Body_Instance.humerus_right.value, # humerus_right + 44: Full_Body_Instance.scapula_left.value, # scapula_left + 45: Full_Body_Instance.scapula_right.value, # scapula_right + 46: Full_Body_Instance.clavicula_left.value, # clavicula_left + 47: Full_Body_Instance.clavicula_right.value, # clavicula_right + 48: Full_Body_Instance.femur_left.value, # femur_left + 49: Full_Body_Instance.femur_right.value, # femur_right + 50: Full_Body_Instance.hip_left.value, # hip_left + 51: Full_Body_Instance.hip_right.value, # hip_right + 52: Full_Body_Instance.channel.value, # spinal_cord + 53: Full_Body_Instance.gluteus_maximus_left.value, # gluteus_maximus_left + 54: Full_Body_Instance.gluteus_maximus_right.value, # gluteus_maximus_right + 55: Full_Body_Instance.gluteus_medius_left.value, # gluteus_medius_left + 56: Full_Body_Instance.gluteus_medius_right.value, # gluteus_medius_right + 57: Full_Body_Instance.gluteus_minimus_left.value, # gluteus_minimus_left + 58: Full_Body_Instance.gluteus_minimus_right.value, # gluteus_minimus_right + 59: Full_Body_Instance.autochthon_left.value, # autochthon_left + 60: Full_Body_Instance.autochthon_right.value, # autochthon_right + 61: Full_Body_Instance.iliopsoas_left.value, # iliopsoas_left + 62: Full_Body_Instance.iliopsoas_right.value, # iliopsoas_right + 63: Full_Body_Instance.sternum.value, # sternum + 64: Full_Body_Instance.costal_cartilage.value, # costal_cartilages + 65: Full_Body_Instance.subcutaneous_fat.value, # subcutaneous_fat + 66: Full_Body_Instance.muscle_other.value, # muscle + 67: Full_Body_Instance.inner_fat.value, # inner_fat + 68: Full_Body_Instance.ivd.value, # IVD + 69: Full_Body_Instance.vert_body.value, # vertebra_body + 70: Full_Body_Instance.vert_post.value, # vertebra_posterior_elements + 71: Full_Body_Instance.channel.value, # spinal_channel + 72: Full_Body_Instance.ignore.value, # bone_other + 73: 0, + 77: 0, # Negative + 100: Full_Body_Instance.ignore.value, + } class Lower_Body(Abstract_lvl): From f45386c18d0effec71e0c3dbefa29bf479d56875 Mon Sep 17 00:00:00 2001 From: ga84mun Date: Mon, 1 Sep 2025 11:43:19 +0000 Subject: [PATCH 04/28] 3.9 comp --- TPTBox/core/poi_fun/save_mkr.py | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/TPTBox/core/poi_fun/save_mkr.py b/TPTBox/core/poi_fun/save_mkr.py index d99357e..e1c0a59 100644 --- a/TPTBox/core/poi_fun/save_mkr.py +++ b/TPTBox/core/poi_fun/save_mkr.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import json import random from pathlib import Path ###### GLOBAL POI ##### -from typing import TYPE_CHECKING, Any, Literal, Optional, TypedDict +from typing import TYPE_CHECKING, Any, Literal, Optional, TypedDict, Union import numpy as np from typing_extensions import NotRequired @@ -107,7 +109,7 @@ class MeasurementVolumeMarkup(Markup, total=False): boundingBox: list[float] -MKR_DEFINITION = MKR_Lines | dict +MKR_DEFINITION = Union[MKR_Lines, dict] def _get_display_dict( @@ -226,7 +228,7 @@ def _make_default_markup( def _get_markup_lines( definition: MKR_Lines, - poi: "POI_Global", + poi: POI_Global, coordinate_system: Literal["LPS", "RAS"], split_by_region=False, split_by_subregion=False, @@ -235,7 +237,6 @@ def _get_markup_lines( if display is None: display = {} key_points = definition.get("key_points") - region, subregion = key_points[0] color = _get_markup_color(definition, region, subregion, split_by_region, split_by_subregion) display = _get_display_dict(display | definition.get("display", {}), selectedColor=color) @@ -251,7 +252,7 @@ def _get_markup_lines( return _make_default_markup("Line", definition.get("name"), coordinate_system, controlPoints=controlPoints, display=display) -def get_desc(self: "POI_Global", region, subregion): +def get_desc(self: POI_Global, region, subregion): try: name = self.level_two_info(subregion).name except Exception: @@ -263,8 +264,18 @@ def get_desc(self: "POI_Global", region, subregion): return name, name2 +def _get_key(region, subregion, split_by_region, split_by_subregion): + key = "P" + if split_by_region: + key += str(region) + "_" + if split_by_subregion: + key += str(subregion) + + return key + + def _save_mrk( - poi: "POI_Global", + poi: POI_Global, filepath: str | Path, color=None, split_by_region=True, @@ -301,11 +312,7 @@ def _save_mrk( if add_points: # Create list of control points for region, subregion, coords in poi.centroids.items(): - key = "P" - if split_by_region: - key += str(region) + "_" - if split_by_subregion: - key += str(subregion) + key = _get_key(region, subregion, split_by_region, split_by_subregion) name, name2 = get_desc(poi, region, subregion) if key not in list_markups: list_markups[key] = _make_default_markup( @@ -325,7 +332,7 @@ def _save_mrk( _get_control_point({}, coords, f"{region}-{subregion}", f"{region}-{subregion}", name, name2) ) markups = list(list_markups.values()) - # Lines + [markups.append(_get_markup_lines(line, poi, coordinate_system, split_by_region, split_by_subregion, display)) for line in add_lines] mrk_data = { "@schema": "https://raw.githubusercontent.com/slicer/slicer/master/Modules/Loadable/Markups/Resources/Schema/markups-schema-v1.0.3.json#", From 34483e91cba496b31440300dc43144f0af5c36a5 Mon Sep 17 00:00:00 2001 From: ga84mun Date: Mon, 1 Sep 2025 11:44:20 +0000 Subject: [PATCH 05/28] snapshot3d can now update the resolution correctly --- TPTBox/mesh3D/snapshot3D.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/TPTBox/mesh3D/snapshot3D.py b/TPTBox/mesh3D/snapshot3D.py index 5fe5af3..b825f18 100644 --- a/TPTBox/mesh3D/snapshot3D.py +++ b/TPTBox/mesh3D/snapshot3D.py @@ -35,6 +35,7 @@ def make_snapshot3D( smoothing=20, resolution: float | None = None, width_factor=1.0, + png_magnify=1, verbose=True, crop=True, ) -> Image.Image: @@ -80,7 +81,7 @@ def make_snapshot3D( nii = to_nii_seg(img) if crop: try: - nii.apply_crop_(nii.compute_crop()) + nii.apply_crop_(nii.compute_crop(0, 5)) except ValueError: pass if resolution is None: @@ -100,17 +101,17 @@ def make_snapshot3D( # TOP : ("A", "I", "R") nii = nii.reorient(("A", "S", "L")).rescale_((resolution, resolution, resolution)) width = int(max(nii.shape[0], nii.shape[2]) * width_factor) - window_size = (width * len(ids_list), nii.shape[1]) + window_size = (width * len(ids_list) * png_magnify, nii.shape[1] * png_magnify) with Xvfb(): scene = window.Scene() - show_m = window.ShowManager(scene=scene, size=window_size, reset_camera=False) + show_m = window.ShowManager(scene=scene, size=window_size, reset_camera=False, png_magnify=png_magnify) show_m.initialize() for i, ids in enumerate(ids_list): x = width * i _plot_sub_seg(scene, nii.extract_label(ids, keep_label=True), x, 0, smoothing, view[i % len(view)]) scene.projection(proj_type="parallel") scene.reset_camera_tight(margin_factor=1.02) - window.record(scene, size=window_size, out_path=output_path, reset_camera=False) + window.record(scene=scene, size=window_size, out_path=output_path, reset_camera=False) scene.clear() if not is_tmp: logger.on_save("Save Snapshot3D:", output_path, verbose=verbose) @@ -121,15 +122,17 @@ def make_snapshot3D( def make_snapshot3D_parallel( - imgs: list[Path | str], - output_paths: list[Image_Reference], + imgs: list[Image_Reference], + output_paths: list[Path | str], view: VIEW | list[VIEW] = "A", ids_list: list[Sequence[int]] | None = None, smoothing=20, - resolution: float = 2, + resolution: float = 1, cpus=10, width_factor=1.0, + png_magnify=1, override=True, + crop=True, ): ress = [] with Pool(cpus) as p: # type: ignore @@ -146,6 +149,8 @@ def make_snapshot3D_parallel( "smoothing": smoothing, "resolution": resolution, "width_factor": width_factor, + "png_magnify": png_magnify, + "crop": crop, }, ) ress.append(res) From e94f4e792deddc00e1694508ebc27f27b45a1bec Mon Sep 17 00:00:00 2001 From: ga84mun Date: Mon, 1 Sep 2025 11:46:39 +0000 Subject: [PATCH 06/28] update max_histroy for early stoping --- TPTBox/registration/deepali/deepali_model.py | 14 +++-- .../registration/deepali/deepali_trainer.py | 56 +++++++++++++------ .../registration/deformable/deformable_reg.py | 4 +- .../deformable/multilabel_segmentation.py | 14 +++-- 4 files changed, 60 insertions(+), 28 deletions(-) diff --git a/TPTBox/registration/deepali/deepali_model.py b/TPTBox/registration/deepali/deepali_model.py index 581ac62..4d89f90 100644 --- a/TPTBox/registration/deepali/deepali_model.py +++ b/TPTBox/registration/deepali/deepali_model.py @@ -6,7 +6,7 @@ import time from collections.abc import Sequence from pathlib import Path -from typing import Literal, Optional, Self, Union +from typing import Literal, Union import torch import torch.optim @@ -15,16 +15,18 @@ from deepali.core import Grid as Deepali_Grid from deepali.data import Image as deepaliImage from deepali.modules import TransformImage -from deepali.spatial import ( - SpatialTransform, -) +from deepali.spatial import SpatialTransform +from typing_extensions import Self from TPTBox import NII, POI, Image_Reference, to_nii from TPTBox.core.compat import zip_strict from TPTBox.core.internal.deep_learning_utils import DEVICES, get_device from TPTBox.core.nii_poi_abstract import Grid as TPTBox_Grid from TPTBox.core.nii_poi_abstract import Has_Grid -from TPTBox.registration.deepali.deepali_trainer import LOSS, DeepaliPairwiseImageTrainer +from TPTBox.registration.deepali.deepali_trainer import ( + LOSS, + DeepaliPairwiseImageTrainer, +) def center_of_mass(tensor): @@ -192,7 +194,7 @@ def __init__( smooth_grad=0.0, verbose=99, max_steps: int | Sequence[int] = 250, # Early stopping. override on_converged finer control - max_history: int | None = None, + max_history: int | None = 100, min_value=0.0, # Early stopping. override on_converged finer control min_delta: float | Sequence[float] = 0.0, # Early stopping. override on_converged finer control loss_terms: list[LOSS | str] | dict[str, LOSS] | dict[str, str] | dict[str, tuple[str, dict]] | None = None, diff --git a/TPTBox/registration/deepali/deepali_trainer.py b/TPTBox/registration/deepali/deepali_trainer.py index 88db61f..ff0bff4 100644 --- a/TPTBox/registration/deepali/deepali_trainer.py +++ b/TPTBox/registration/deepali/deepali_trainer.py @@ -197,6 +197,32 @@ def __init__( self.min_delta = min_delta if not isinstance(min_delta, (Sequence)) else min_delta[::-1] self.verbose = verbose self.loss_terms, self.weights = parse_loss(loss_terms, weights) + + self.pyramid_levels = pyramid_levels + self.finest_level = finest_level + self.coarsest_level = coarsest_level + self.dims = dims + self.pyramid_finest_spacing = pyramid_finest_spacing + self.pyramid_min_size = pyramid_min_size + self._load_all(source, target, source_seg, target_seg, source_mask, target_mask) + + self.source_pset = source_pset + self.target_pset = target_pset + self.source_landmarks = source_landmarks + self.target_landmarks = target_landmarks + self.smooth_grad = smooth_grad + self._eval_hooks = OrderedDict() + self._step_hooks = OrderedDict() + + def _load_all( + self, + source, + target, + source_seg, + target_seg, + source_mask=None, + target_mask=None, + ): # reading images self.source = self._read(source) self.target = self._read(target) @@ -208,9 +234,15 @@ def __init__( # normalize self.source, self.target = self.on_normalize(self.source, self.target) # Pyramid - self.source_pyramid, self.target_pyramid = self.make_pyramid( - self.source, self.target, pyramid_levels, finest_level, coarsest_level, dims, pyramid_finest_spacing, pyramid_min_size + self.source, + self.target, + self.pyramid_levels, + self.finest_level, + self.coarsest_level, + self.dims, + self.pyramid_finest_spacing, + self.pyramid_min_size, ) if source_seg is not None or target_seg is not None: with torch.no_grad(): @@ -270,12 +302,12 @@ def __init__( self.source_pyramid_seg, self.target_pyramid_seg = self.make_pyramid( self.source_seg, self.target_seg, - pyramid_levels, - finest_level, - coarsest_level, - dims, - pyramid_finest_spacing, - pyramid_min_size, + self.pyramid_levels, + self.finest_level, + self.coarsest_level, + self.dims, + self.pyramid_finest_spacing, + self.pyramid_min_size, ) print("make_pyramid seg end", self.source_seg.dtype) else: @@ -284,14 +316,6 @@ def __init__( self.source_pyramid_seg = None self.target_pyramid_seg = None - self.source_pset = source_pset - self.target_pset = target_pset - self.source_landmarks = source_landmarks - self.target_landmarks = target_landmarks - self.smooth_grad = smooth_grad - self._eval_hooks = OrderedDict() - self._step_hooks = OrderedDict() - def on_normalize(self, source: Image, target: Image): return normalize_img(source, self.normalize_strategy), normalize_img(target, self.normalize_strategy) diff --git a/TPTBox/registration/deformable/deformable_reg.py b/TPTBox/registration/deformable/deformable_reg.py index 0a9c73f..00a6182 100644 --- a/TPTBox/registration/deformable/deformable_reg.py +++ b/TPTBox/registration/deformable/deformable_reg.py @@ -2,7 +2,7 @@ # pip install hf-deepali from collections.abc import Sequence -from typing import Literal, Optional, Union +from typing import Literal, Union import numpy as np import torch @@ -67,7 +67,7 @@ def __init__( smooth_grad=0.0, verbose=0, max_steps: int | Sequence[int] = 1000, # Early stopping. override on_converged finer controle - max_history: int | None = None, + max_history: int | None = 100, min_value=0.0, # Early stopping. override on_converged finer controle min_delta: float | Sequence[float] = -0.0001, # Early stopping. override on_converged finer controle loss_terms: list[LOSS | str] | dict[str, LOSS] | dict[str, str] | dict[str, tuple[str, dict]] | None = None, diff --git a/TPTBox/registration/deformable/multilabel_segmentation.py b/TPTBox/registration/deformable/multilabel_segmentation.py index 2435265..e6bd523 100644 --- a/TPTBox/registration/deformable/multilabel_segmentation.py +++ b/TPTBox/registration/deformable/multilabel_segmentation.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pickle from pathlib import Path @@ -32,8 +34,8 @@ def __init__( self, target: NII, atlas: NII, - poi_cms: POI | None, - same_side: bool, + poi_cms: POI | None = None, + same_side: bool = True, verbose=99, gpu=0, ddevice: DEVICES = "cuda", @@ -47,6 +49,8 @@ def __init__( finest_level=0, cms_ids: list | None = None, poi_target_cms: POI | None = None, + max_history=100, + change_after_point_reg=lambda x, y: (x, y), **args, ): """ @@ -142,10 +146,11 @@ def __init__( target = target_ break target = target_ - self.crop = (target + atlas_reg).compute_crop(0, 5) + self.crop = (target + atlas_reg).compute_crop(0, 10) target = target.apply_crop(self.crop) atlas_reg = atlas_reg.apply_crop(self.crop) self.target_grid = target.to_gird() + target, atlas_reg = change_after_point_reg(target, atlas_reg) self.reg_deform = Deformable_Registration( target, atlas_reg, @@ -162,6 +167,7 @@ def __init__( verbose=verbose, gpu=gpu, ddevice=ddevice, + max_history=max_history, **args, ) @@ -220,7 +226,7 @@ def load_(cls, w): self = cls.__new__(cls) self.reg_point = Point_Registration.load_(t0) self.reg_deform = Deformable_Registration.load_(t1) - self.same_side, self.atlas_org, self.target_grid_org, self.target_grid, self.crop = x + (self.same_side, self.atlas_org, self.target_grid_org, self.target_grid, self.crop) = x return self def forward_nii(self, nii_atlas: NII): From 0307165dc68a5689268ee91522bffbf41f760d2b Mon Sep 17 00:00:00 2001 From: ga84mun Date: Mon, 1 Sep 2025 11:47:04 +0000 Subject: [PATCH 07/28] add wrapper for file names --- TPTBox/segmentation/spineps.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/TPTBox/segmentation/spineps.py b/TPTBox/segmentation/spineps.py index 7eeab82..f3592f5 100644 --- a/TPTBox/segmentation/spineps.py +++ b/TPTBox/segmentation/spineps.py @@ -13,6 +13,27 @@ def injection_function(seg_nii: NII): return seg_nii +def get_outpaths_spineps_single( + file_path: str | Path | BIDS_FILE, + dataset=None, + derivative_name="derivative", + ignore_bids_filter=True, +): + from spineps.seg_run import output_paths_from_input + + if not isinstance(file_path, BIDS_FILE): + file_path = Path(file_path) + file_path = BIDS_FILE(file_path, file_path.parent if dataset is None else dataset) + output_paths = output_paths_from_input( + file_path, + derivative_name, + None, + input_format=file_path.format, + non_strict_mode=ignore_bids_filter, + ) + return output_paths + + def run_spineps_single( file_path: str | Path | BIDS_FILE, dataset=None, @@ -181,7 +202,11 @@ def _run_spineps_vert( verbose=True, use_cpu=False, ): - from spineps import get_instance_model, phase_postprocess_combined, predict_instance_mask + from spineps import ( + get_instance_model, + phase_postprocess_combined, + predict_instance_mask, + ) from spineps.get_models import get_actual_model if isinstance(model_instance, Path): From a7e1864a9b987049a1e7b8615c1dcac42c9d896b Mon Sep 17 00:00:00 2001 From: ga84mun Date: Mon, 1 Sep 2025 11:48:43 +0000 Subject: [PATCH 08/28] update inference --- .../TotalVibeSeg/inference_nnunet.py | 33 +++++++++++++++++-- .../segmentation/TotalVibeSeg/totalvibeseg.py | 2 +- TPTBox/segmentation/nnUnet_utils/predictor.py | 4 +++ 3 files changed, 35 insertions(+), 4 deletions(-) diff --git a/TPTBox/segmentation/TotalVibeSeg/inference_nnunet.py b/TPTBox/segmentation/TotalVibeSeg/inference_nnunet.py index 13a54e4..8681524 100644 --- a/TPTBox/segmentation/TotalVibeSeg/inference_nnunet.py +++ b/TPTBox/segmentation/TotalVibeSeg/inference_nnunet.py @@ -86,11 +86,20 @@ def run_inference_on_file( # if idx in _unets: # nnunet = _unets[idx] # else: - print("load model", nnunet_path.name, "; folds", folds) if verbose else None + print("load model", nnunet_path, "; folds", folds) if verbose else None with open(Path(nnunet_path, "plans.json")) as f: plans_info = json.load(f) with open(Path(nnunet_path, "dataset.json")) as f: ds_info = json.load(f) + inference_config = Path(nnunet_path, "inference_config.json") + if inference_config.exists(): + with open() as f: + ds_info2 = json.load(f) + if "model_expected_orientation" in ds_info2: + ds_info["orientation"] = ds_info2["model_expected_orientation"] + if "resolution_range" in ds_info2: + ds_info["spacing"] = ds_info2["resolution_range"] + nnunet = load_inf_model( nnunet_path, allow_non_final=True, @@ -107,13 +116,30 @@ def run_inference_on_file( # _unets[idx] = nnunet if "orientation" in ds_info: orientation = ds_info["orientation"] + zoom = None + orientation_ref = None og_nii = input_nii[0].copy() try: + zoom_old = ds_info.get("spacing") zoom = plans_info["configurations"]["3d_fullres"]["spacing"] order = plans_info["transpose_backward"] + # order2 = plans_info["transpose_forward"] zoom = [zoom[order[0]], zoom[order[1]], zoom[order[2]]][::-1] + orientation_ref = ("P", "I", "R") + orientation_ref = [ + orientation_ref[order[0]], + orientation_ref[order[1]], + orientation_ref[order[2]], + ] # [::-1] + + # zoom_old = zoom_old[::-1] + if zoom is None: + pass + + else: + zoom = [float(z) for z in zoom] except Exception: pass assert len(ds_info["channel_names"]) == len(input_nii), ( @@ -123,12 +149,13 @@ def run_inference_on_file( nnunet_path, ) if orientation is not None: - print("orientation", orientation) if verbose else None + print("orientation", orientation, f"{orientation_ref=}") if verbose else None input_nii = [i.reorient(orientation) for i in input_nii] if zoom is not None: - print("rescale", zoom) if verbose else None + print("rescale", zoom, f"{zoom_old=}, {order=}") if verbose else None input_nii = [i.rescale_(zoom, mode=mode) for i in input_nii] + print(input_nii) print("squash to float16") if verbose else None input_nii = [squash_so_it_fits_in_float16(i) for i in input_nii] diff --git a/TPTBox/segmentation/TotalVibeSeg/totalvibeseg.py b/TPTBox/segmentation/TotalVibeSeg/totalvibeseg.py index ef8febc..2a2d5ee 100644 --- a/TPTBox/segmentation/TotalVibeSeg/totalvibeseg.py +++ b/TPTBox/segmentation/TotalVibeSeg/totalvibeseg.py @@ -88,7 +88,7 @@ def run_totalvibeseg( override=False, gpu=0, ddevice: Literal["cpu", "cuda", "mps"] = "cuda", - dataset_id=80, + dataset_id=100, padd=0, keep_size=False, # Keep size of the model Segmentation **args, diff --git a/TPTBox/segmentation/nnUnet_utils/predictor.py b/TPTBox/segmentation/nnUnet_utils/predictor.py index eecbb1f..a917b54 100755 --- a/TPTBox/segmentation/nnUnet_utils/predictor.py +++ b/TPTBox/segmentation/nnUnet_utils/predictor.py @@ -567,6 +567,8 @@ def _run_prediction_splits( predicted_logits, n_predictions, _, _ = self._allocate(data, "cpu", pbar) for e, i in enumerate(inter_mediate_slice, 1): slices = i.get_intermediate() + if slices is None: + continue sub_data = data[slices] logits, n_pred = self._run_sub( sub_data, network, self.device, i.get_slices(), pbar, addendum=f"chunks={e}/{len(inter_mediate_slice)}" @@ -689,6 +691,8 @@ def add_slicer(self, s: tuple[slice, ...]): self.slicers.append(s) def get_intermediate(self): + if self.min_s is None or self.max_s is None: + return None return ( slice(None), *tuple(slice(mi, ma) for mi, ma in zip(self.min_s, self.max_s)), From dd831219db3ebddc1e4ac819fcf0e4e7df4f1217 Mon Sep 17 00:00:00 2001 From: ga84mun Date: Thu, 11 Sep 2025 13:10:50 +0000 Subject: [PATCH 09/28] add better logging --- TPTBox/core/vert_constants.py | 45 +++++++++++++++++++ TPTBox/mesh3D/snapshot3D.py | 2 +- .../ridged_points/point_registration.py | 4 +- 3 files changed, 49 insertions(+), 2 deletions(-) diff --git a/TPTBox/core/vert_constants.py b/TPTBox/core/vert_constants.py index dd60c5f..35f954b 100755 --- a/TPTBox/core/vert_constants.py +++ b/TPTBox/core/vert_constants.py @@ -235,6 +235,51 @@ class Full_Body_Instance(Abstract_lvl): inner_fat = 59 ignore = 63 + @classmethod + def bone(cls): + return [ + Full_Body_Instance.skull, + Full_Body_Instance.clavicula_right, + Full_Body_Instance.clavicula_left, + Full_Body_Instance.scapula_right, + Full_Body_Instance.scapula_left, + Full_Body_Instance.humerus_right, + Full_Body_Instance.humerus_left, + Full_Body_Instance.hand_right, + Full_Body_Instance.hand_left, + Full_Body_Instance.radius_right, + Full_Body_Instance.radius_left, + Full_Body_Instance.ulna_right, + Full_Body_Instance.ulna_left, + Full_Body_Instance.sternum, + Full_Body_Instance.costal_cartilage, + Full_Body_Instance.rib_right, + Full_Body_Instance.rib_left, + Full_Body_Instance.vert_body, + Full_Body_Instance.vert_post, + Full_Body_Instance.sacrum, + Full_Body_Instance.hip_right, + Full_Body_Instance.hip_left, + Full_Body_Instance.femur_right, + Full_Body_Instance.femur_left, + Full_Body_Instance.patella_right, + Full_Body_Instance.patella_left, + Full_Body_Instance.tibia_right, + Full_Body_Instance.tibia_left, + Full_Body_Instance.fibula_right, + Full_Body_Instance.fibula_left, + Full_Body_Instance.talus_right, + Full_Body_Instance.talus_left, + Full_Body_Instance.calcaneus_right, + Full_Body_Instance.calcaneus_left, + Full_Body_Instance.tarsals_right, + Full_Body_Instance.tarsals_left, + Full_Body_Instance.metatarsals_right, + Full_Body_Instance.metatarsals_left, + Full_Body_Instance.phalanges_right, + Full_Body_Instance.phalanges_left, + ] + @classmethod def get_totalVibeSegMapping(cls): return { diff --git a/TPTBox/mesh3D/snapshot3D.py b/TPTBox/mesh3D/snapshot3D.py index b825f18..24d1889 100644 --- a/TPTBox/mesh3D/snapshot3D.py +++ b/TPTBox/mesh3D/snapshot3D.py @@ -77,7 +77,7 @@ def make_snapshot3D( if output_path is None: t = NamedTemporaryFile(suffix="_snap3D.png") # noqa: SIM115 output_path = str(t.name) - Path(output_path).parent.mkdir(exist_ok=True) + Path(output_path).parent.mkdir(exist_ok=True, parents=True) nii = to_nii_seg(img) if crop: try: diff --git a/TPTBox/registration/ridged_points/point_registration.py b/TPTBox/registration/ridged_points/point_registration.py index b53bf97..c5b65c2 100755 --- a/TPTBox/registration/ridged_points/point_registration.py +++ b/TPTBox/registration/ridged_points/point_registration.py @@ -79,7 +79,9 @@ def __init__( if len(inter) < 2: log.print("[!] To few points, skip registration", Log_Type.FAIL) - raise ValueError("[!] To few points, skip registration", inter) + raise ValueError( + f"[!] To few points, skip registration; {poi_fixed.keys()=}; {poi_moving.keys()=}", + ) img_movig = poi_moving.make_empty_nii() assert img_movig.shape == poi_moving.shape_int, (img_movig, poi_moving.shape) assert img_movig.orientation == poi_moving.orientation From 87429c51aa7a5eaad69a26732105209fede0d6a4 Mon Sep 17 00:00:00 2001 From: ga84mun Date: Thu, 11 Sep 2025 13:12:29 +0000 Subject: [PATCH 10/28] add option --- TPTBox/core/poi_fun/strategies.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/TPTBox/core/poi_fun/strategies.py b/TPTBox/core/poi_fun/strategies.py index 58c4a2b..86bbee0 100644 --- a/TPTBox/core/poi_fun/strategies.py +++ b/TPTBox/core/poi_fun/strategies.py @@ -1,6 +1,7 @@ from __future__ import annotations from collections.abc import Sequence +from typing import Literal import numpy as np from numpy.linalg import norm @@ -169,14 +170,16 @@ def strategy_ligament_attachment_point_flava( goal: Location | np.ndarray, log: Logger_Interface = _log, delta=0.0000001, + shift_direction: Literal["S", "I"] = "S", + dir2: Literal["A"] = "A", ): if vert_id in sacrum_w_o_arcus: return try: - normal_vector1 = get_direction("S", poi, vert_id) # / np.array(poi.zoom) + normal_vector1 = get_direction(shift_direction, poi, vert_id) # / np.array(poi.zoom) v1 = normal_vector1 / norm(normal_vector1) - normal_vector2 = get_direction("A", poi, vert_id) # / np.array(poi.zoom) + normal_vector2 = get_direction(dir2, poi, vert_id) # / np.array(poi.zoom) v2 = normal_vector2 / norm(normal_vector2) except KeyError: return From 0ba8bdcd9a190259290fb20b4c17897ee54cfed5 Mon Sep 17 00:00:00 2001 From: ga84mun Date: Wed, 1 Oct 2025 06:55:05 +0000 Subject: [PATCH 11/28] small changes --- TPTBox/core/nii_wrapper.py | 6 +- TPTBox/core/poi_fun/poi_global.py | 8 +- TPTBox/core/poi_fun/save_load.py | 17 +++- TPTBox/core/poi_fun/save_mkr.py | 95 ++++++++++++++++++---- TPTBox/core/vert_constants.py | 96 ++++++++++++++++++++-- TPTBox/mesh3D/mesh_colors.py | 130 +++++++++++++++++++++++++++++- 6 files changed, 323 insertions(+), 29 deletions(-) diff --git a/TPTBox/core/nii_wrapper.py b/TPTBox/core/nii_wrapper.py index 7f15524..842eaaa 100755 --- a/TPTBox/core/nii_wrapper.py +++ b/TPTBox/core/nii_wrapper.py @@ -286,6 +286,7 @@ def load_nrrd(cls, path: str | Path, seg: bool): # ('space origin', array([-249.51171875, -392.51171875, 119.7]))]) # Construct the affine transformation matrix + #print(header) try: #print(header['space directions']) #print(header['space origin']) @@ -300,7 +301,7 @@ def load_nrrd(cls, path: str | Path, seg: bool): if m != n: n=m data = data.sum(axis=0) - space_directions = space_directions.T + space_directions = space_directions.T if space_directions.shape != (n, n): raise ValueError(f"Expected 'space directions' to be a nxn matrix. n = {n} is not {space_directions.shape}",space_directions) if space_origin.shape != (n,): @@ -309,6 +310,7 @@ def load_nrrd(cls, path: str | Path, seg: bool): affine = np.eye(n+1) # Initialize 4x4 identity matrix affine[:n, :n] = space_directions # Set rotation and scaling affine[:n, n] = space_origin # Set translation + #print(affine,space) if space =="left-posterior-superior": #LPS (SITK-space) affine[0] *=-1 affine[1] *=-1 @@ -320,7 +322,7 @@ def load_nrrd(cls, path: str | Path, seg: bool): pass else: raise ValueError(space) - + #print(affine) except KeyError as e: raise KeyError(f"Missing expected header field: {e}") from None diff --git a/TPTBox/core/poi_fun/poi_global.py b/TPTBox/core/poi_fun/poi_global.py index 3636676..e7c642f 100755 --- a/TPTBox/core/poi_fun/poi_global.py +++ b/TPTBox/core/poi_fun/poi_global.py @@ -23,12 +23,14 @@ class POI_Global(Abstract_POI): def __init__( self, - input_poi: poi.POI | POI_Descriptor | dict[str, dict[str, tuple[float, ...]]], + input_poi: poi.POI | POI_Descriptor | dict[str, dict[str, tuple[float, ...]]] = None, itk_coords: bool = False, level_one_info: type[Abstract_lvl] | None = None, # Must be Enum and must has order_dict level_two_info: type[Abstract_lvl] | None = None, info: dict | None = None, ): + if input_poi is None: + input_poi = {} args = {} if level_one_info is not None: args["level_one_info"] = level_one_info @@ -157,6 +159,8 @@ def copy(self, centroids: POI_Descriptor | None = None) -> Self: if centroids is None: centroids = self.centroids.copy() p = POI_Global(centroids) + p.level_one_info = self.level_one_info + p.level_two_info = self.level_two_info p.format = self.format p.info = deepcopy(self.info) p.itk_coords = self.itk_coords @@ -195,6 +199,7 @@ def save_mrk( display: save_mkr.MKR_Display | dict = None, # type: ignore pointLabelsVisibility=False, glyphScale=5.0, + main_key="Point", ): save_mkr._save_mrk( poi=self, @@ -207,4 +212,5 @@ def save_mrk( display=display, pointLabelsVisibility=pointLabelsVisibility, glyphScale=glyphScale, + main_key=main_key, ) diff --git a/TPTBox/core/poi_fun/save_load.py b/TPTBox/core/poi_fun/save_load.py index 276a88d..f277238 100644 --- a/TPTBox/core/poi_fun/save_load.py +++ b/TPTBox/core/poi_fun/save_load.py @@ -137,7 +137,11 @@ def convert(o): def _poi_to_dict_list( # noqa: C901 - ctd: POI | POI_Global, additional_info: dict | None, save_hint=0, resample_reference: Has_Grid | None = None, verbose: logging = False + ctd: POI | POI_Global, + additional_info: dict | None, + save_hint=0, + resample_reference: Has_Grid | None = None, + verbose: logging = False, ): from TPTBox import POI, POI_Global @@ -274,7 +278,13 @@ def load_poi(ctd_path: POI_Reference, verbose=True) -> POI | POI_Global: # noqa centroids = POI_Descriptor() itk_coords = global_spacing_name_key2value[dict_list[0].get("coordinate_system", "nib")] _load_POI_centroids(dict_list, centroids, level_one_info, level_two_info) - return POI_Global(centroids, itk_coords=itk_coords, level_one_info=level_one_info, level_two_info=level_two_info, info=info) + return POI_Global( + centroids, + itk_coords=itk_coords, + level_one_info=level_one_info, + level_two_info=level_two_info, + info=info, + ) ### Ours ### assert "direction" in dict_list[0], f'File format error: first index must be a "Direction" but got {dict_list[0]}' @@ -462,6 +472,7 @@ def _load_mkr_POI(dict_mkr: dict): itk_coords = None if dict_mkr.get("coordinateSystem") in ["LPS", "RAS"]: itk_coords = dict_mkr["coordinateSystem"] == "LPS" + label_name = {} for markup in dict_mkr["markups"]: if markup["type"] != "Fiducial": log.on_warning("skip unknown markup type:", markup["type"]) @@ -493,6 +504,7 @@ def _load_mkr_POI(dict_mkr: dict): # orientation = controlPoints.get("orientation", None) region, subregion = _get_poi_idx_from_text(idx, label, centroids) centroids[region, subregion] = tuple(position) + label_name[str((region, subregion))] = label assert itk_coords is not None, "itk_coords not set" from TPTBox import POI_Global @@ -500,4 +512,5 @@ def _load_mkr_POI(dict_mkr: dict): if "display" in dict_mkr and "color" in dict_mkr["display"]: # TODO keep all display, locked etc info in the info dict poi.info["color"] = dict_mkr["display"]["color"] + poi.info["label_name"] = label_name return poi diff --git a/TPTBox/core/poi_fun/save_mkr.py b/TPTBox/core/poi_fun/save_mkr.py index e1c0a59..f3fc779 100644 --- a/TPTBox/core/poi_fun/save_mkr.py +++ b/TPTBox/core/poi_fun/save_mkr.py @@ -5,7 +5,7 @@ from pathlib import Path ###### GLOBAL POI ##### -from typing import TYPE_CHECKING, Any, Literal, Optional, TypedDict, Union +from typing import TYPE_CHECKING, Any, Literal, TypedDict, Union import numpy as np from typing_extensions import NotRequired @@ -113,7 +113,12 @@ class MeasurementVolumeMarkup(Markup, total=False): def _get_display_dict( - display: MKR_Display | dict, color=None, selectedColor=None, activeColor=None, pointLabelsVisibility=False, glyphScale=1.0 + display: MKR_Display | dict, + color=None, + selectedColor=None, + activeColor=None, + pointLabelsVisibility=False, + glyphScale=1.0, ): if activeColor is None: activeColor = [0.4, 1.0, 0.0] @@ -155,7 +160,13 @@ def _get_display_dict( } -def _get_markup_color(definition: MKR_DEFINITION, region, subregion, split_by_region=False, split_by_subregion=False): +def _get_markup_color( + definition: MKR_DEFINITION, + region, + subregion, + split_by_region=False, + split_by_subregion=False, +): color = definition.get("color", None) if color is None: if split_by_region: @@ -190,7 +201,11 @@ def _get_control_point(cp: ControlPoint, position, id_name="", label="", name="" def _make_default_markup( - markup_type: MarkupType, name, coordinateSystem: CoordinateSystem, controlPoints=None, display=None + markup_type: MarkupType, + name, + coordinateSystem: CoordinateSystem, + controlPoints=None, + display=None, ) -> Markup | MeasurementVolumeMarkup: if controlPoints is None: controlPoints = [] @@ -211,7 +226,13 @@ def _make_default_markup( if markup_type == "ROI": base.update({"roiType": "Box", "insideOut": False}) elif markup_type == "Plane": - base.update({"planeType": "PointNormal", "sizeMode": "auto", "autoScalingSizeFactor": 1.0}) + base.update( + { + "planeType": "PointNormal", + "sizeMode": "auto", + "autoScalingSizeFactor": 1.0, + } + ) elif markup_type == "MeasurementVolume": mv: MeasurementVolumeMarkup = { **base, @@ -246,13 +267,27 @@ def _get_markup_lines( name, name2 = get_desc(poi, region, subregion) controlPoints.append( _get_control_point( - definition.get("controlPoint", {}), poi[region, subregion], f"{region}-{subregion}", f"{region}-{subregion}", name, name2 + definition.get("controlPoint", {}), + poi[region, subregion], + f"{region}-{subregion}", + f"{region}-{subregion}", + name, + name2, ) ) - return _make_default_markup("Line", definition.get("name"), coordinate_system, controlPoints=controlPoints, display=display) + return _make_default_markup( + "Line", + definition.get("name"), + coordinate_system, + controlPoints=controlPoints, + display=display, + ) def get_desc(self: POI_Global, region, subregion): + label = self.info.get("label_name", {}).get(str((region, subregion))) + if label is None: + label = f"{region}-{subregion}" try: name = self.level_two_info(subregion).name except Exception: @@ -261,11 +296,11 @@ def get_desc(self: POI_Global, region, subregion): name2 = self.level_one_info(region).name except Exception: name2 = str(region) - return name, name2 + return name, name2, label -def _get_key(region, subregion, split_by_region, split_by_subregion): - key = "P" +def _get_key(region, subregion, split_by_region, split_by_subregion, main_key="POI"): + key = main_key if split_by_region: key += str(region) + "_" if split_by_subregion: @@ -285,6 +320,7 @@ def _save_mrk( display: MKR_Display | dict = None, # type: ignore pointLabelsVisibility=False, glyphScale=1.0, + main_key="P", **args, ): """ @@ -312,8 +348,14 @@ def _save_mrk( if add_points: # Create list of control points for region, subregion, coords in poi.centroids.items(): - key = _get_key(region, subregion, split_by_region, split_by_subregion) - name, name2 = get_desc(poi, region, subregion) + key = _get_key( + region, + subregion, + split_by_region, + split_by_subregion, + main_key=main_key, + ) + name, name2, label = get_desc(poi, region, subregion) if key not in list_markups: list_markups[key] = _make_default_markup( "Fiducial", @@ -323,17 +365,40 @@ def _save_mrk( display=_get_display_dict( display, selectedColor=_get_markup_color( - {"color": color}, region, subregion, split_by_region=split_by_subregion, split_by_subregion=split_by_subregion + {"color": color}, + region, + subregion, + split_by_region=split_by_subregion, + split_by_subregion=split_by_subregion, ), **addendum, ), ) list_markups[key]["controlPoints"].append( - _get_control_point({}, coords, f"{region}-{subregion}", f"{region}-{subregion}", name, name2) + _get_control_point( + {}, + coords, + f"{region}-{subregion}", + label, + name, + name2, + ) ) markups = list(list_markups.values()) - [markups.append(_get_markup_lines(line, poi, coordinate_system, split_by_region, split_by_subregion, display)) for line in add_lines] + [ + markups.append( + _get_markup_lines( + line, + poi, + coordinate_system, + split_by_region, + split_by_subregion, + display, + ) + ) + for line in add_lines + ] mrk_data = { "@schema": "https://raw.githubusercontent.com/slicer/slicer/master/Modules/Loadable/Markups/Resources/Schema/markups-schema-v1.0.3.json#", "markups": markups, diff --git a/TPTBox/core/vert_constants.py b/TPTBox/core/vert_constants.py index 180287e..dbd7ef5 100755 --- a/TPTBox/core/vert_constants.py +++ b/TPTBox/core/vert_constants.py @@ -214,6 +214,84 @@ class Full_Body_Instance_Vibe(Abstract_lvl): spinal_channel = 71 bone_other = 72 + @classmethod + def get_Full_Body_Instance_mapping(cls): + return { + Full_Body_Instance.spleen.value: cls.spleen.value, # spleen + Full_Body_Instance.kidney_right.value: cls.kidney_right.value, # kidney_right + Full_Body_Instance.kidney_left.value: cls.kidney_left.value, # kidney_left + Full_Body_Instance.gallbladder.value: cls.gallbladder.value, # gallbladder + Full_Body_Instance.liver.value: cls.liver.value, # liver + Full_Body_Instance.stomach.value: cls.stomach.value, # stomach + Full_Body_Instance.pancreas.value: cls.pancreas.value, # pancreas + Full_Body_Instance.adrenal_gland_right.value: cls.adrenal_gland_right.value, # adrenal_gland_right + Full_Body_Instance.adrenal_gland_left.value: cls.adrenal_gland_left.value, # adrenal_gland_left + Full_Body_Instance.lung_left.value: cls.lung_upper_lobe_left.value, # lung_upper_lobe_left + Full_Body_Instance.lung_left.value: cls.lung_lower_lobe_left.value, # lung_lower_lobe_left + Full_Body_Instance.lung_right.value: cls.lung_upper_lobe_right.value, # lung_upper_lobe_right + Full_Body_Instance.lung_right.value: cls.lung_middle_lobe_right.value, # lung_middle_lobe_right + Full_Body_Instance.lung_right.value: cls.lung_lower_lobe_right.value, # lung_lower_lobe_right + Full_Body_Instance.esophagus.value: cls.esophagus.value, # esophagus + Full_Body_Instance.trachea.value: cls.trachea.value, # trachea + Full_Body_Instance.thyroid_gland_right.value: cls.thyroid_gland.value, # thyroid_gland + Full_Body_Instance.intestine.value: cls.intestine.value, # intestine + Full_Body_Instance.doudenum.value: cls.duodenum.value, # duodenum + Full_Body_Instance.rib_right.value: cls.unused.value, # unused + Full_Body_Instance.urinary_bladder.value: cls.urinary_bladder.value, # urinary_bladder + Full_Body_Instance.prostate.value: cls.prostate.value, # prostate + Full_Body_Instance.sacrum.value: cls.sacrum.value, # sacrum + Full_Body_Instance.heart.value: cls.heart.value, # heart + Full_Body_Instance.aorta.value: cls.aorta.value, # aorta + Full_Body_Instance.pulmonary_vein.value: cls.pulmonary_vein.value, # pulmonary_vein + Full_Body_Instance.brachiocephalic_trunk.value: cls.brachiocephalic_trunk.value, # brachiocephalic_trunk + Full_Body_Instance.subclavian_artery_right.value: cls.subclavian_artery_right.value, # subclavian_artery_right + Full_Body_Instance.subclavian_artery_left.value: cls.subclavian_artery_left.value, # subclavian_artery_left + Full_Body_Instance.common_carotid_artery_right.value: cls.common_carotid_artery_right.value, # common_carotid_artery_right + Full_Body_Instance.common_carotid_artery_left.value: cls.common_carotid_artery_left.value, # common_carotid_artery_left + Full_Body_Instance.brachiocephalic_vein_left.value: cls.brachiocephalic_vein_left.value, # brachiocephalic_vein_left + Full_Body_Instance.brachiocephalic_vein_right.value: cls.brachiocephalic_vein_right.value, # brachiocephalic_vein_right + Full_Body_Instance.atrial_appendage_left.value: cls.atrial_appendage_left.value, # atrial_appendage_left + Full_Body_Instance.superior_vena_cava.value: cls.superior_vena_cava.value, # superior_vena_cava + Full_Body_Instance.inferior_vena_cava.value: cls.inferior_vena_cava.value, # inferior_vena_cava + Full_Body_Instance.portal_vein_and_splenic_vein.value: cls.portal_vein_and_splenic_vein.value, # portal_vein_and_splenic_vein + Full_Body_Instance.iliac_artery_left.value: cls.iliac_artery_left.value, # iliac_artery_left + Full_Body_Instance.iliac_artery_right.value: cls.iliac_artery_right.value, # iliac_artery_right + Full_Body_Instance.iliac_vena_left.value: cls.iliac_vena_left.value, # iliac_vena_left + Full_Body_Instance.iliac_vena_right.value: cls.iliac_vena_right.value, # iliac_vena_right + Full_Body_Instance.humerus_left.value: cls.humerus_left.value, # humerus_left + Full_Body_Instance.humerus_right.value: cls.humerus_right.value, # humerus_right + Full_Body_Instance.scapula_left.value: cls.scapula_left.value, # scapula_left + Full_Body_Instance.scapula_right.value: cls.scapula_right.value, # scapula_right + Full_Body_Instance.clavicula_left.value: cls.clavicula_left.value, # clavicula_left + Full_Body_Instance.clavicula_right.value: cls.clavicula_right.value, # clavicula_right + Full_Body_Instance.femur_left.value: cls.femur_left.value, # femur_left + Full_Body_Instance.femur_right.value: cls.femur_right.value, # femur_right + Full_Body_Instance.pelvis_left.value: cls.hip_left.value, # hip_left + Full_Body_Instance.pelvis_right.value: cls.hip_right.value, # hip_right + Full_Body_Instance.channel.value: cls.spinal_cord.value, # spinal_cord + Full_Body_Instance.gluteus_maximus_left.value: cls.gluteus_maximus_left.value, # gluteus_maximus_left + Full_Body_Instance.gluteus_maximus_right.value: cls.gluteus_maximus_right.value, # gluteus_maximus_right + Full_Body_Instance.gluteus_medius_left.value: cls.gluteus_medius_left.value, # gluteus_medius_left + Full_Body_Instance.gluteus_medius_right.value: cls.gluteus_medius_right.value, # gluteus_medius_right + Full_Body_Instance.gluteus_minimus_left.value: cls.gluteus_minimus_left.value, # gluteus_minimus_left + Full_Body_Instance.gluteus_minimus_right.value: cls.gluteus_minimus_right.value, # gluteus_minimus_right + Full_Body_Instance.autochthon_left.value: cls.autochthon_left.value, # autochthon_left + Full_Body_Instance.autochthon_right.value: cls.autochthon_right.value, # autochthon_right + Full_Body_Instance.iliopsoas_left.value: cls.iliopsoas_left.value, # iliopsoas_left + Full_Body_Instance.iliopsoas_right.value: cls.iliopsoas_right.value, # iliopsoas_right + Full_Body_Instance.sternum.value: cls.sternum.value, # sternum + Full_Body_Instance.costal_cartilage.value: cls.costal_cartilages.value, # costal_cartilages + Full_Body_Instance.subcutaneous_fat.value: cls.subcutaneous_fat.value, # subcutaneous_fat + Full_Body_Instance.muscle_other.value: cls.muscle.value, # muscle + Full_Body_Instance.inner_fat.value: cls.inner_fat.value, # inner_fat + Full_Body_Instance.ivd.value: cls.IVD.value, # IVD + Full_Body_Instance.vert_body.value: cls.vertebra_body.value, # vertebra_body + Full_Body_Instance.vert_post.value: cls.vertebra_posterior_elements.value, # vertebra_posterior_elements + Full_Body_Instance.channel.value: cls.spinal_channel.value, # spinal_channel + Full_Body_Instance.ignore.value: cls.bone_other.value, # bone_other + 100: Full_Body_Instance.ignore.value, + } + class Full_Body_Instance(Abstract_lvl): skull = 1 @@ -236,8 +314,8 @@ class Full_Body_Instance(Abstract_lvl): vert_body = 9 vert_post = 10 sacrum = 11 - hip_right = 12 - hip_left = 112 + pelvis_right = 12 + pelvis_left = 112 femur_right = 13 femur_left = 113 patella_right = 14 @@ -333,8 +411,8 @@ def bone(cls): Full_Body_Instance.vert_body, Full_Body_Instance.vert_post, Full_Body_Instance.sacrum, - Full_Body_Instance.hip_right, - Full_Body_Instance.hip_left, + Full_Body_Instance.pelvis_right, + Full_Body_Instance.pelvis_left, Full_Body_Instance.femur_right, Full_Body_Instance.femur_left, Full_Body_Instance.patella_right, @@ -356,7 +434,7 @@ def bone(cls): ] @classmethod - def get_totalVibeSegMapping(cls): + def get_VIBESeg_mapping(cls): return { 1: Full_Body_Instance.spleen.value, # spleen 2: Full_Body_Instance.kidney_right.value, # kidney_right @@ -407,8 +485,8 @@ def get_totalVibeSegMapping(cls): 47: Full_Body_Instance.clavicula_right.value, # clavicula_right 48: Full_Body_Instance.femur_left.value, # femur_left 49: Full_Body_Instance.femur_right.value, # femur_right - 50: Full_Body_Instance.hip_left.value, # hip_left - 51: Full_Body_Instance.hip_right.value, # hip_right + 50: Full_Body_Instance.pelvis_left.value, # hip_left + 51: Full_Body_Instance.pelvis_right.value, # hip_right 52: Full_Body_Instance.channel.value, # spinal_cord 53: Full_Body_Instance.gluteus_maximus_left.value, # gluteus_maximus_left 54: Full_Body_Instance.gluteus_maximus_right.value, # gluteus_maximus_right @@ -452,7 +530,7 @@ class Lower_Body(Abstract_lvl): TROCHLEA_GROOVE_CENTRAL_POINT = 10 # Femur - HIP_CENTER = 11 + PELVIS_CENTER = 11 NECK_CENTER = 12 TIP_OF_GREATER_TROCHANTER = 13 LATERAL_CONDYLE_POSTERIOR = 14 @@ -501,6 +579,8 @@ def __init__( self._rib = ( vertebra_label + VERTEBRA_INSTANCE_RIB_LABEL_OFFSET if vertebra_label != 28 else 21 + VERTEBRA_INSTANCE_RIB_LABEL_OFFSET ) + # 40 - 8 + 21 = 53 = rib for T13 + # 52 rib for L1 if has_ivd: self._ivd = vertebra_label + VERTEBRA_INSTANCE_IVD_LABEL_OFFSET self._endplate = vertebra_label + VERTEBRA_INSTANCE_ENDPLATE_LABEL_OFFSET diff --git a/TPTBox/mesh3D/mesh_colors.py b/TPTBox/mesh3D/mesh_colors.py index cb60381..a3c156e 100755 --- a/TPTBox/mesh3D/mesh_colors.py +++ b/TPTBox/mesh3D/mesh_colors.py @@ -108,7 +108,7 @@ class Mesh_Color_List: ITK_48 = RGB_Color.init_list([112, 128, 144]) ITK_49 = RGB_Color.init_list([34, 139, 34]) ITK_50 = RGB_Color.init_list([248, 248, 255]) - ITK_51 = RGB_Color.init_list([245, 255, 250]) + ITK_51 = RGB_Color.init_list([145, 255, 150]) ITK_52 = RGB_Color.init_list([255, 160, 122]) ITK_53 = RGB_Color.init_list([144, 238, 144]) ITK_54 = RGB_Color.init_list([173, 255, 47]) @@ -131,7 +131,135 @@ class Mesh_Color_List: ITK_71 = RGB_Color.init_list([255, 250, 240]) ITK_72 = RGB_Color.init_list([0, 206, 209]) + ITK_73 = RGB_Color.init_list([0, 255, 127]) + ITK_74 = RGB_Color.init_list([128, 0, 128]) + ITK_75 = RGB_Color.init_list([255, 250, 205]) + ITK_76 = RGB_Color.init_list([250, 128, 114]) + ITK_77 = RGB_Color.init_list([148, 0, 211]) + ITK_78 = RGB_Color.init_list([178, 34, 34]) + ITK_79 = RGB_Color.init_list([255, 127, 80]) + ITK_80 = RGB_Color.init_list([135, 206, 235]) + ITK_81 = RGB_Color.init_list([100, 149, 237]) + ITK_82 = RGB_Color.init_list([240, 230, 140]) + ITK_83 = RGB_Color.init_list([250, 235, 215]) + ITK_84 = RGB_Color.init_list([255, 245, 238]) + ITK_85 = RGB_Color.init_list([107, 142, 35]) + ITK_86 = RGB_Color.init_list([135, 206, 250]) + ITK_87 = RGB_Color.init_list([0, 0, 139]) + ITK_88 = RGB_Color.init_list([139, 0, 139]) + ITK_89 = RGB_Color.init_list([245, 245, 220]) + ITK_90 = RGB_Color.init_list([186, 85, 211]) + ITK_91 = RGB_Color.init_list([255, 228, 181]) + ITK_92 = RGB_Color.init_list([255, 222, 173]) + ITK_93 = RGB_Color.init_list([0, 191, 255]) + ITK_94 = RGB_Color.init_list([210, 105, 30]) + ITK_95 = RGB_Color.init_list([255, 248, 220]) + ITK_96 = RGB_Color.init_list([47, 79, 79]) + ITK_97 = RGB_Color.init_list([72, 61, 139]) + ITK_98 = RGB_Color.init_list([175, 238, 238]) + ITK_99 = RGB_Color.init_list([128, 128, 0]) ITK_100 = RGB_Color.init_list([176, 224, 230]) + ITK_101 = RGB_Color.init_list([255, 240, 245]) + ITK_102 = RGB_Color.init_list([139, 0, 0]) + ITK_103 = RGB_Color.init_list([240, 255, 255]) + ITK_104 = RGB_Color.init_list([255, 215, 0]) + ITK_105 = RGB_Color.init_list([216, 191, 216]) + ITK_106 = RGB_Color.init_list([119, 136, 153]) + ITK_107 = RGB_Color.init_list([219, 112, 147]) + ITK_108 = RGB_Color.init_list([72, 209, 204]) + ITK_109 = RGB_Color.init_list([255, 0, 255]) + ITK_110 = RGB_Color.init_list([199, 21, 133]) + ITK_111 = RGB_Color.init_list([154, 205, 50]) + ITK_112 = RGB_Color.init_list([189, 183, 107]) + ITK_113 = RGB_Color.init_list([240, 248, 255]) + ITK_114 = RGB_Color.init_list([230, 230, 250]) + ITK_115 = RGB_Color.init_list([0, 250, 154]) + ITK_116 = RGB_Color.init_list([85, 107, 47]) + ITK_117 = RGB_Color.init_list([64, 224, 208]) + ITK_118 = RGB_Color.init_list([153, 50, 204]) + ITK_119 = RGB_Color.init_list([205, 92, 92]) + ITK_120 = RGB_Color.init_list([250, 250, 210]) + ITK_121 = RGB_Color.init_list([95, 158, 160]) + ITK_122 = RGB_Color.init_list([0, 128, 0]) + ITK_123 = RGB_Color.init_list([255, 69, 0]) + ITK_124 = RGB_Color.init_list([224, 255, 255]) + ITK_125 = RGB_Color.init_list([176, 196, 222]) + ITK_126 = RGB_Color.init_list([138, 43, 226]) + ITK_127 = RGB_Color.init_list([30, 144, 255]) + ITK_128 = RGB_Color.init_list([240, 128, 128]) + ITK_129 = RGB_Color.init_list([152, 251, 152]) + ITK_130 = RGB_Color.init_list([160, 82, 45]) + ITK_131 = RGB_Color.init_list([255, 0, 0]) + ITK_132 = RGB_Color.init_list([0, 255, 0]) + ITK_133 = RGB_Color.init_list([0, 0, 255]) + ITK_134 = RGB_Color.init_list([255, 255, 0]) + ITK_135 = RGB_Color.init_list([0, 255, 255]) + ITK_136 = RGB_Color.init_list([255, 0, 255]) + ITK_137 = RGB_Color.init_list([255, 239, 213]) + ITK_138 = RGB_Color.init_list([0, 0, 205]) + ITK_139 = RGB_Color.init_list([205, 133, 63]) + ITK_140 = RGB_Color.init_list([210, 180, 140]) + ITK_141 = RGB_Color.init_list([102, 205, 170]) + ITK_142 = RGB_Color.init_list([0, 0, 128]) + ITK_143 = RGB_Color.init_list([0, 139, 139]) + ITK_144 = RGB_Color.init_list([46, 139, 87]) + ITK_145 = RGB_Color.init_list([255, 228, 225]) + ITK_146 = RGB_Color.init_list([106, 90, 205]) + ITK_147 = RGB_Color.init_list([221, 160, 221]) + ITK_148 = RGB_Color.init_list([233, 150, 122]) + ITK_149 = RGB_Color.init_list([165, 42, 42]) + ITK_150 = RGB_Color.init_list([255, 250, 250]) + ITK_151 = RGB_Color.init_list([147, 112, 219]) + ITK_152 = RGB_Color.init_list([218, 112, 214]) + ITK_153 = RGB_Color.init_list([75, 0, 130]) + ITK_154 = RGB_Color.init_list([255, 182, 193]) + ITK_155 = RGB_Color.init_list([60, 179, 113]) + ITK_156 = RGB_Color.init_list([255, 235, 205]) + ITK_157 = RGB_Color.init_list([255, 228, 196]) + ITK_158 = RGB_Color.init_list([218, 165, 32]) + ITK_159 = RGB_Color.init_list([0, 128, 128]) + ITK_160 = RGB_Color.init_list([188, 143, 143]) + ITK_161 = RGB_Color.init_list([255, 105, 180]) + ITK_162 = RGB_Color.init_list([255, 218, 185]) + ITK_163 = RGB_Color.init_list([222, 184, 135]) + ITK_164 = RGB_Color.init_list([127, 255, 0]) + ITK_165 = RGB_Color.init_list([139, 69, 19]) + ITK_166 = RGB_Color.init_list([124, 252, 0]) + ITK_167 = RGB_Color.init_list([255, 255, 224]) + ITK_168 = RGB_Color.init_list([70, 130, 180]) + ITK_169 = RGB_Color.init_list([0, 100, 0]) + ITK_170 = RGB_Color.init_list([238, 130, 238]) + ITK_171 = RGB_Color.init_list([238, 232, 170]) + ITK_172 = RGB_Color.init_list([240, 255, 240]) + ITK_173 = RGB_Color.init_list([245, 222, 179]) + ITK_174 = RGB_Color.init_list([184, 134, 11]) + ITK_175 = RGB_Color.init_list([32, 178, 170]) + ITK_176 = RGB_Color.init_list([255, 20, 147]) + ITK_177 = RGB_Color.init_list([25, 25, 112]) + ITK_178 = RGB_Color.init_list([112, 128, 144]) + ITK_179 = RGB_Color.init_list([34, 139, 34]) + ITK_180 = RGB_Color.init_list([248, 248, 255]) + ITK_181 = RGB_Color.init_list([245, 255, 250]) + ITK_182 = RGB_Color.init_list([255, 160, 122]) + ITK_183 = RGB_Color.init_list([144, 238, 144]) + ITK_184 = RGB_Color.init_list([173, 255, 47]) + ITK_185 = RGB_Color.init_list([65, 105, 225]) + ITK_186 = RGB_Color.init_list([255, 99, 71]) + ITK_187 = RGB_Color.init_list([250, 240, 230]) + ITK_188 = RGB_Color.init_list([128, 0, 0]) + ITK_189 = RGB_Color.init_list([50, 205, 50]) + ITK_190 = RGB_Color.init_list([244, 164, 96]) + ITK_191 = RGB_Color.init_list([255, 255, 240]) + ITK_192 = RGB_Color.init_list([123, 104, 238]) + ITK_193 = RGB_Color.init_list([255, 165, 0]) + ITK_194 = RGB_Color.init_list([173, 216, 230]) + ITK_195 = RGB_Color.init_list([255, 192, 203]) + ITK_196 = RGB_Color.init_list([127, 255, 212]) + ITK_197 = RGB_Color.init_list([255, 140, 0]) + ITK_198 = RGB_Color.init_list([143, 188, 143]) + ITK_199 = RGB_Color.init_list([220, 20, 60]) + ITK_200 = RGB_Color.init_list([253, 245, 230]) + ITK_201 = RGB_Color.init_list([255, 250, 240]) _color_dict = {v: getattr(Mesh_Color_List, v) for v in vars(Mesh_Color_List) if not callable(v) and not v.startswith("__")} From d7b4a47ba83273aa0cd959e17b335153bc93bfda Mon Sep 17 00:00:00 2001 From: robert Date: Wed, 1 Oct 2025 08:56:18 +0200 Subject: [PATCH 12/28] prevent error for unreadable data, like object data --- TPTBox/core/nii_wrapper.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/TPTBox/core/nii_wrapper.py b/TPTBox/core/nii_wrapper.py index 7f15524..3a3ba85 100755 --- a/TPTBox/core/nii_wrapper.py +++ b/TPTBox/core/nii_wrapper.py @@ -99,7 +99,10 @@ def _check_if_nifty_is_lying_about_its_dtype(self: NII): arr = self.nii.dataobj dtype = self._nii.dataobj.dtype # type: ignore dtype_s = str(self._nii.dataobj.dtype) - mi = np.min(arr) + try: + mi = np.min(arr) + except Exception: + return np.float32 ma = np.max(arr) has_neg = mi < 0 max_v = _dtype_max.get(str(dtype), 0) From 389b305d7c8237851aa546af4b99127b7d8c00df Mon Sep 17 00:00:00 2001 From: ga84mun Date: Tue, 14 Oct 2025 20:51:05 +0000 Subject: [PATCH 13/28] small bugfixes --- TPTBox/core/bids_files.py | 3 ++- TPTBox/core/nii_poi_abstract.py | 2 ++ TPTBox/core/poi_fun/save_load.py | 9 ++++++++- TPTBox/core/poi_fun/save_mkr.py | 7 +++++++ 4 files changed, 19 insertions(+), 2 deletions(-) diff --git a/TPTBox/core/bids_files.py b/TPTBox/core/bids_files.py index 64d98ca..972b2ef 100755 --- a/TPTBox/core/bids_files.py +++ b/TPTBox/core/bids_files.py @@ -774,7 +774,8 @@ def get_changed_path( # noqa: C901 if key in same_info: continue if value is not None: - assert validate_entities(key, value, f"..._{key}-{value}_...", True) + if not non_strict_mode: + assert validate_entities(key, value, f"..._{key}-{value}_...", True), f"..._{key}-{value}_..." final_info[key] = value # file_name += f"{key}-{value}_" # sort by order diff --git a/TPTBox/core/nii_poi_abstract.py b/TPTBox/core/nii_poi_abstract.py index 15c377d..e04edce 100755 --- a/TPTBox/core/nii_poi_abstract.py +++ b/TPTBox/core/nii_poi_abstract.py @@ -2,6 +2,7 @@ import sys import warnings +from dataclasses import dataclass from typing import TYPE_CHECKING import nibabel as nib @@ -435,6 +436,7 @@ def get_num_dims(self): return len(self.shape) +@dataclass class Grid(Has_Grid): def __init__(self, **qargs) -> None: super().__init__() diff --git a/TPTBox/core/poi_fun/save_load.py b/TPTBox/core/poi_fun/save_load.py index f277238..249f756 100644 --- a/TPTBox/core/poi_fun/save_load.py +++ b/TPTBox/core/poi_fun/save_load.py @@ -200,6 +200,12 @@ def _poi_to_dict_list( # noqa: C901 elif save_hint in (FORMAT_POI, FORMAT_GLOBAL): v_name = ctd.level_one_info._get_name(vert_id, no_raise=True) subreg_id = ctd.level_two_info._get_name(subreg_id, no_raise=True) # noqa: PLW2901 + if "_ignore_level_one_info_range" in ctd.info: + try: + if vert_id in ctd.info["_ignore_level_one_info_range"]: + v_name = vert_id + except Exception: + pass # sub_name = v_idx2name[subreg_id] if v_name not in temp_dict: temp_dict[v_name] = {} @@ -465,7 +471,8 @@ def _load_mkr_POI(dict_mkr: dict): if "@schema" not in dict_mkr or "markups-schema-v1.0.3" not in dict_mkr["@schema"]: log.on_warning( - "this file is possible incompatible. Tested only with markups-schema-v1.0.3 and not", dict_mkr.get("@schmea", "No Schema") + "this file is possible incompatible. Tested only with markups-schema-v1.0.3 and not", + dict_mkr.get("@schmea", "No Schema"), ) if "markups" not in dict_mkr: raise ValueError("markups is missing") diff --git a/TPTBox/core/poi_fun/save_mkr.py b/TPTBox/core/poi_fun/save_mkr.py index f3fc779..f04bf0f 100644 --- a/TPTBox/core/poi_fun/save_mkr.py +++ b/TPTBox/core/poi_fun/save_mkr.py @@ -294,6 +294,13 @@ def get_desc(self: POI_Global, region, subregion): name = str(subregion) try: name2 = self.level_one_info(region).name + if "_ignore_level_one_info_range" in self.info: + try: + if region in self.info["_ignore_level_one_info_range"]: + name2 = str(region) + except Exception: + pass + except Exception: name2 = str(region) return name, name2, label From 8711979252a5d14890ca2e3825323a38383b6262 Mon Sep 17 00:00:00 2001 From: robert Date: Tue, 14 Oct 2025 23:01:20 +0200 Subject: [PATCH 14/28] add 3D dicom supprot --- TPTBox/core/dicom/dicom_extract.py | 67 ++++++++++++++++++++++++++++-- 1 file changed, 64 insertions(+), 3 deletions(-) diff --git a/TPTBox/core/dicom/dicom_extract.py b/TPTBox/core/dicom/dicom_extract.py index 8f9e0c7..0a0075c 100644 --- a/TPTBox/core/dicom/dicom_extract.py +++ b/TPTBox/core/dicom/dicom_extract.py @@ -11,6 +11,7 @@ import dicom2nifti import dicom2nifti.exceptions +import nibabel as nib import numpy as np import pydicom from dicom2nifti import convert_dicom @@ -87,6 +88,50 @@ def _generate_bids_path( return fname.file["json"], fname +def dicom_to_nifti_multiframe(ds, nii_path): + pixel_array = ds.pixel_array + n_frames, n_rows, n_cols = pixel_array.shape + + # Pixel spacing (mm) + if hasattr(ds, "PixelSpacing"): + dy, dx = [float(v) for v in ds.PixelSpacing] + # Image orientation (row and column direction cosines) + orientation = [float(v) for v in ds.ImageOrientationPatient] + row_cosines = np.array(orientation[0:3]) + col_cosines = np.array(orientation[3:6]) + # Normal vector (slice direction) + slice_cosines = np.cross(row_cosines, col_cosines) + + # Image position (origin of first slice) + origin = np.array([float(v) for v in ds.ImagePositionPatient]) + # Slice spacing - robust: Abstand zwischen Slice 0 und 1 + if n_frames > 1: + pos1 = np.array([float(v) for v in ds.PerFrameFunctionalGroupsSequence[0].PlanePositionSequence[0].ImagePositionPatient]) + pos2 = np.array([float(v) for v in ds.PerFrameFunctionalGroupsSequence[1].PlanePositionSequence[0].ImagePositionPatient]) + dz = np.linalg.norm(pos2 - pos1) + else: + dz = float(getattr(ds, "SpacingBetweenSlices", ds.SliceThickness)) + + # Affine bauen + affine = np.eye(4) + affine[0:3, 0] = row_cosines * dx + affine[0:3, 1] = col_cosines * dy + affine[0:3, 2] = slice_cosines * dz + affine[0:3, 3] = origin + else: + dy, dx = [float(v) for v in ds.ImagerPixelSpacing] + # Einfaches affine (nur 2D + Zeit, keine Lage im Patientenraum) + affine = np.eye(4) + affine[0, 0] = -dx + affine[1, 1] = -dy + + # Reihenfolge anpassen: Nibabel erwartet (X,Y,Z) + nii = nib.Nifti1Image(np.transpose(pixel_array, (2, 1, 0)), affine) + nib.save(nii, nii_path) + + return nii_path + + def _convert_to_nifti(dicom_out_path, nii_path): """ Convert DICOM files to NIfTI format and handle common conversion errors. @@ -105,6 +150,15 @@ def _convert_to_nifti(dicom_out_path, nii_path): """ try: if isinstance(dicom_out_path, list): + try: + if len(dicom_out_path) == 1: + ds = dicom_out_path[0] + if hasattr(ds, "pixel_array") and len(ds.pixel_array.shape) >= 2: + dicom_to_nifti_multiframe(ds, nii_path) + + return True + except Exception as e: + logger.on_debug("Multi-Frame DICOM did not work:", e) convert_dicom.dicom_array_to_nifti(dicom_out_path, nii_path, True) else: # func_timeout(10, dicom2nifti.dicom_series_to_nifti, (dicom_out_path, nii_path, True)) @@ -112,8 +166,9 @@ def _convert_to_nifti(dicom_out_path, nii_path): logger.print("Save ", nii_path, Log_Type.SAVE) except dicom2nifti.exceptions.ConversionValidationError as e: if e.args[0] in ["NON_IMAGING_DICOM_FILES"]: + s = f"dicom_array_to_nifti len={len(dicom_out_path)}" if isinstance(dicom_out_path, list) else "dicom_series_to_nifti" Path(str(nii_path).replace(".nii.gz", ".json")).unlink(missing_ok=True) - logger.on_debug(f"Not exportable '{Path(nii_path).name}':", e.args[0]) + logger.on_debug(f"Not exportable '{Path(nii_path).name}':", e.args[0], s) return False for key, reason in [ ("validate_orthogonal", "NON_CUBICAL_IMAGE/GANTRY_TILT"), @@ -311,7 +366,7 @@ def _read_dicom_files(dicom_out_path: Path) -> tuple[dict[str, list[FileDataset] path = Path(_paths) if path.is_file(): try: - dcm_data = pydicom.dcmread(path, defer_size="1 KB", force=True) + dcm_data = pydicom.dcmread(path, defer_size="1 KB", force=True) # , stop_before_pixels=True try: typ = ( str(dcm_data.get_item((0x0008, 0x0008)).value) @@ -324,7 +379,6 @@ def _read_dicom_files(dicom_out_path: Path) -> tuple[dict[str, list[FileDataset] except Exception: typ = "" key1 = str(dcm_data.SeriesInstanceUID) - key = f"{key1}_{typ}" if not hasattr(dcm_data, "ImageOrientationPatient"): key += "_" + dcm_data.get("SOPInstanceUID", 0) @@ -523,6 +577,13 @@ def process_series(key, files, parts): if __name__ == "__main__": + extract_dicom_folder( + Path("/media/robert/NRAD/DSA_Data/DICOMS_DSA_all/0001018804/"), + Path("/media/data/robert/test/dicom2nii", f"dataset-{'dsa'}"), + False, + 0, + ) + sys.exit() # s = "/home/robert/Downloads/bein/dataset-oberschenkel/rawdata/sub-1-3-46-670589-11-2889201787-2305829596-303261238-2367429497/mr/sub-1-3-46-670589-11-2889201787-2305829596-303261238-2367429497_sequ-406_mr.nii.gz" # nii2 = NII.load(s, False) # print(nii2.affine, nii2.orientation) From 584efea25db34863e17751f0c91d8f57fac423d0 Mon Sep 17 00:00:00 2001 From: ga84mun Date: Fri, 17 Oct 2025 09:30:40 -0400 Subject: [PATCH 15/28] prevent error when cuda is not installed. --- TPTBox/segmentation/nnUnet_utils/predictor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/TPTBox/segmentation/nnUnet_utils/predictor.py b/TPTBox/segmentation/nnUnet_utils/predictor.py index a917b54..6818535 100755 --- a/TPTBox/segmentation/nnUnet_utils/predictor.py +++ b/TPTBox/segmentation/nnUnet_utils/predictor.py @@ -224,7 +224,8 @@ def mapp(d: dict): self.network.load_state_dict(params) # type: ignore else: self.network._orig_mod.load_state_dict(params) - self.network.cuda() # type: ignore + if self.device.type == "cuda": + self.network.cuda() # type: ignore self.network.eval() # type: ignore self.loaded_networks.append(self.network) # print(type(self.loaded_networks[0])) From 11e4e8a6f254c1aa5a33c4e9e8189dc99111a8de Mon Sep 17 00:00:00 2001 From: ga84mun Date: Fri, 17 Oct 2025 09:32:42 -0400 Subject: [PATCH 16/28] remove total from name --- TPTBox/core/dicom/dicom_extract.py | 13 ++++----- TPTBox/segmentation/TotalVibeSeg/__init__.py | 3 -- .../auto_download.py | 0 .../inference_nnunet.py | 4 +-- .../totalvibeseg.py => VibeSeg/vibeseg.py} | 29 +++++++++---------- TPTBox/segmentation/__init__.py | 3 +- .../IVD_transfer/transfare_spine_seg.py | 2 +- 7 files changed, 24 insertions(+), 30 deletions(-) delete mode 100644 TPTBox/segmentation/TotalVibeSeg/__init__.py rename TPTBox/segmentation/{TotalVibeSeg => VibeSeg}/auto_download.py (100%) rename TPTBox/segmentation/{TotalVibeSeg => VibeSeg}/inference_nnunet.py (98%) rename TPTBox/segmentation/{TotalVibeSeg/totalvibeseg.py => VibeSeg/vibeseg.py} (89%) diff --git a/TPTBox/core/dicom/dicom_extract.py b/TPTBox/core/dicom/dicom_extract.py index 0a0075c..124a2a5 100644 --- a/TPTBox/core/dicom/dicom_extract.py +++ b/TPTBox/core/dicom/dicom_extract.py @@ -94,7 +94,7 @@ def dicom_to_nifti_multiframe(ds, nii_path): # Pixel spacing (mm) if hasattr(ds, "PixelSpacing"): - dy, dx = [float(v) for v in ds.PixelSpacing] + dy, dx = (float(v) for v in ds.PixelSpacing) # Image orientation (row and column direction cosines) orientation = [float(v) for v in ds.ImageOrientationPatient] row_cosines = np.array(orientation[0:3]) @@ -119,7 +119,7 @@ def dicom_to_nifti_multiframe(ds, nii_path): affine[0:3, 2] = slice_cosines * dz affine[0:3, 3] = origin else: - dy, dx = [float(v) for v in ds.ImagerPixelSpacing] + dy, dx = (float(v) for v in ds.ImagerPixelSpacing) # Einfaches affine (nur 2D + Zeit, keine Lage im Patientenraum) affine = np.eye(4) affine[0, 0] = -dx @@ -577,12 +577,9 @@ def process_series(key, files, parts): if __name__ == "__main__": - extract_dicom_folder( - Path("/media/robert/NRAD/DSA_Data/DICOMS_DSA_all/0001018804/"), - Path("/media/data/robert/test/dicom2nii", f"dataset-{'dsa'}"), - False, - 0, - ) + for p in Path("E:/DSA_Data/DICOMS_DSA_all/").iterdir(): + extract_dicom_folder(p, Path("D:/data/DSA", "dataset-DSA"), False, False) + sys.exit() # s = "/home/robert/Downloads/bein/dataset-oberschenkel/rawdata/sub-1-3-46-670589-11-2889201787-2305829596-303261238-2367429497/mr/sub-1-3-46-670589-11-2889201787-2305829596-303261238-2367429497_sequ-406_mr.nii.gz" # nii2 = NII.load(s, False) diff --git a/TPTBox/segmentation/TotalVibeSeg/__init__.py b/TPTBox/segmentation/TotalVibeSeg/__init__.py deleted file mode 100644 index aa1bfe1..0000000 --- a/TPTBox/segmentation/TotalVibeSeg/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from __future__ import annotations - -from TPTBox.segmentation.TotalVibeSeg.totalvibeseg import extract_vertebra_bodies_from_totalVibe, run_totalvibeseg, total_vibe_map diff --git a/TPTBox/segmentation/TotalVibeSeg/auto_download.py b/TPTBox/segmentation/VibeSeg/auto_download.py similarity index 100% rename from TPTBox/segmentation/TotalVibeSeg/auto_download.py rename to TPTBox/segmentation/VibeSeg/auto_download.py diff --git a/TPTBox/segmentation/TotalVibeSeg/inference_nnunet.py b/TPTBox/segmentation/VibeSeg/inference_nnunet.py similarity index 98% rename from TPTBox/segmentation/TotalVibeSeg/inference_nnunet.py rename to TPTBox/segmentation/VibeSeg/inference_nnunet.py index 8681524..a062cb0 100644 --- a/TPTBox/segmentation/TotalVibeSeg/inference_nnunet.py +++ b/TPTBox/segmentation/VibeSeg/inference_nnunet.py @@ -11,7 +11,7 @@ import torch from TPTBox import NII, Image_Reference, Log_Type, Print_Logger, to_nii -from TPTBox.segmentation.TotalVibeSeg.auto_download import download_weights +from TPTBox.segmentation.VibeSeg.auto_download import download_weights logger = Print_Logger() out_base = Path(__file__).parent.parent / "nnUNet/" @@ -187,7 +187,7 @@ def run_inference_on_file( idx_models = [80, 87, 86, 85] -def run_total_seg( +def run_VibeSeg( img: Path | str | list[Path] | list[NII], out_path: Path, override=False, diff --git a/TPTBox/segmentation/TotalVibeSeg/totalvibeseg.py b/TPTBox/segmentation/VibeSeg/vibeseg.py similarity index 89% rename from TPTBox/segmentation/TotalVibeSeg/totalvibeseg.py rename to TPTBox/segmentation/VibeSeg/vibeseg.py index 2a2d5ee..51a26e3 100644 --- a/TPTBox/segmentation/TotalVibeSeg/totalvibeseg.py +++ b/TPTBox/segmentation/VibeSeg/vibeseg.py @@ -4,9 +4,9 @@ from typing import Literal from TPTBox import Image_Reference, to_nii -from TPTBox.segmentation.TotalVibeSeg.inference_nnunet import run_inference_on_file +from TPTBox.segmentation.VibeSeg.inference_nnunet import run_inference_on_file -total_vibe_map = { +VibeSeg_map = { 1: "spleen", 2: "kidney_right", 3: "kidney_left", @@ -82,7 +82,7 @@ } -def run_totalvibeseg( +def run_vibeseg( i: Image_Reference, out_seg: str | Path, override=False, @@ -126,22 +126,22 @@ def run_nnunet( ) -def extract_vertebra_bodies_from_totalVibe( - nii_total: Image_Reference, +def extract_vertebra_bodies_from_VibeSeg( + nii_vibeSeg: Image_Reference, num_thoracic_verts: int = 12, num_lumbar_verts: int = 5, out_path: str | Path | None = None, out_path_poi: str | Path | None = None, ): """ - Extracts and labels vertebra bodies from a totalVibe segmentation NIfTI file. + Extracts and labels vertebra bodies from a VibeSeg segmentation NIfTI file. This function processes a segmentation mask containing vertebrae and intervertebral discs (IVDs). It separates individual vertebra bodies by eroding and splitting the mask at IVD regions, labels the vertebrae from bottom to top (lumbar and thoracic), and optionally saves the labeled mask and point-of-interest (POI) data. Args: - nii_total (Image_Reference): Path or reference to the NIfTI file containing the totalVibe segmentation mask. + nii_vibeSeg (Image_Reference): Path or reference to the NIfTI file containing the VibeSeg segmentation mask. num_thoracic_verts (int, optional): Number of thoracic vertebrae to include. Defaults to 12. num_lumbar_verts (int, optional): Number of lumbar vertebrae to include. Defaults to 5. out_path (str | Path | None, optional): Path to save the processed mask data. If None, no files are saved. Defaults to None. @@ -159,13 +159,13 @@ def extract_vertebra_bodies_from_totalVibe( - Mask file: `` - POI file: `` with `_poi.json` suffix recommended. Example: - >>> nii_total = "/path/to/vibe_segmentation.nii.gz" - >>> labeled_mask, centroids = extract_vertebra_bodies_from_totalVibe(nii_total, out_path="output_mask.nii.gz") + >>> nii_vibeSeg = "/path/to/vibe_segmentation.nii.gz" + >>> labeled_mask, centroids = extract_vertebra_bodies_from_nii_vibeSeg(nii_vibeSeg, out_path="output_mask.nii.gz") """ from TPTBox import Vertebra_Instance, calc_centroids - # Load the totalVibe segmentation - nii = to_nii(nii_total, seg=True) + # Load the nii_vibeSeg segmentation + nii = to_nii(nii_vibeSeg, seg=True) vertebrae = nii.extract_label(69) ivds = nii.extract_label(68) @@ -211,9 +211,8 @@ def map_to_label(index): if __name__ == "__main__": from TPTBox import BIDS_FILE - from TPTBox.segmentation import run_totalvibeseg + from TPTBox.segmentation import run_vibeseg - # run_totalvibeseg # You can also use a string/Path if you want to set the path yourself. dataset = "/media/data/robert/datasets/dicom_example/dataset-VR-DICOM2/" in_file = BIDS_FILE( @@ -224,6 +223,6 @@ def map_to_label(index): "nii.gz", "msk", parent="derivative", - info={"seg": "TotalVibeSegmentator", "mod": in_file.bids_format}, + info={"seg": "VibeSeg", "mod": in_file.bids_format}, ) - run_totalvibeseg(in_file, out_file) + run_vibeseg(in_file, out_file) diff --git a/TPTBox/segmentation/__init__.py b/TPTBox/segmentation/__init__.py index a23e6b5..aa94afb 100644 --- a/TPTBox/segmentation/__init__.py +++ b/TPTBox/segmentation/__init__.py @@ -1,4 +1,5 @@ from __future__ import annotations from TPTBox.segmentation.spineps import run_spineps_all, run_spineps_single -from TPTBox.segmentation.TotalVibeSeg import run_totalvibeseg +from TPTBox.segmentation.VibeSeg import run_vibeseg +from TPTBox.segmentation.VibeSeg import run_vibeseg as run_totalvibeseg # TODO deprecate diff --git a/examples/registration/IVD_transfer/transfare_spine_seg.py b/examples/registration/IVD_transfer/transfare_spine_seg.py index 004c239..9b3a404 100644 --- a/examples/registration/IVD_transfer/transfare_spine_seg.py +++ b/examples/registration/IVD_transfer/transfare_spine_seg.py @@ -277,7 +277,7 @@ def _remove_LWS6_and_5(ref): def _get_template(vert: NII): - from TPTBox.segmentation.TotalVibeSeg.auto_download import _download + from TPTBox.segmentation.VibeSeg.auto_download import _download tmp = Path(os.path.join(gettempdir(), "spine-templates")) tmp.mkdir(exist_ok=True) From 01a36db364562bc29eaec3dc182a22889fa08fcd Mon Sep 17 00:00:00 2001 From: ga84mun Date: Fri, 17 Oct 2025 10:40:58 -0400 Subject: [PATCH 17/28] remove oar segmentation. Use CATS instead --- .../segmentation/oar_segmentator/__init__.py | 0 .../oar_segmentator/map_to_binary.py | 451 ------------------ TPTBox/segmentation/oar_segmentator/run.py | 121 ----- 3 files changed, 572 deletions(-) delete mode 100644 TPTBox/segmentation/oar_segmentator/__init__.py delete mode 100644 TPTBox/segmentation/oar_segmentator/map_to_binary.py delete mode 100644 TPTBox/segmentation/oar_segmentator/run.py diff --git a/TPTBox/segmentation/oar_segmentator/__init__.py b/TPTBox/segmentation/oar_segmentator/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/TPTBox/segmentation/oar_segmentator/map_to_binary.py b/TPTBox/segmentation/oar_segmentator/map_to_binary.py deleted file mode 100644 index 4a566d9..0000000 --- a/TPTBox/segmentation/oar_segmentator/map_to_binary.py +++ /dev/null @@ -1,451 +0,0 @@ -from __future__ import annotations - -class_map = { - 1: "spleen", - 2: "kidney_right", - 3: "kidney_left", - 4: "gallbladder", - 5: "liver", - 6: "stomach", - 7: "aorta", - 8: "inferior_vena_cava", - 9: "portal_vein_and_splenic_vein", - 10: "pancreas", - 11: "adrenal_gland_right", - 12: "adrenal_gland_left", - 13: "lung_upper_lobe_left", - 14: "lung_lower_lobe_left", - 15: "lung_upper_lobe_right", - 16: "lung_middle_lobe_right", - 17: "lung_lower_lobe_right", - 18: "vertebrae_L5", - 19: "vertebrae_L4", - 20: "vertebrae_L3", - 21: "vertebrae_L2", - 22: "vertebrae_L1", - 23: "vertebrae_T12", - 24: "vertebrae_T11", - 25: "vertebrae_T10", - 26: "vertebrae_T9", - 27: "vertebrae_T8", - 28: "vertebrae_T7", - 29: "vertebrae_T6", - 30: "vertebrae_T5", - 31: "vertebrae_T4", - 32: "vertebrae_T3", - 33: "vertebrae_T2", - 34: "vertebrae_T1", - 35: "vertebrae_C7", - 36: "vertebrae_C6", - 37: "vertebrae_C5", - 38: "vertebrae_C4", - 39: "vertebrae_C3", - 40: "vertebrae_C2", - 41: "vertebrae_C1", - 42: "esophagus", - 43: "trachea", - 44: "heart_myocardium", - 45: "heart_atrium_left", - 46: "heart_ventricle_left", - 47: "heart_atrium_right", - 48: "heart_ventricle_right", - 49: "pulmonary_artery", - 50: "brain", - 51: "iliac_artery_left", - 52: "iliac_artery_right", - 53: "iliac_vena_left", - 54: "iliac_vena_right", - 55: "small_bowel", - 56: "duodenum", - 57: "colon", - 58: "rib_left_1", - 59: "rib_left_2", - 60: "rib_left_3", - 61: "rib_left_4", - 62: "rib_left_5", - 63: "rib_left_6", - 64: "rib_left_7", - 65: "rib_left_8", - 66: "rib_left_9", - 67: "rib_left_10", - 68: "rib_left_11", - 69: "rib_left_12", - 70: "rib_right_1", - 71: "rib_right_2", - 72: "rib_right_3", - 73: "rib_right_4", - 74: "rib_right_5", - 75: "rib_right_6", - 76: "rib_right_7", - 77: "rib_right_8", - 78: "rib_right_9", - 79: "rib_right_10", - 80: "rib_right_11", - 81: "rib_right_12", - 82: "humerus_left", - 83: "humerus_right", - 84: "scapula_left", - 85: "scapula_right", - 86: "clavicula_left", - 87: "clavicula_right", - 88: "femur_left", - 89: "femur_right", - 90: "hip_left", - 91: "hip_right", - 92: "sacrum", - 93: "face", - 94: "gluteus_maximus_left", - 95: "gluteus_maximus_right", - 96: "gluteus_medius_left", - 97: "gluteus_medius_right", - 98: "gluteus_minimus_left", - 99: "gluteus_minimus_right", - 100: "autochthon_left", - 101: "autochthon_right", - 102: "iliopsoas_left", - 103: "iliopsoas_right", - 104: "urinary_bladder", - 105: "sternum", - 106: "thyroid gland", - 107: "right adrenal gland", - 108: "left adrenal gland", - 109: "right psoas major", - 110: "left psoas major", - 111: "right rectus abdominis", - 112: "left rectus abdominis", - 113: "brainstem", - 114: "spinal canal", - 115: "left parotid gland", - 116: "right parotid gland", - 117: "left submandibular gland", - 118: "right submandibular gland", - 119: "larynx", - 120: "sigmoid", - 121: "rectum", - 122: "prostate", - 123: "seminal vasicle", - 124: "left mammary gland", - 125: "right mammary gland", - 126: "white matter", - 127: "gray matter", - 128: "csf", - 129: "head bone", - 130: "scalp", - 131: "eye balls", - 132: "compact bone", - 133: "spongy bone", - 134: "blood", - 135: "head muscles", - 136: "OAR_A_Carotid_L", - 137: "OAR_A_Carotid_R", - 138: "OAR_Arytenoid", - 139: "OAR_Bone_Mandible", - 140: "OAR_Brainstem", - 141: "OAR_BuccalMucosa", - 142: "OAR_Cavity_Oral", - 143: "OAR_Cochlea_L", - 144: "OAR_Cochlea_R", - 145: "OAR_Cricopharyngeus", - 146: "OAR_Esophagus_S", - 147: "OAR_Eye_AL", - 148: "OAR_Eye_AR", - 149: "OAR_Eye_PL", - 150: "OAR_Eye_PR", - 151: "OAR_Glnd_Lacrimal_L", - 152: "OAR_Glnd_Lacrimal_R", - 153: "OAR_Glnd_Submand_L", - 154: "OAR_Glnd_Submand_R", - 155: "OAR_Glnd_Thyroid", - 156: "OAR_Glottis", - 157: "OAR_Larynx_SG", - 158: "OAR_Lips", - 159: "OAR_OpticChiasm", - 160: "OAR_OpticNrv_L", - 161: "OAR_OpticNrv_R", - 162: "OAR_Parotid_L", - 163: "OAR_Parotid_R", - 164: "OAR_Pituitary", - 165: "OAR_SpinalCord", - 166: "subcutaneous_tissue", - 167: "muscle", - 168: "abdominal_cavity", - 169: "thoracic_cavity", - 170: "bones", - 171: "glands", - 172: "pericardium", - 173: "breast_implant", - 174: "mediastinum", - 175: "brain_head", - 176: "spinal_cord", -} - -class_map_9_parts = { - # 17 classes part 251 - "class_map_part_organs": { - 1: "spleen", - 2: "kidney_right", - 3: "kidney_left", - 4: "gallbladder", - 5: "liver", - 6: "stomach", - 7: "aorta", - 8: "inferior_vena_cava", - 9: "portal_vein_and_splenic_vein", - 10: "pancreas", - 11: "adrenal_gland_right", - 12: "adrenal_gland_left", - 13: "lung_upper_lobe_left", - 14: "lung_lower_lobe_left", - 15: "lung_upper_lobe_right", - 16: "lung_middle_lobe_right", - 17: "lung_lower_lobe_right", - }, - # 24 classes part 252 - "class_map_part_vertebrae": { - 1: "vertebrae_L5", - 2: "vertebrae_L4", - 3: "vertebrae_L3", - 4: "vertebrae_L2", - 5: "vertebrae_L1", - 6: "vertebrae_T12", - 7: "vertebrae_T11", - 8: "vertebrae_T10", - 9: "vertebrae_T9", - 10: "vertebrae_T8", - 11: "vertebrae_T7", - 12: "vertebrae_T6", - 13: "vertebrae_T5", - 14: "vertebrae_T4", - 15: "vertebrae_T3", - 16: "vertebrae_T2", - 17: "vertebrae_T1", - 18: "vertebrae_C7", - 19: "vertebrae_C6", - 20: "vertebrae_C5", - 21: "vertebrae_C4", - 22: "vertebrae_C3", - 23: "vertebrae_C2", - 24: "vertebrae_C1", - }, - # 18 part 253 - "class_map_part_cardiac": { - 1: "esophagus", - 2: "trachea", - 3: "heart_myocardium", - 4: "heart_atrium_left", - 5: "heart_ventricle_left", - 6: "heart_atrium_right", - 7: "heart_ventricle_right", - 8: "pulmonary_artery", - 9: "brain", - 10: "iliac_artery_left", - 11: "iliac_artery_right", - 12: "iliac_vena_left", - 13: "iliac_vena_right", - 14: "small_bowel", - 15: "duodenum", - 16: "colon", - 17: "urinary_bladder", - 18: "face", - }, - # 21 part 254 - "class_map_part_muscles": { - 1: "humerus_left", - 2: "humerus_right", - 3: "scapula_left", - 4: "scapula_right", - 5: "clavicula_left", - 6: "clavicula_right", - 7: "femur_left", - 8: "femur_right", - 9: "hip_left", - 10: "hip_right", - 11: "sacrum", - 12: "gluteus_maximus_left", - 13: "gluteus_maximus_right", - 14: "gluteus_medius_left", - 15: "gluteus_medius_right", - 16: "gluteus_minimus_left", - 17: "gluteus_minimus_right", - 18: "autochthon_left", - 19: "autochthon_right", - 20: "iliopsoas_left", - 21: "iliopsoas_right", - }, - # 24 classes part 255 - # 12. ribs start from vertebrae T12 - # Small subset of population (roughly 8%) have 13. rib below 12. rib - # (would start from L1 then) - # -> this has label rib_12 - # Even smaller subset (roughly 1%) has extra rib above 1. rib ("Halsrippe") - # (the extra rib would start from C7) - # -> this has label rib_1 - # - # Quite often only 11 ribs (12. ribs probably so small that not found). Those - # cases often wrongly segmented. - "class_map_part_ribs": { - 1: "rib_left_1", - 2: "rib_left_2", - 3: "rib_left_3", - 4: "rib_left_4", - 5: "rib_left_5", - 6: "rib_left_6", - 7: "rib_left_7", - 8: "rib_left_8", - 9: "rib_left_9", - 10: "rib_left_10", - 11: "rib_left_11", - 12: "rib_left_12", - 13: "rib_right_1", - 14: "rib_right_2", - 15: "rib_right_3", - 16: "rib_right_4", - 17: "rib_right_5", - 18: "rib_right_6", - 19: "rib_right_7", - 20: "rib_right_8", - 21: "rib_right_9", - 22: "rib_right_10", - 23: "rib_right_11", - 24: "rib_right_12", - }, - # 21 classes organs at risk in house and visceral - # part 256 - "class_map_part_oarrad": { - 1: "sternum", - 2: "thyroid gland", - 3: "right adrenal gland", - 4: "left adrenal gland", - 5: "right psoas major", - 6: "left psoas major", - 7: "right rectus abdominis", - 8: "left rectus abdominis", - 9: "brainstem", - 10: "spinal canal", - 11: "left parotid gland", - 12: "right parotid gland", - 13: "left submandibular gland", - 14: "right submandibular gland", - 15: "larynx", - 16: "sigmoid", - 17: "rectum", - 18: "prostate", - 19: "seminal vasicle", - 20: "left mammary gland", - 21: "right mammary gland", - }, - ##10 classes head koens registration - ## part 257 - "class_map_part_head": { - 1: "white matter", - 2: "gray matter", - 3: "csf", - 4: "head bone", - 5: "scalp", - 6: "eye balls", - 7: "compact bone", - 8: "spongy bone", - 9: "blood", - 10: "head muscles", - }, - ## 30 lasses Head and neck structures from HanSeg - ## part 258 - "class_map_part_headneck": { - 1: "OAR_A_Carotid_L", - 2: "OAR_A_Carotid_R", - 3: "OAR_Arytenoid", - 4: "OAR_Bone_Mandible", - 5: "OAR_Brainstem", - 6: "OAR_BuccalMucosa", - 7: "OAR_Cavity_Oral", - 8: "OAR_Cochlea_L", - 9: "OAR_Cochlea_R", - 10: "OAR_Cricopharyngeus", - 11: "OAR_Esophagus_S", - 12: "OAR_Eye_AL", - 13: "OAR_Eye_AR", - 14: "OAR_Eye_PL", - 15: "OAR_Eye_PR", - 16: "OAR_Glnd_Lacrimal_L", - 17: "OAR_Glnd_Lacrimal_R", - 18: "OAR_Glnd_Submand_L", - 19: "OAR_Glnd_Submand_R", - 20: "OAR_Glnd_Thyroid", - 21: "OAR_Glottis", - 22: "OAR_Larynx_SG", - 23: "OAR_Lips", - 24: "OAR_OpticChiasm", - 25: "OAR_OpticNrv_L", - 26: "OAR_OpticNrv_R", - 27: "OAR_Parotid_L", - 28: "OAR_Parotid_R", - 29: "OAR_Pituitary", - 30: "OAR_SpinalCord", - }, - ## 11 claases for body parts including subcutaneous tissue - ## part 259 - "class_map_part_bodyregions": { - 1: "subcutaneous_tissue", - 2: "muscle", - 3: "abdominal_cavity", - 4: "thoracic_cavity", - 5: "bones", - 6: "glands", - 7: "pericardium", - 8: "breast_implant", - 9: "mediastinum", - 10: "brain", - 11: "spinal_cord", - }, -} - -# map_taskid_to_partname = { -# 251: "class_map_part_organs", -# 252: "class_map_part_vertebrae", -# 253: "class_map_part_cardiac", -# 254: "class_map_part_muscles", -# 255: "class_map_part_ribs" -# } - -map_taskid_to_partname = { - 251: "class_map_part_organs", - 252: "class_map_part_vertebrae", - 253: "class_map_part_cardiac", - 254: "class_map_part_muscles", - 255: "class_map_part_ribs", - 256: "class_map_part_oarrad", - 257: "class_map_part_head", - 258: "class_map_part_headneck", - 259: "class_map_part_bodyregions", -} -except_labels_combine = [ - "OAR_Brainstem", - "brain", # place for gray-white matter - "bone", # place for vertebrae - "OAR_Parotid_L", # included in visceral - "OAR_Parotid_R", # included in visceral - "OAR_Glnd_Submand_L", # included in visceral - "OAR_Glnd_Submand_R", # included in visceral - "OAR_Glnd_Thyroid", # included in visceral - "OAR_Larynx_SG", # included in visceral - "right adrenal gland", # included in totalseg - "left adrenal gland", # included in totalseg - "OAR_Esophagus_S", # included in totalseg - "face", # for visualization - "glands", # from saros - "head bone", # overlaps with hanseg structures - "scalp", # cover all head structures - "eye balls", # repeated in hanseg - "compact bone", # overlaps with hanseg structures - "spongy bone", # overlaps with hanseg structures - "blood", # overlaps with hanseg structures - "head muscles", # overlaps with hanseg structures - "muscle", - "abdominal_cavity", - "bones", # overlaps for vertebrae - "glands", - "pericardium", - "breast_implant", - "mediastinum", - "spinal_cord", -] diff --git a/TPTBox/segmentation/oar_segmentator/run.py b/TPTBox/segmentation/oar_segmentator/run.py deleted file mode 100644 index a8efc63..0000000 --- a/TPTBox/segmentation/oar_segmentator/run.py +++ /dev/null @@ -1,121 +0,0 @@ -from __future__ import annotations - -import time -from collections.abc import Sequence -from concurrent.futures import ThreadPoolExecutor, as_completed -from pathlib import Path - -import GPUtil -from tqdm import tqdm - -from TPTBox import BIDS_FILE, NII, POI, to_nii -from TPTBox.segmentation.nnUnet_utils.inference_api import load_inf_model, run_inference -from TPTBox.segmentation.oar_segmentator.map_to_binary import class_map, class_map_9_parts, except_labels_combine, map_taskid_to_partname - -class_map_inv = {v: k for k, v in class_map.items()} - - -def save_resampled_segmentation(seg_nii: NII, in_file: BIDS_FILE, parent, org: NII | POI, idx): - """Helper function to resample and save NIfTI file.""" - out_path = in_file.get_changed_path("nii.gz", "msk", parent=parent, info={"seg": f"oar-{idx}"}, non_strict_mode=True) - seg_nii.resample_from_to(org, verbose=False, mode="nearest").save(out_path) - - -def run_oar_segmentor( - ct_path: Path | str | BIDS_FILE, - dataset: Path | str | None = None, - oar_path="/home/fercho/code/oar_segmentator/models/nnunet/results/nnUNet/3d_fullres/", - parent="derivatives", - gpu=None, - override=False, -): - ## Hard coded info ## - zoom = 1.5 - orientation = ("R", "A", "S") - #### - if isinstance(ct_path, BIDS_FILE): - in_file = ct_path - else: - if dataset is None: - dataset = Path(ct_path).parent - in_file = BIDS_FILE(ct_path, dataset) - out_path_combined = in_file.get_changed_path("nii.gz", "msk", parent=parent, info={"seg": "oar-combined"}, non_strict_mode=True) - if out_path_combined.exists() and not override: - print("skip", out_path_combined.name, " ", end="\r") - return - org = to_nii(in_file) - print("resample ") - input_nii = org.rescale((zoom, zoom, zoom), mode="nearest").reorient(orientation) - org = (org.shape, org.affine, org.zoom) - segs: dict[int, NII] = {} - futures = [] - # Create ThreadPoolExecutor for parallel saving - print("start") - with ThreadPoolExecutor(max_workers=4) as executor: - for idx in tqdm(range(251, 260), desc="Predict oar segmentation"): - # Suppress stdout and stderr for run_inference - nnunet_path = next(next(iter(Path(oar_path).glob(f"*{idx}*"))).glob("*__nnUNetPlans*")) - nnunet = load_inf_model(nnunet_path, allow_non_final=True, use_folds=(0,), gpu=gpu) - seg_nii, _, _ = run_inference(input_nii, nnunet, logits=False) - segs[idx] = seg_nii - # Submit the save task to the thread pool - futures.append(executor.submit(save_resampled_segmentation, seg_nii, in_file, parent, org, idx)) - # Wait for all save tasks to complete - for future in as_completed(futures): - future.result() # Ensure any exceptions in threads are raised - seg_combined = segs[251] * 0 - for tid in range(251, 260): - seg = segs[tid] - for jdx, class_name in class_map_9_parts[map_taskid_to_partname[tid]].items(): - if any(class_name in s for s in except_labels_combine): - continue - seg_combined[seg == jdx] = class_map_inv[class_name] - seg_combined.resample_from_to(org, verbose=False, mode="nearest").save(out_path_combined) - - -def check_gpu_memory(gpu_id, threshold=50): - """Check the GPU memory utilization and return True if usage exceeds threshold.""" - gpus = GPUtil.getGPUs() - for gpu in gpus: - if gpu.id == gpu_id: - return gpu.memoryUtil * 100 > threshold - return False - - -def run_oar_segmentor_in_parallel(dataset, parents: Sequence[str] = ("rawdata",), gpu_id=3, threshold=50, max_workers=16, override=False): - """Run the OAR segmentation in parallel and pause when GPU memory usage exceeds the threshold.""" - from TPTBox import BIDS_Global_info - - bgi = BIDS_Global_info([dataset], parents=parents) - - futures = [] - - # ThreadPoolExecutor for parallel execution - with ThreadPoolExecutor(max_workers=max_workers) as executor: - for _name, subject in bgi.enumerate_subjects(): - q = subject.new_query(flatten=True) - q.filter_filetype("nii.gz") - q.filter_format("ct") - - for i in q.loop_list(): - # Check GPU memory usage and pause if above threshold - while check_gpu_memory(gpu_id, threshold): - print(f"GPU memory usage exceeded {threshold}%. Pausing submission...") - time.sleep(10) # Pause for 10 seconds before checking again - - # Submit run_oar_segmentor task to the executor - futures.append(executor.submit(run_oar_segmentor, i, gpu=gpu_id, override=override)) - - # Wait for all tasks to complete - for future in as_completed(futures): - try: - future.result() # This will raise any exceptions encountered - except Exception as e: - print(f"Error in execution: {e}") - - -if __name__ == "__main__": - # Example usage - bgi = "/DATA/NAS/datasets_processed/CT_spine/dataset-shockroom-without-fx/" - - run_oar_segmentor_in_parallel(bgi, ("rawdata_fixed",), gpu_id=0, threshold=50, max_workers=16, override=False) From 2976ba89698be60a6a6a68955017150311b80e4b Mon Sep 17 00:00:00 2001 From: ga84mun Date: Fri, 17 Oct 2025 11:55:26 -0400 Subject: [PATCH 18/28] add __all__ --- TPTBox/registration/__init__.py | 6 ++++ .../ridged_intensity/affine_deepali.py | 34 +------------------ .../registration/ridged_intensity/register.py | 3 ++ TPTBox/segmentation/__init__.py | 2 +- 4 files changed, 11 insertions(+), 34 deletions(-) diff --git a/TPTBox/registration/__init__.py b/TPTBox/registration/__init__.py index cacd365..b74375d 100755 --- a/TPTBox/registration/__init__.py +++ b/TPTBox/registration/__init__.py @@ -16,3 +16,9 @@ except Exception: pass +__all__ = [ + "General_Registration", + "Point_Registration", + "ridged_points_from_poi", + "ridged_points_from_subreg_vert", +] diff --git a/TPTBox/registration/ridged_intensity/affine_deepali.py b/TPTBox/registration/ridged_intensity/affine_deepali.py index 1dad5bb..9a4f9a9 100644 --- a/TPTBox/registration/ridged_intensity/affine_deepali.py +++ b/TPTBox/registration/ridged_intensity/affine_deepali.py @@ -254,7 +254,7 @@ def forward( # Pairwise distances (N_wrong, N_gt); differentiable d = torch.cdist(subsample_coords(wrong_coords, 5000), gt_coords) - min_dists = d.min(dim=1).values # (N_wrong,) + min_dists = d.min(dim=1).to_numpy() # (N_wrong,) per_class_losses.append(min_dists.mean()) @@ -266,38 +266,6 @@ def forward( return torch.stack(per_class_losses).mean() -class LABEL_LOSS(PairwiseImageLoss): - def __init__( - self, - max_v=1, - *args, - **kwargs, - ) -> None: - self.max_v = max_v - super().__init__(*args, **kwargs) - - def forward( - self, - source: torch.Tensor, - target: torch.Tensor, - mask: torch.Tensor | None = None, # noqa: ARG002 - ) -> torch.Tensor: # noqa: ARG002 - loss = torch.zeros(1, device=source.device) - target = (target * self.max_v).round(decimals=0) - source = (source * self.max_v).round(decimals=0) - u = torch.unique(target) - for i in u: - if i == 0: - continue - com_fixed = (target == i) ^ (source == i) - l_com = com_fixed.sum() / ((source == i).sum() + (target == i).sum()) - l_com = torch.nan_to_num(l_com, nan=0) - print(i, l_com) - loss += l_com - - return loss / len(u) - - class Rigid_Registration_with_Tether(General_Registration): def __init__( self, diff --git a/TPTBox/registration/ridged_intensity/register.py b/TPTBox/registration/ridged_intensity/register.py index 3f1a033..0d8ed0c 100755 --- a/TPTBox/registration/ridged_intensity/register.py +++ b/TPTBox/registration/ridged_intensity/register.py @@ -20,6 +20,9 @@ raise from TPTBox import AX_CODES, NII +""" +Wrapper functions for different registration methods with ants and nipy. +""" Similarity_Measures = Literal["slr", "mi", "pmi", "dpmi", "cc", "cr", "crl1"] Affine_Transforms = Literal["affine", "affine2d", "similarity", "similarity2d", "rigid", "rigid2d"] diff --git a/TPTBox/segmentation/__init__.py b/TPTBox/segmentation/__init__.py index aa94afb..ced3f44 100644 --- a/TPTBox/segmentation/__init__.py +++ b/TPTBox/segmentation/__init__.py @@ -1,5 +1,5 @@ from __future__ import annotations from TPTBox.segmentation.spineps import run_spineps_all, run_spineps_single -from TPTBox.segmentation.VibeSeg import run_vibeseg +from TPTBox.segmentation.VibeSeg import extract_vertebra_bodies_from_VibeSeg, run_inference_on_file, run_nnunet, run_vibeseg from TPTBox.segmentation.VibeSeg import run_vibeseg as run_totalvibeseg # TODO deprecate From 3332ea6fc3cc12a1a07df4d99e0aebd66d263324 Mon Sep 17 00:00:00 2001 From: ga84mun Date: Fri, 17 Oct 2025 14:05:34 -0400 Subject: [PATCH 19/28] x --- TPTBox/segmentation/VibeSeg/inference_nnunet.py | 2 +- pyproject.toml | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/TPTBox/segmentation/VibeSeg/inference_nnunet.py b/TPTBox/segmentation/VibeSeg/inference_nnunet.py index e9cd759..4f1295a 100644 --- a/TPTBox/segmentation/VibeSeg/inference_nnunet.py +++ b/TPTBox/segmentation/VibeSeg/inference_nnunet.py @@ -72,7 +72,7 @@ def run_inference_on_file( if out_file is not None and Path(out_file).exists() and not override: return out_file, None - from TPTBox.segmentation.nnUnet_utils.inference_api import load_inf_model, run_inference # noqa: PLC0415 + from TPTBox.segmentation.nnUnet_utils.inference_api import load_inf_model, run_inference if isinstance(idx, int): download_weights(idx, model_path) diff --git a/pyproject.toml b/pyproject.toml index 6fc9f79..d22ed2c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -130,6 +130,7 @@ select = [ ignore = [ "C901", + "RUF059", "RUF100", "F401", "BLE001", @@ -158,13 +159,11 @@ ignore = [ "B905", # strict= in zip "UP007", # Union and "|" python 3.9 "PLC0415", # import-outside-top-level - "RUF059", -] + ] # Allow fix for all enabled rules (when `--fix`) is provided. fixable = ["ALL"] unfixable = [] -ignore-init-module-imports = true extend-safe-fixes = ["RUF015", "C419", "C408", "B006"] #unnecessary-iterable-allocation-for-first-element = true From c77ffb22559d4c55b9ad010d3babc8ed1dc24f2d Mon Sep 17 00:00:00 2001 From: ga84mun Date: Fri, 17 Oct 2025 14:14:22 -0400 Subject: [PATCH 20/28] update tests --- unit_tests/test_poi_global.py | 3 --- unit_tests/test_reg_seg.py | 45 ----------------------------------- 2 files changed, 48 deletions(-) delete mode 100755 unit_tests/test_reg_seg.py diff --git a/unit_tests/test_poi_global.py b/unit_tests/test_poi_global.py index 44c926f..2874fa1 100755 --- a/unit_tests/test_poi_global.py +++ b/unit_tests/test_poi_global.py @@ -22,9 +22,6 @@ def test_glob_by_definition(self): poi.rescale_((3, 2, 1)) glob_poi.to_other_poi(poi) - def test_not_implemented(self): - self.assertRaises(NotImplementedError, POI_Global, None) - def test_is_global(self): poi = get_poi() poi.orientation = ("L", "A", "S") diff --git a/unit_tests/test_reg_seg.py b/unit_tests/test_reg_seg.py deleted file mode 100755 index e1efdff..0000000 --- a/unit_tests/test_reg_seg.py +++ /dev/null @@ -1,45 +0,0 @@ -# Call 'python -m unittest' on this folder -# coverage run -m unittest -# coverage report -# coverage html -from __future__ import annotations - -import os -import random -import sys -import unittest -from pathlib import Path - -sys.path.append(str(Path(__file__).resolve().parents[2])) -import nibabel as nib - -from TPTBox import to_nii - -test_data = [ - "BIDS/test/test_data/sub-fxclass0001_seg-subreg_msk.nii.gz", - "BIDS/test/test_data/sub-fxclass0001_seg-vert_msk.nii.gz", - "BIDS/test/test_data/sub-fxclass0004_seg-subreg_msk.nii.gz", - "BIDS/test/test_data/sub-fxclass0004_seg-vert_msk.nii.gz", -] -out_name_sub = "BIDS/test/test_data/sub-fxclass0004_seg-subreg_reg-0001_msk.nii.gz" -out_name_vert = "BIDS/test/test_data/sub-fxclass0004_seg-vert_reg-0001_msk.nii.gz" - - -class Test_registration(unittest.TestCase): - @unittest.skipIf(not Path(test_data[0]).exists(), "requires real data test data") - def test_seg_registration(self): - pass - # TODO OUTDATED - # t = ridged_segmentation_from_seg(*test_data, verbose=True, ids=list(range(40, 50)), exclusion=[19]) - # slice = t.compute_crop(dist=20) - # nii_out = t.transform_nii(moving_img_nii=(test_data[2], True), slice=slice) - # nii_out.save(out_name_sub) - # nii_out = t.transform_nii(moving_img_nii=(test_data[3], True), slice=slice) - # nii_out.save(out_name_vert) - - -if __name__ == "__main__": - unittest.main() - -# @unittest.skipIf(condition, reason) -# with self.subTest(i=i): From 889432fa5a48212d409067742c418f47b0a8e019 Mon Sep 17 00:00:00 2001 From: Robert Graf Date: Thu, 30 Oct 2025 14:55:39 +0000 Subject: [PATCH 21/28] small bug with new name --- TPTBox/segmentation/VibeSeg/auto_download.py | 6 ++--- TPTBox/segmentation/VibeSeg/vibeseg.py | 4 ++-- TPTBox/segmentation/__init__.py | 4 ++-- TPTBox/segmentation/nnUnet_utils/predictor.py | 23 ++++++++++++++++--- TPTBox/segmentation/spineps.py | 18 ++++++++++++++- 5 files changed, 44 insertions(+), 11 deletions(-) diff --git a/TPTBox/segmentation/VibeSeg/auto_download.py b/TPTBox/segmentation/VibeSeg/auto_download.py index 1dadb92..afc96af 100644 --- a/TPTBox/segmentation/VibeSeg/auto_download.py +++ b/TPTBox/segmentation/VibeSeg/auto_download.py @@ -25,8 +25,8 @@ from tqdm import tqdm logger = logging.getLogger(__name__) -WEIGHTS_URL_ = "https://github.com/robert-graf/TotalVibeSegmentator/releases/download/v1.0.0/" -env_name = "TOTALVIBE_WEIGHTS_PATH" +WEIGHTS_URL_ = "https://github.com/robert-graf/VibeSegmentator/releases/download/v1.0.0/" +env_name = "VIBESEG_WEIGHTS_PATH" def get_weights_dir(idx, model_path: Path | None = None) -> Path: @@ -35,7 +35,7 @@ def get_weights_dir(idx, model_path: Path | None = None) -> Path: elif model_path is not None and model_path.exists(): weights_dir = model_path else: - assert Path(__file__).parent.name == "TotalVibeSeg", Path(__file__).parent + assert Path(__file__).parent.name == "VibeSeg", Path(__file__).parent weights_dir = Path(__file__).parent.parent / "nnUNet/nnUNet_results" diff --git a/TPTBox/segmentation/VibeSeg/vibeseg.py b/TPTBox/segmentation/VibeSeg/vibeseg.py index 51a26e3..49e9c7d 100644 --- a/TPTBox/segmentation/VibeSeg/vibeseg.py +++ b/TPTBox/segmentation/VibeSeg/vibeseg.py @@ -93,7 +93,7 @@ def run_vibeseg( keep_size=False, # Keep size of the model Segmentation **args, ): - run_inference_on_file( + return run_inference_on_file( dataset_id, [to_nii(i)], out_file=out_seg, @@ -103,7 +103,7 @@ def run_vibeseg( padd=padd, keep_size=keep_size, **args, - ) + )[0] def run_nnunet( diff --git a/TPTBox/segmentation/__init__.py b/TPTBox/segmentation/__init__.py index ced3f44..a0fa5d6 100644 --- a/TPTBox/segmentation/__init__.py +++ b/TPTBox/segmentation/__init__.py @@ -1,5 +1,5 @@ from __future__ import annotations from TPTBox.segmentation.spineps import run_spineps_all, run_spineps_single -from TPTBox.segmentation.VibeSeg import extract_vertebra_bodies_from_VibeSeg, run_inference_on_file, run_nnunet, run_vibeseg -from TPTBox.segmentation.VibeSeg import run_vibeseg as run_totalvibeseg # TODO deprecate +from TPTBox.segmentation.VibeSeg.vibeseg import extract_vertebra_bodies_from_VibeSeg, run_inference_on_file, run_nnunet, run_vibeseg +from TPTBox.segmentation.VibeSeg.vibeseg import run_vibeseg as run_totalvibeseg # TODO deprecate diff --git a/TPTBox/segmentation/nnUnet_utils/predictor.py b/TPTBox/segmentation/nnUnet_utils/predictor.py index 6818535..be37fb4 100755 --- a/TPTBox/segmentation/nnUnet_utils/predictor.py +++ b/TPTBox/segmentation/nnUnet_utils/predictor.py @@ -130,11 +130,26 @@ def initialize_from_trained_model_folder( ], "use_mask_for_norm": plans["use_mask_for_norm"], "resampling_fn_data": "resample_data_or_seg_to_shape", - "resampling_fn_data_kwargs": {"is_seg": False, "order": 3, "order_z": 0, "force_separate_z": None}, + "resampling_fn_data_kwargs": { + "is_seg": False, + "order": 3, + "order_z": 0, + "force_separate_z": None, + }, "resampling_fn_seg": "resample_data_or_seg_to_shape", - "resampling_fn_seg_kwargs": {"is_seg": True, "order": 1, "order_z": 0, "force_separate_z": None}, + "resampling_fn_seg_kwargs": { + "is_seg": True, + "order": 1, + "order_z": 0, + "force_separate_z": None, + }, "resampling_fn_probabilities": "resample_data_or_seg_to_shape", - "resampling_fn_probabilities_kwargs": {"is_seg": False, "order": 1, "order_z": 0, "force_separate_z": None}, + "resampling_fn_probabilities_kwargs": { + "is_seg": False, + "order": 1, + "order_z": 0, + "force_separate_z": None, + }, **plans["plans_per_stage"][0], } } @@ -709,6 +724,8 @@ def get_slices(self): def empty_cache(device: torch.device): + if isinstance(device, str): + device = torch.device(device) if device.type == "cuda": torch.cuda.empty_cache() elif device.type == "mps": diff --git a/TPTBox/segmentation/spineps.py b/TPTBox/segmentation/spineps.py index f3592f5..139e2a3 100644 --- a/TPTBox/segmentation/spineps.py +++ b/TPTBox/segmentation/spineps.py @@ -2,6 +2,7 @@ import subprocess from pathlib import Path +from typing import Literal from TPTBox import BIDS_FILE, NII, Print_Logger @@ -18,7 +19,22 @@ def get_outpaths_spineps_single( dataset=None, derivative_name="derivative", ignore_bids_filter=True, -): +) -> dict[ + Literal[ + "out_spine", + "out_spine_raw", + "out_vert", + "out_vert_raw", + "out_unc", + "out_logits", + "out_snap", + "out_ctD", + "out_snap2", + "out_debug", + "out_raw", + ], + Path, +]: from spineps.seg_run import output_paths_from_input if not isinstance(file_path, BIDS_FILE): From 82a58a4dd799334ea7e4cebba88079d6d88b5f6e Mon Sep 17 00:00:00 2001 From: Robert Graf Date: Fri, 28 Nov 2025 23:49:23 +0000 Subject: [PATCH 22/28] added new features --- TPTBox/core/bids_constants.py | 91 ++++++++++++++++ TPTBox/core/dicom/dicom2nii_utils.py | 48 ++++++++ TPTBox/core/dicom/dicom_extract.py | 40 ++++++- TPTBox/core/dicom/dicom_header_to_keys.py | 40 +++++-- TPTBox/core/nii_poi_abstract.py | 9 +- TPTBox/core/nii_wrapper.py | 79 +++++++++----- TPTBox/core/np_utils.py | 21 +++- TPTBox/logger/log_file.py | 5 +- .../segmentation/VibeSeg/inference_nnunet.py | 35 +++--- TPTBox/segmentation/__init__.py | 2 +- TPTBox/segmentation/spineps.py | 11 +- unit_tests/test_auto_segmentation.py | 103 ++++++++++++++++++ 12 files changed, 404 insertions(+), 80 deletions(-) create mode 100644 unit_tests/test_auto_segmentation.py diff --git a/TPTBox/core/bids_constants.py b/TPTBox/core/bids_constants.py index 6c9c677..bbeb4ea 100755 --- a/TPTBox/core/bids_constants.py +++ b/TPTBox/core/bids_constants.py @@ -138,7 +138,15 @@ "recon", "reformat", "subtraction", + "DSA", + "DSA3D", + "3DRA", + "XA", "RI", # Raw input + "tmax", + "cbv", + "mtt", + "cbf", "stat", "snp", "log", @@ -158,6 +166,88 @@ formats_relaxed = [*formats, "t2", "t1", "t2c", "t1c", "cta", "mr", "snapshot", "t1dixon", "dwi"] # Recommended writing style: T1c, T2c; This list is not official and can be extended. +modalities = { + "AR": "Autorefraction", + "AS": "Angioscopy (Retired)", + "ASMT": "Content Assessment Results", + "AU": "Audio", + "BDUS": "Bone Densitometry (ultrasound)", + "BI": "Biomagnetic imaging", + "BMD": "Bone Densitometry (X-Ray)", + "CD": "Color flow Doppler (Retired)", + "CF": "Cinefluorography (Retired)", + "CP": "Colposcopy (Retired)", + "CR": "Computed Radiography", + "CS": "Cystoscopy (Retired)", + "CT": "Computed Tomography", + "DD": "Duplex Doppler (Retired)", + "DF": "Digital fluoroscopy (Retired)", + "DG": "Diaphanography", + "DM": "Digital microscopy (Retired)", + "DOC": "Document", + "DS": "Digital Subtraction Angiography (Retired)", + "DX": "Digital Radiography", + "EC": "Echocardiography (Retired)", + "ECG": "Electrocardiography", + "EPS": "Cardiac Electrophysiology", + "ES": "Endoscopy", + "FA": "Fluorescein angiography (Retired)", + "FID": "Fiducials", + "FS": "Fundoscopy (Retired)", + "GM": "General Microscopy ", + "HC": "Hard Copy", + "HD": "Hemodynamic Waveform", + "IO": "Intra-Oral Radiography", + "IOL": "Intraocular Lens Data", + "IVOCT": "Intravascular Optical Coherence Tomography", + "IVUS": "Intravascular Ultrasound", + "KER": "Keratometry", + "KO": "Key Object Selection", + "LEN": "Lensometry", + "LP": "Laparoscopy (Retired)", + "LS": "Laser surface scan", + "MA": "Magnetic resonance angiography (Retired)", + "MG": "Mammography", + "MR": "Magnetic Resonance", + "MS": "Magnetic resonance spectroscopy (Retired)", + "NM": "Nuclear Medicine", + "OAM": "Ophthalmic Axial Measurements", + "OCT": "Optical Coherence Tomography (non-Ophthalmic)", + "OP": "Ophthalmic Photography", + "OPM": "Ophthalmic Mapping", + "OPR": "Ophthalmic Refraction (Retired)", + "OPT": "Ophthalmic Tomography", + "OPV": "Ophthalmic Visual Field", + "OSS": "Optical Surface Scan", + "OT": "Other ", + "PLAN": "Plan", + "PR": "Presentation State", + "PT": "Positron emission tomography (PET)", + "PX": "Panoramic X-Ray", + "REG": "Registration", + "RESP": "Respiratory Waveform", + "RF": "Radio Fluoroscopy", + "RG": "Radiographic imaging (conventional film/screen)", + "RTDOSE": "Radiotherapy Dose", + "RTIMAGE": "Radiotherapy Image", + "RTPLAN": "Radiotherapy Plan", + "RTRECORD": "RT Treatment Record", + "RTSTRUCT": "Radiotherapy Structure Set", + "RWV": "Real World Value Map", + "SEG": "Segmentation", + "SM": "Slide Microscopy", + "SMR": "Stereometric Relationship", + "SR": "SR Document", + "SRF": "Subjective Refraction", + "ST": "Single-photon emission computed tomography (SPECT) (Retired)", + "STAIN": "Automated Slide Stainer", + "TG": "Thermography", + "US": "Ultrasound", + "VA": "Visual Acuity", + "VF": "Videofluorography (Retired)", + "XA": "X-Ray Angiography", + "XC": "External-camera Photography", +} # Actual official final folder # func (task based and resting state functional MRI) @@ -246,6 +336,7 @@ # Others (never used) "Split": "split", "Density": "den", + "version": "version", "Description": "desc", "nameconflict": "nameconflict", } diff --git a/TPTBox/core/dicom/dicom2nii_utils.py b/TPTBox/core/dicom/dicom2nii_utils.py index 9485698..4b661c7 100755 --- a/TPTBox/core/dicom/dicom2nii_utils.py +++ b/TPTBox/core/dicom/dicom2nii_utils.py @@ -3,6 +3,7 @@ import json import pickle from copy import deepcopy +from datetime import date from pathlib import Path import numpy as np @@ -152,9 +153,56 @@ def clean_dicom_data(dcm_data) -> dict: for tag in ["00291010", "00291020"]: if tag in py_dict and "InlineBinary" in py_dict[tag]: del py_dict[tag]["InlineBinary"] + py_dict = replace_birthdate_with_age(py_dict) return py_dict +def replace_birthdate_with_age(d): + try: + # DICOM tags + BIRTH_TAG = "00100030" # PatientBirthDate + STUDY_DATE_TAG = "00080020" # StudyDate + AGE_TAG = "00101010" # PatientAge + + birth_str = d.get(BIRTH_TAG, {}).get("Value", [None])[0] + study_str = d.get(STUDY_DATE_TAG, {}).get("Value", [None])[0] + + if not birth_str: + return d # no birth date, nothing to do + + # Parse birth date safely + try: + year = int(birth_str[:4]) + month = int(birth_str[4:6]) if len(birth_str) >= 6 and birth_str[4:6] != "00" else 6 + day = int(birth_str[6:8]) if len(birth_str) == 8 and birth_str[6:8] != "00" else 15 + birth_date = date(year, month, day) + except Exception: + return d # invalid date format, skip + + # Reference date (study date or today) + try: + ref_date = date( + int(study_str[:4]), + int(study_str[4:6]) if study_str[4:6] != "00" else 6, + int(study_str[6:8]) if study_str[6:8] != "00" else 15, + ) + except Exception: + ref_date = date.today() + + # Compute integer age + age = ref_date.year - birth_date.year - ((ref_date.month, ref_date.day) < (birth_date.month, birth_date.day)) + + # Replace PatientBirthDate with PatientAge + d.pop(BIRTH_TAG, None) + d[AGE_TAG] = { + "vr": "AS", # Age String + "Value": [f"{age:03d}Y"], # DICOM age format (e.g. '034Y') + } + except Exception: + pass + return d + + def get_json_from_dicom(data: list[pydicom.FileDataset] | pydicom.FileDataset): if isinstance(data, list): data = data[0] diff --git a/TPTBox/core/dicom/dicom_extract.py b/TPTBox/core/dicom/dicom_extract.py index 124a2a5..325d7a8 100644 --- a/TPTBox/core/dicom/dicom_extract.py +++ b/TPTBox/core/dicom/dicom_extract.py @@ -90,7 +90,9 @@ def _generate_bids_path( def dicom_to_nifti_multiframe(ds, nii_path): pixel_array = ds.pixel_array - n_frames, n_rows, n_cols = pixel_array.shape + if len(pixel_array.shape) != 3 and len(pixel_array.shape) != 4: + raise ValueError(f"Expected a shape with 3 colums not {len(pixel_array.shape)}; {pixel_array.shape=}") + n_frames = pixel_array.shape[0] # Pixel spacing (mm) if hasattr(ds, "PixelSpacing"): @@ -118,15 +120,37 @@ def dicom_to_nifti_multiframe(ds, nii_path): affine[0:3, 1] = col_cosines * dy affine[0:3, 2] = slice_cosines * dz affine[0:3, 3] = origin - else: + nii = nib.Nifti1Image(np.transpose(pixel_array, (2, 1, 0)), affine) + + elif hasattr(ds, "ImagerPixelSpacing"): dy, dx = (float(v) for v in ds.ImagerPixelSpacing) # Einfaches affine (nur 2D + Zeit, keine Lage im Patientenraum) affine = np.eye(4) affine[0, 0] = -dx affine[1, 1] = -dy + nii = nib.Nifti1Image(np.transpose(pixel_array, (2, 1, 0)), affine) + + else: + if hasattr(ds, "RelatedSeriesSequence"): + raise NotImplementedError("RelatedSeriesSequence Affine lookup not implemented") + raise NotImplementedError("No spatial metadata found") + ### Some could be solved by looking up the "RelatedSeriesSequence" + # "RelatedSeriesSequence": [ + # { + # "StudyInstanceUID": "1.2.276.0.38.1.1.1.7712.20250929100319.54200288", + # "SeriesInstanceUID": "1.3.46.670589.7.8.1.6.1403526999.1.9608.1759142950287.2", + # "PurposeOfReferenceCodeSequence": [] + # } + # ], + # --- No geometry info (e.g. RGB screen captures or video frames) --- + print("⚠️ No spatial metadata found — assuming pixel size = 1mm and identity orientation.") + affine = np.eye(4) + affine[0, 0] = 1.0 + affine[1, 1] = 1.0 + affine[2, 2] = 1.0 + nii = nib.Nifti1Image(pixel_array, affine) # Reihenfolge anpassen: Nibabel erwartet (X,Y,Z) - nii = nib.Nifti1Image(np.transpose(pixel_array, (2, 1, 0)), affine) nib.save(nii, nii_path) return nii_path @@ -491,6 +515,7 @@ def extract_dicom_folder( validate_slicecount=True, validate_orientation=True, validate_orthogonal=False, + validate_slice_increment=True, n_cpu: int | None = 1, override_subject_name: Callable[[dict, Path], str] | None = None, skip_localizer=True, @@ -517,7 +542,8 @@ def extract_dicom_folder( convert_dicom.settings.disable_validate_orientation() if not validate_orthogonal: convert_dicom.settings.disable_validate_orthogonal() - + if not validate_slice_increment: + convert_dicom.settings.disable_validate_slice_increment() outs = {} for p in _find_all_files(dicom_folder): @@ -566,6 +592,8 @@ def process_series(key, files, parts): try: key2, out = process_series(key, files, parts) outs[key2] = out + except NotImplementedError as e: + logger.on_warning("NotImplementedError:", e) except Exception: logger.print_error() @@ -577,8 +605,8 @@ def process_series(key, files, parts): if __name__ == "__main__": - for p in Path("E:/DSA_Data/DICOMS_DSA_all/").iterdir(): - extract_dicom_folder(p, Path("D:/data/DSA", "dataset-DSA"), False, False) + for p in Path("/DATA/NAS/datasets_source/brain/dsa").iterdir(): + extract_dicom_folder(p, Path("/DATA/NAS/datasets_source/brain/", "dataset-DSA"), False, False, validate_slice_increment=False) sys.exit() # s = "/home/robert/Downloads/bein/dataset-oberschenkel/rawdata/sub-1-3-46-670589-11-2889201787-2305829596-303261238-2367429497/mr/sub-1-3-46-670589-11-2889201787-2305829596-303261238-2367429497_sequ-406_mr.nii.gz" diff --git a/TPTBox/core/dicom/dicom_header_to_keys.py b/TPTBox/core/dicom/dicom_header_to_keys.py index 6279857..b9156ec 100644 --- a/TPTBox/core/dicom/dicom_header_to_keys.py +++ b/TPTBox/core/dicom/dicom_header_to_keys.py @@ -10,7 +10,7 @@ import pydicom from dicom2nifti import common -from TPTBox.core.bids_constants import formats +from TPTBox.core.bids_constants import formats, modalities from TPTBox.core.nii_wrapper import NII, to_nii dixon_mapping = { @@ -266,7 +266,25 @@ def _get(key, default=None): found = False if modality == "ct": mri_format = "ct" - else: + elif modality == "xa": # Angiography + if "BIPLANE A" in image_type or "SINGLE A" in image_type: + keys["acq"] = "A" + elif "BIPLANE B" in image_type or "SINGLE B" in image_type: + keys["acq"] = "B" + monitor = _get("PositionerMotion", " ").lower() + # ftv = _get("FrameTimeVector", None).lower() + monitor = _get("PositionerMotion", " ").lower() + tag = _get("DerivationDescription", " ").lower() + # ftv is not None + if tag == "subtraction": + mri_format = "DSA" if monitor == "static" and "VOLUME" not in image_type and "RECON" not in image_type else "subtraction" + elif "3DRA_PROP" in image_type: + mri_format = "3DRA" + elif monitor == "dynamic" or "VOLUME" in image_type or "RECON" in image_type or "3DRA_PROP" in image_type: + mri_format = "DSA3D" + else: + mri_format = "XA" + elif modality == "mr": for key, mri_format_new in map_series_description_to_file_format.items(): regex = re.compile(key) if re.match(regex, series_description): @@ -280,14 +298,16 @@ def _get(key, default=None): break if mri_format is None: mri_format = "mr" - if mri_format == "T1w": - if "sub" in series_description.lower() and keys.get("part") is None: - keys["part"] = "subtraction" - if ( - " km " in series_description.lower() or series_description.startswith("km") or series_description.endswith("km") - ) and keys.get("ce") is None: - keys["ce"] = "ContrastAgent" + if mri_format == "T1w": + if "sub" in series_description.lower() and keys.get("part") is None: + keys["part"] = "subtraction" + if ( + " km " in series_description.lower() or series_description.startswith("km") or series_description.endswith("km") + ) and keys.get("ce") is None: + keys["ce"] = "ContrastAgent" + else: + raise NotImplementedError(f"modality='{modality.upper()}', ({modalities.get(modality.upper())})") + # ".*sub.*t1.*": "subtraktion", # "subtraktion.*t1.*": "subtraktion", - return mri_format, keys diff --git a/TPTBox/core/nii_poi_abstract.py b/TPTBox/core/nii_poi_abstract.py index e04edce..0ca2632 100755 --- a/TPTBox/core/nii_poi_abstract.py +++ b/TPTBox/core/nii_poi_abstract.py @@ -74,8 +74,9 @@ def __str__(self) -> str: except Exception: origin = self.origin try: - zoom = tuple(np.around(self.zoom, decimals=2).tolist()) - except Exception: + zoom = "(" + ",".join([f"{a:.2f}" for a in self.zoom]) + ")" + except Exception as e: + print(e) zoom = self.zoom return f"shape={self.shape_int},spacing={zoom}, origin={origin}, ori={self.orientation}" # type: ignore @@ -127,11 +128,11 @@ def change_affine( Apply a transformation (translation, rotation, scaling) to the affine matrix. Parameters: - translation: (n,) array-like in mm + translation: (n,) array-like in mm in (R, A, S) rotation_degrees: (n,) array-like (pitch, yaw, roll) in degrees scaling: (n,) array-like scaling factors along x, y, z """ - warnings.warn("change_affine is untested", stacklevel=2) + # warnings.warn("change_affine is untested", stacklevel=2) n = self.affine.shape[0] transform = np.eye(n) diff --git a/TPTBox/core/nii_wrapper.py b/TPTBox/core/nii_wrapper.py index 66abd86..89cd685 100755 --- a/TPTBox/core/nii_wrapper.py +++ b/TPTBox/core/nii_wrapper.py @@ -59,6 +59,7 @@ SHAPE, ZOOMS, Location, + _same_direction, log, logging, v_name2idx, @@ -666,7 +667,7 @@ def astype(self,dtype,order:Literal["C","F","A","K"] ='K', casting:Literal["no", return c else: return self.get_array().astype(dtype,order=order,casting=casting, subok=subok,copy=copy) - def reorient(self:Self, axcodes_to: AX_CODES = ("P", "I", "R"), verbose:logging=False, inplace=False)-> Self: + def reorient(self:Self, axcodes_to: AX_CODES|None = ("P", "I", "R"), verbose:logging=False, inplace=False)-> Self: """ Reorients the input Nifti image to the desired orientation, specified by the axis codes. @@ -683,30 +684,34 @@ def reorient(self:Self, axcodes_to: AX_CODES = ("P", "I", "R"), verbose:logging= """ # Note: nibabel axes codes describe the direction not origin of axes # direction PIR+ = origin ASL - - aff = self.affine - ornt_fr = self.orientation_ornt - arr = self.get_array() - ornt_to = nio.axcodes2ornt(axcodes_to) - ornt_trans = nio.ornt_transform(ornt_fr, ornt_to) - if (ornt_fr == ornt_to).all(): - log.print("Image is already rotated to", axcodes_to,verbose=verbose) - if inplace: - return self - return self.copy() # type: ignore - arr = nio.apply_orientation(arr, ornt_trans) - aff_trans = nio.inv_ornt_aff(ornt_trans, arr.shape) - new_aff = np.matmul(aff, aff_trans) - ### Reset origin ### - flip = ornt_trans[:, 1] - change = ((-flip) + 1) / 2 # 1 if flip else 0 - change = tuple(a * (s-1) for a, s in zip(change, self.shape)) - new_aff[:3, 3] = nib.affines.apply_affine(aff,change) # type: ignore - ###### - #if self.header is not None: - # self.header.set_sform(new_aff, code=1) - new_img = arr, new_aff,self.header - log.print("Image reoriented from", nio.ornt2axcodes(ornt_fr), "to", axcodes_to,verbose=verbose) + if isinstance(axcodes_to,str): + axcodes_to = tuple(axcodes_to) + if axcodes_to is not None: + aff = self.affine + ornt_fr = self.orientation_ornt + arr = self.get_array() + ornt_to = nio.axcodes2ornt(axcodes_to) + ornt_trans = nio.ornt_transform(ornt_fr, ornt_to) + if (ornt_fr == ornt_to).all(): + log.print("Image is already rotated to", axcodes_to,verbose=verbose) + if inplace: + return self + return self.copy() # type: ignore + arr = nio.apply_orientation(arr, ornt_trans) + aff_trans = nio.inv_ornt_aff(ornt_trans, arr.shape) + new_aff = np.matmul(aff, aff_trans) + ### Reset origin ### + flip = ornt_trans[:, 1] + change = ((-flip) + 1) / 2 # 1 if flip else 0 + change = tuple(a * (s-1) for a, s in zip(change, self.shape)) + new_aff[:3, 3] = nib.affines.apply_affine(aff,change) # type: ignore + ###### + #if self.header is not None: + # self.header.set_sform(new_aff, code=1) + new_img = arr, new_aff,self.header + log.print("Image reoriented from", nio.ornt2axcodes(ornt_fr), "to", axcodes_to,verbose=verbose) + else: + return self if not inplace else self.copy() if inplace: self.nii = new_img return self @@ -1420,7 +1425,8 @@ def filter_connected_components(self, labels: int |list[int]|None=None,min_volum max_count_component (int | None): Maximum number of components to retain. Once this limit is reached, remaining components will be removed. connectivity (int): Connectivity criterion for defining connected components (default is 3). removed_to_label (int): Label to assign to removed components (default is 0). - + TODO : max_count_component currently filters over all labels instead of per label. will be changed. + TODO : removed_to_label does not work when keep_label=False Returns: None """ @@ -1437,8 +1443,8 @@ def filter_connected_components(self, labels: int |list[int]|None=None,min_volum #print("filter",nii.unique()) #assert max_count_component is None or nii.max() <= max_count_component, nii.unique() return self.set_array(arr, inplace=inplace) - def filter_connected_components_(self, labels: int |list[int]|None=None,min_volume:int=0,max_volume:int|None=None, max_count_component = None, connectivity: int = 3,keep_label=False): - return self.filter_connected_components(labels,min_volume=min_volume,max_volume=max_volume, max_count_component = max_count_component, connectivity = connectivity,keep_label=keep_label,inplace=True) + def filter_connected_components_(self, labels: int |list[int]|None=None,min_volume:int=0,max_volume:int|None=None, max_count_component = None, connectivity: int = 3,keep_label=False,removed_to_label=0): + return self.filter_connected_components(labels,min_volume=min_volume,max_volume=max_volume, max_count_component = max_count_component, connectivity = connectivity,removed_to_label=removed_to_label,keep_label=keep_label,inplace=True) def get_segmentation_connected_components_center_of_mass(self, label: int, connectivity: int = 3, sort_by_axis: int | None = None) -> list[COORDINATE]: """Calculates the center of mass of the different connected components of a given label in an array @@ -1835,6 +1841,16 @@ def copy(self, nib:Nifti1Image|_unpacked_nii|None = None,seg=None): nib = (self.get_array().copy(), self.affine.copy(), self.header.copy()) return NII(nib,seg=self.seg if seg is None else seg,c_val = self.c_val,info = self.info) + + def flip(self, axis:int|str,keep_global_coords=True,inplace=False): + axis = self.get_axis(axis) if not isinstance(axis,int) else axis + if keep_global_coords: + orient = list(self.orientation) + orient[axis] = _same_direction[orient[axis] ] + return self.reorient(tuple(orient),inplace=inplace) + else: + return self.set_array(np.flip(self.get_array(),axis),inplace=inplace) + def clone(self): return self.copy() @secure_save @@ -2011,9 +2027,14 @@ def extract_background(self,inplace=False): arr_bg = np_extract_label(arr_bg, label=0, to_label=1) return self.set_array(arr_bg, inplace, False) - def extract_label(self,label:int|Enum|Sequence[int]|Sequence[Enum], keep_label=False,inplace=False): + def extract_label(self,label:int|Enum|Sequence[int]|Sequence[Enum]|None, keep_label=False,inplace=False): '''If this NII is a segmentation you can single out one label with [0,1].''' assert self.seg, "extracting a label only makes sense for a segmentation mask" + if label is None: + if keep_label: + return self.copy() if inplace else self + else: + return self.clamp(0,1,inplace=inplace) seg_arr = self.get_seg_array() if isinstance(label, Sequence): diff --git a/TPTBox/core/np_utils.py b/TPTBox/core/np_utils.py index 05870eb..6816315 100755 --- a/TPTBox/core/np_utils.py +++ b/TPTBox/core/np_utils.py @@ -792,6 +792,7 @@ def np_filter_connected_components( arr2[np.isin(arr2, labels, invert=True)] = 0 # type:ignore labels_out, n = _connected_components(arr2, connectivity=connectivity, return_N=True) + largest_k_components_org = largest_k_components if largest_k_components is None: largest_k_components = n assert largest_k_components is not None @@ -801,8 +802,23 @@ def np_filter_connected_components( ] largest_k_components = min(largest_k_components, len(label_volume_pairs)) label_volume_pairs.sort(key=lambda x: x[1], reverse=True) - preserve: list[int] = [x[0] for x in label_volume_pairs[:largest_k_components]] + if len(labels) == 1 or label_volume_pairs == largest_k_components or largest_k_components_org is None: + preserve: list[int] = [x[0] for x in label_volume_pairs[:largest_k_components]] + else: + counter = dict.fromkeys(labels, 0) + preserve = [] + for preserve_label, _ in label_volume_pairs: + idx = arr[labels_out == preserve_label].max() + if counter.get(idx, largest_k_components + 1) <= largest_k_components_org: + preserve.append(preserve_label) + counter[idx] += 1 + # print("add perserve", idx) + if counter.get(idx, largest_k_components + 1) == largest_k_components_org: + del counter[idx] + # print("del perserve", idx) + if len(counter) == 0: + break cc_out = np.zeros(arr.shape, dtype=arr.dtype) i = 1 for preserve_label in preserve: @@ -1014,9 +1030,6 @@ def np_smooth_gaussian_labelwise( if isinstance(label_to_smooth, int): label_to_smooth = [label_to_smooth] - for l in label_to_smooth: - assert l in sem_labels, f"You want to smooth label {l} but it is not present in the given segmentation mask" - if dilate_prior > 0 and not dilate_channelwise: arr = np_dilate_msk( arr, diff --git a/TPTBox/logger/log_file.py b/TPTBox/logger/log_file.py index 0db5f7b..ad8592d 100755 --- a/TPTBox/logger/log_file.py +++ b/TPTBox/logger/log_file.py @@ -52,7 +52,7 @@ def print( """ if verbose is None: verbose = getattr(self, "default_verbose", False) - if len(text) == 0 or text == [""] or text == "" or text is None: + if len(text) == 0 or text in ([""], "") or text is None: ignore_prefix = True string: str = "" else: @@ -211,6 +211,9 @@ def on_warning(self, *text, end="\n", verbose: bool | None = None, **qargs): def on_text(self, *text, end="\n", verbose: bool | None = None, **qargs): self.print(*text, end=end, ltype=Log_Type.TEXT, verbose=verbose, **qargs) + def info(self, *text, end="\n", verbose: bool | None = None, **qargs): + self.print(*text, end=end, ltype=Log_Type.TEXT, verbose=verbose, **qargs) + class Logger(Logger_Interface): """ diff --git a/TPTBox/segmentation/VibeSeg/inference_nnunet.py b/TPTBox/segmentation/VibeSeg/inference_nnunet.py index 4f1295a..889c4e4 100644 --- a/TPTBox/segmentation/VibeSeg/inference_nnunet.py +++ b/TPTBox/segmentation/VibeSeg/inference_nnunet.py @@ -127,23 +127,24 @@ def run_inference_on_file( try: zoom_old = ds_info.get("spacing") - zoom = plans_info["configurations"]["3d_fullres"]["spacing"] - order = plans_info["transpose_backward"] - # order2 = plans_info["transpose_forward"] - zoom = [zoom[order[0]], zoom[order[1]], zoom[order[2]]][::-1] - orientation_ref = ("P", "I", "R") - orientation_ref = [ - orientation_ref[order[0]], - orientation_ref[order[1]], - orientation_ref[order[2]], - ] # [::-1] + + if zoom_old is None: + zoom = plans_info["configurations"]["3d_fullres"]["spacing"] + if all(zoom[0] == z for z in zoom): + zoom_old = zoom + # order = plans_info["transpose_backward"] + ## order2 = plans_info["transpose_forward"] + # zoom = [zoom[order[0]], zoom[order[1]], zoom[order[2]]][::-1] + # orientation_ref = ("P", "I", "R") + # orientation_ref = [ + # orientation_ref[order[0]], + # orientation_ref[order[1]], + # orientation_ref[order[2]], + # ] # [::-1] # zoom_old = zoom_old[::-1] - if zoom is None: - pass - else: - zoom = [float(z) for z in zoom] + zoom_old = [float(z) for z in zoom_old] except Exception: pass assert len(ds_info["channel_names"]) == len(input_nii), ( @@ -156,9 +157,9 @@ def run_inference_on_file( print("orientation", orientation, f"{orientation_ref=}") if verbose else None input_nii = [i.reorient(orientation) for i in input_nii] - if zoom is not None: - print("rescale", zoom, f"{zoom_old=}, {order=}") if verbose else None - input_nii = [i.rescale_(zoom, mode=mode) for i in input_nii] + if zoom_old is not None: + print("rescale", zoom, f"{zoom_old=}") if verbose else None + input_nii = [i.rescale_(zoom_old, mode=mode) for i in input_nii] print(input_nii) print("squash to float16") if verbose else None input_nii = [squash_so_it_fits_in_float16(i) for i in input_nii] diff --git a/TPTBox/segmentation/__init__.py b/TPTBox/segmentation/__init__.py index a0fa5d6..0601e1f 100644 --- a/TPTBox/segmentation/__init__.py +++ b/TPTBox/segmentation/__init__.py @@ -1,5 +1,5 @@ from __future__ import annotations -from TPTBox.segmentation.spineps import run_spineps_all, run_spineps_single +from TPTBox.segmentation.spineps import _run_spineps_all, run_spineps from TPTBox.segmentation.VibeSeg.vibeseg import extract_vertebra_bodies_from_VibeSeg, run_inference_on_file, run_nnunet, run_vibeseg from TPTBox.segmentation.VibeSeg.vibeseg import run_vibeseg as run_totalvibeseg # TODO deprecate diff --git a/TPTBox/segmentation/spineps.py b/TPTBox/segmentation/spineps.py index 139e2a3..6ea3657 100644 --- a/TPTBox/segmentation/spineps.py +++ b/TPTBox/segmentation/spineps.py @@ -9,12 +9,7 @@ logger = Print_Logger() -def injection_function(seg_nii: NII): - # TODO do something with semantic mask - return seg_nii - - -def get_outpaths_spineps_single( +def get_outpaths_spineps( file_path: str | Path | BIDS_FILE, dataset=None, derivative_name="derivative", @@ -50,7 +45,7 @@ def get_outpaths_spineps_single( return output_paths -def run_spineps_single( +def run_spineps( file_path: str | Path | BIDS_FILE, dataset=None, model_semantic: str | Path = "t2w", @@ -111,7 +106,7 @@ def run_spineps_single( return output_paths -def run_spineps_all(nii_dataset: Path | str): +def _run_spineps_all(nii_dataset: Path | str): for model_semantic in ["t2w", "t1w", "vibe"]: command = [ "spineps", diff --git a/unit_tests/test_auto_segmentation.py b/unit_tests/test_auto_segmentation.py new file mode 100644 index 0000000..9ba3836 --- /dev/null +++ b/unit_tests/test_auto_segmentation.py @@ -0,0 +1,103 @@ +# Call 'python -m unittest' on this folder +# coverage run -m unittest +# coverage report +# coverage html +from __future__ import annotations + +import random +import shutil +import sys +import tempfile +from pathlib import Path + +import numpy as np + +from TPTBox.core.nii_wrapper import to_nii + +file = Path(__file__).resolve() +sys.path.append(str(file.parents[2])) + +import unittest # noqa: E402 + +from TPTBox import NII, Location, Print_Logger, calc_poi_from_subreg_vert # noqa: E402 +from TPTBox.tests.test_utils import get_test_ct, get_test_mri, get_tests_dir # noqa: E402 + +has_spineps = False +try: + import spineps + + has_spineps = True +except ModuleNotFoundError: + has_spineps = False + + +class Test_test_samples(unittest.TestCase): + # def test_load_ct(self): + # ct_nii, subreg_nii, vert_nii, label = get_test_ct() + # self.assertTrue(ct_nii.assert_affine(other=subreg_nii, raise_error=False)) + # self.assertTrue(ct_nii.assert_affine(other=vert_nii, raise_error=False)) + + # l3 = vert_nii.extract_label(label) + # l3_subreg = subreg_nii.apply_mask(l3, inplace=False) + # self.assertEqual(l3.volumes()[1], sum(l3_subreg.volumes(include_zero=False).values())) + @unittest.skipIf(not has_spineps, "requires spineps to be installed") + def test_get_outpaths_spineps(self): + tests_path = get_tests_dir() + + from TPTBox.segmentation.spineps import get_outpaths_spineps + + mri_path = tests_path.joinpath("sample_mri") + mri_path = mri_path.joinpath("sub-mri_label-6_T2w.nii.gz") + out = get_outpaths_spineps(mri_path, tests_path) + assert "out_spine" in out + assert "out_vert" in out + + @unittest.skipIf(not has_spineps, "requires spineps to be installed") + def test_spineps(self): + tests_path = get_tests_dir() + + mri_nii, subreg_nii, vert_nii, label = get_test_mri() + from TPTBox.segmentation.spineps import run_spineps + + mri_path = tests_path.joinpath("sample_mri") + mri_path = mri_path.joinpath("sub-mri_label-6_T2w.nii.gz") + out = run_spineps(mri_path, tests_path, ignore_compatibility_issues=True) + assert "out_spine" in out + assert "out_vert" in out + assert out["out_spine"].exists() + assert out["out_vert"].exists() + assert out["out_snap"].exists() + assert out["out_ctd"].exists() + + vert_nii = to_nii(out["out_vert"], True) + assert label in vert_nii.unique(), vert_nii.unique() + shutil.rmtree(tests_path / "derivative") + + @unittest.skipIf(not has_spineps, "requires spineps to be installed") + def test_VIBESeg(self): + tests_path = get_tests_dir() + from TPTBox.segmentation import run_vibeseg + + for i in [100, 11, 278]: + mri_path = tests_path.joinpath("sample_mri") + mri_path = mri_path.joinpath("sub-mri_label-6_T2w.nii.gz") + seg_out_path = tests_path / f"{i}_test_VIBESeg.nii.gz" + out = run_vibeseg(mri_path, seg_out_path, dataset_id=i) + assert isinstance(out, (NII, Path)) + assert seg_out_path.exists() + seg_out_path.unlink(missing_ok=True) + + @unittest.skipIf(not has_spineps, "requires spineps to be installed") + def test_VIBESeg_ct(self): + tests_path = get_tests_dir() + from TPTBox.segmentation import run_vibeseg + + for i in [100, 11, 520]: + tests_path = get_tests_dir() + ct_path = tests_path.joinpath("sample_ct") + ct_path = ct_path.joinpath("sub-ct_label-22_ct.nii.gz") + seg_out_path = tests_path / f"{i}_test_ct_VIBESeg.nii.gz" + out = run_vibeseg(ct_path, seg_out_path, dataset_id=i) + assert isinstance(out, (NII, Path)) + assert seg_out_path.exists() + seg_out_path.unlink(missing_ok=True) From a9a3aa79acdcec8fee88f1496d1a2ba2e70f4da3 Mon Sep 17 00:00:00 2001 From: Robert Graf Date: Sat, 29 Nov 2025 00:52:06 +0000 Subject: [PATCH 23/28] fix test with updated semantic meaning --- TPTBox/core/np_utils.py | 5 ++- unit_tests/test_auto_segmentation.py | 4 +- unit_tests/test_nputils.py | 57 ++++++++++++++++++++++++++-- 3 files changed, 59 insertions(+), 7 deletions(-) diff --git a/TPTBox/core/np_utils.py b/TPTBox/core/np_utils.py index 6816315..6997aa0 100755 --- a/TPTBox/core/np_utils.py +++ b/TPTBox/core/np_utils.py @@ -765,6 +765,7 @@ def np_filter_connected_components( min_volume: float = 0, max_volume: float | None = None, removed_to_label=0, + k_larges_global=False, ) -> UINTARRAY: """finds the largest k connected components in a given array (does NOT work with zero as label!) @@ -774,7 +775,7 @@ def np_filter_connected_components( labels (int | list[int] | None, optional): Labels that the algorithm should be applied to. If none, applies on all labels found in arr. Defaults to None. connectivity: in range [1,3]. For 2D images, 2 and 3 is the same. return_original_labels (bool): If set to False, will label the components from 1 to k. Defaults to True - + k_larges_global(bool): If true largest_k_components is filterd over all labels instead of each lable individualy Returns: np.ndarray: array with the largest k connected components """ @@ -803,7 +804,7 @@ def np_filter_connected_components( largest_k_components = min(largest_k_components, len(label_volume_pairs)) label_volume_pairs.sort(key=lambda x: x[1], reverse=True) - if len(labels) == 1 or label_volume_pairs == largest_k_components or largest_k_components_org is None: + if len(labels) == 1 or label_volume_pairs == largest_k_components or largest_k_components_org is None or k_larges_global: preserve: list[int] = [x[0] for x in label_volume_pairs[:largest_k_components]] else: counter = dict.fromkeys(labels, 0) diff --git a/unit_tests/test_auto_segmentation.py b/unit_tests/test_auto_segmentation.py index 9ba3836..41ed643 100644 --- a/unit_tests/test_auto_segmentation.py +++ b/unit_tests/test_auto_segmentation.py @@ -55,6 +55,8 @@ def test_get_outpaths_spineps(self): @unittest.skipIf(not has_spineps, "requires spineps to be installed") def test_spineps(self): tests_path = get_tests_dir() + if (tests_path / "derivative").exists(): + shutil.rmtree(tests_path / "derivative") mri_nii, subreg_nii, vert_nii, label = get_test_mri() from TPTBox.segmentation.spineps import run_spineps @@ -70,7 +72,7 @@ def test_spineps(self): assert out["out_ctd"].exists() vert_nii = to_nii(out["out_vert"], True) - assert label in vert_nii.unique(), vert_nii.unique() + assert label in vert_nii.unique(), (label, vert_nii.unique()) shutil.rmtree(tests_path / "derivative") @unittest.skipIf(not has_spineps, "requires spineps to be installed") diff --git a/unit_tests/test_nputils.py b/unit_tests/test_nputils.py index 13c52ae..aa134da 100755 --- a/unit_tests/test_nputils.py +++ b/unit_tests/test_nputils.py @@ -286,17 +286,17 @@ def test_connected_components_per_label(self): msg=f"{coms[0][0]}, {coms_compare}", ) - def test_get_largest_k_connected_components(self): + def test_get_largest_k_connected_components_non_global(self): a = np.zeros((50, 50), dtype=np.uint16) - a[10:20, 10:20] = 5 - a[30:50, 30:50] = 7 + a[10:20, 10:20] = 1 + a[30:50, 30:50] = 1 a[1:4, 1:4] = 1 # k less than N a_cc = np_utils.np_filter_connected_components(a, largest_k_components=2, return_original_labels=False) a_volume = np_utils.np_volume(a_cc) print(a_volume) - self.assertTrue(len(a_volume) == 2) + self.assertTrue(len(a_volume) == 2, a_volume) self.assertTrue(a_volume[1] > a_volume[2]) # k == N @@ -313,6 +313,55 @@ def test_get_largest_k_connected_components(self): self.assertTrue(len(a_volume) == 3) self.assertTrue(a_volume[1] > a_volume[2] > a_volume[3]) + a = np.zeros((50, 50), dtype=np.uint16) + a[10:20, 10:20] = 7 + a[30:50, 30:50] = 5 + a[1:4, 1:4] = 1 + + # k less than N + a_cc = np_utils.np_filter_connected_components(a, largest_k_components=2, return_original_labels=False) + a_volume = np_utils.np_volume(a_cc) + self.assertTrue(len(a_volume) == 3, a_volume) + + def test_get_largest_k_connected_components(self): + a = np.zeros((50, 50), dtype=np.uint16) + a[10:20, 10:20] = 5 + a[30:50, 30:50] = 7 + a[1:4, 1:4] = 1 + + # k less than N + a_cc = np_utils.np_filter_connected_components(a, largest_k_components=2, return_original_labels=False, k_larges_global=True) + a_volume = np_utils.np_volume(a_cc) + print(a_volume) + self.assertTrue(len(a_volume) == 2, a_volume) + self.assertTrue(a_volume[1] > a_volume[2]) + + # k == N + a_cc = np_utils.np_filter_connected_components(a, largest_k_components=3, return_original_labels=False, k_larges_global=True) + a_volume = np_utils.np_volume(a_cc) + print(a_volume) + self.assertTrue(len(a_volume) == 3) + self.assertTrue(a_volume[1] > a_volume[2] > a_volume[3]) + + # k > N + a_cc = np_utils.np_filter_connected_components(a, largest_k_components=20, return_original_labels=False, k_larges_global=True) + a_volume = np_utils.np_volume(a_cc) + print(a_volume) + self.assertTrue(len(a_volume) == 3) + self.assertTrue(a_volume[1] > a_volume[2] > a_volume[3]) + + a = np.zeros((50, 50), dtype=np.uint16) + a[10:20, 10:20] = 1 + a[30:50, 30:50] = 1 + a[1:4, 1:4] = 1 + + # k less than N + a_cc = np_utils.np_filter_connected_components(a, largest_k_components=2, return_original_labels=False, k_larges_global=True) + a_volume = np_utils.np_volume(a_cc) + print(a_volume) + self.assertTrue(len(a_volume) == 2, a_volume) + self.assertTrue(a_volume[1] > a_volume[2]) + def test_fill_holes(self): # Create a test NII object with a segmentation mask arr = np.array([[0, 0, 0, 0, 0], [0, 1, 1, 1, 0], [0, 1, 0, 1, 0], [0, 1, 1, 1, 0], [0, 0, 0, 0, 0]], dtype=np.int16) From d5e1bc91791ea8f61b91eb942fb3d4fa7524a26b Mon Sep 17 00:00:00 2001 From: Robert Graf Date: Sat, 29 Nov 2025 00:53:42 +0000 Subject: [PATCH 24/28] add 3.9 to fast test --- .github/workflows/python-publish.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index 388d67d..44d10a3 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -17,7 +17,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v4 with: - python-version: '3.10' + python-version: ["3.9", "3.12"] - name: Install dependencies run: | python -m pip install --upgrade pip From acea6bb7d7ec42573afe129ecc3ae358eca77474 Mon Sep 17 00:00:00 2001 From: Robert Graf Date: Sat, 29 Nov 2025 00:57:18 +0000 Subject: [PATCH 25/28] add 3.9 to merge test --- .github/workflows/python-publish.yml | 2 +- .github/workflows/tests_mr.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index 44d10a3..500c50e 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -17,7 +17,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v4 with: - python-version: ["3.9", "3.12"] + python-version: "3.10" - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/.github/workflows/tests_mr.yml b/.github/workflows/tests_mr.yml index 9131af9..9c0d43e 100644 --- a/.github/workflows/tests_mr.yml +++ b/.github/workflows/tests_mr.yml @@ -14,7 +14,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - python-version: ["3.10"] + python-version: ["3.9","3.12"] steps: - uses: actions/checkout@v4 From 37c74e7a738e2f00ba413a8dee0257c2c294ffe3 Mon Sep 17 00:00:00 2001 From: Robert Graf Date: Sat, 29 Nov 2025 02:11:50 +0000 Subject: [PATCH 26/28] remove "total" --- tutorials/tutorial_Dataset_processing.ipynb | 43 ++++++++++----------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/tutorials/tutorial_Dataset_processing.ipynb b/tutorials/tutorial_Dataset_processing.ipynb index 60e56f8..c5251ff 100644 --- a/tutorials/tutorial_Dataset_processing.ipynb +++ b/tutorials/tutorial_Dataset_processing.ipynb @@ -36,7 +36,7 @@ "\n", "(2) Stitching\n", "\n", - "(3) Segmentation TotalVibeSegmentator and Spineps\n", + "(3) Segmentation VibeSegmentator and Spineps\n", "\n", "(4) Points of Interest (POI) \n", "\n", @@ -116,7 +116,7 @@ "source": [ "from pathlib import Path\n", "\n", - "from TPTBox.segmentation.TotalVibeSeg.auto_download import _download\n", + "from TPTBox.segmentation.VibeSeg.auto_download import _download\n", "\n", "## Download the Dicom Example\n", "path_to_dicom_dataset = Path(\"tutorial_data_processing\").absolute()\n", @@ -124,9 +124,9 @@ "\n", "if not path_to_dicom_dataset.exists():\n", " _download(\n", - " \"https://github.com/robert-graf/TotalVibeSegmentator/releases/download/example/tutorial_data_processing.zip\",\n", + " \"https://github.com/robert-graf/VibeSegmentator/releases/download/example/tutorial_data_processing.zip\",\n", " path_to_dicom_dataset,\n", - " text=\"example\"\n", + " text=\"example\",\n", " )\n", "\n", "dataset_name = Path(path_to_dicom_dataset).name.replace(\"_\", \"-\") # TODO Remove\n", @@ -446,11 +446,11 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### TotalVibeSegmentator\n", + "### VibeSegmentator\n", "\n", "https://arxiv.org/abs/2406.00125\n", "\n", - "https://github.com/robert-graf/TotalVibeSegmentator\n" + "https://github.com/robert-graf/VibeSegmentator\n" ] }, { @@ -462,9 +462,9 @@ "from pathlib import Path\n", "\n", "from TPTBox import BIDS_FILE\n", - "from TPTBox.segmentation import run_totalvibeseg\n", + "from TPTBox.segmentation import run_vibeseg\n", "\n", - "# run_totalvibeseg\n", + "# run_vibeseg\n", "# You can also use a string/Path if you want to set the path yourself.\n", "### Just making in and output path\n", "path_to_dicom_dataset = Path(\"tutorial_data_processing\").absolute()\n", @@ -475,10 +475,10 @@ " dataset,\n", ")\n", "out_file_dixon = in_file_dixon.get_changed_path(\n", - " \"nii.gz\", \"msk\", parent=\"derivative\", info={\"seg\": \"TotalVibeSegmentator\", \"mod\": in_file_dixon.bids_format}\n", + " \"nii.gz\", \"msk\", parent=\"derivative\", info={\"seg\": \"VibeSegmentator\", \"mod\": in_file_dixon.bids_format}\n", ")\n", "####\n", - "run_totalvibeseg(in_file_dixon, out_file_dixon, override=False)\n" + "run_vibeseg(in_file_dixon, out_file_dixon, override=False)\n" ] }, { @@ -510,9 +510,9 @@ " f\"{dataset}/rawdata_stiched/sub-111168223/ses-20230128/T2w/sub-111168223_ses-20230128_sequ-201-stiched_acq-ax_part-inphase_T2w.nii.gz\",\n", " dataset,\n", ")\n", - "out_file = in_file.get_changed_path(\"nii.gz\", \"msk\", parent=\"derivative\", info={\"seg\": \"TotalVibeSegmentator\", \"mod\": in_file.bids_format})\n", + "out_file = in_file.get_changed_path(\"nii.gz\", \"msk\", parent=\"derivative\", info={\"seg\": \"VibeSegmentator\", \"mod\": in_file.bids_format})\n", "####\n", - "run_totalvibeseg(in_file, out_file, override=False)" + "run_vibeseg(in_file, out_file, override=False)" ] }, { @@ -595,10 +595,10 @@ "metadata": {}, "outputs": [], "source": [ - "from TPTBox.segmentation.spineps import run_spineps_single\n", + "from TPTBox.segmentation.spineps import run_spineps\n", "\n", "# With 'ignore_compatibility_issues = True' you can force to run the soft ware\n", - "out_paths = run_spineps_single(\n", + "out_paths = run_spineps(\n", " in_file,\n", " dataset=dataset,\n", " model_semantic=model_semantic,\n", @@ -928,7 +928,6 @@ "source": [ "from TPTBox import NII, Location, Vertebra_Instance, calc_poi_from_subreg_vert\n", "from TPTBox.registration import Point_Registration, ridged_points_from_poi\n", - "from TPTBox.segmentation.TotalVibeSeg import extract_vertebra_bodies_from_totalVibe\n", "\n", "# Example registration two sagittal images, like for compensating movement between scans.\n", "poi_fixed = calc_poi_from_subreg_vert(nii_instance1, nii_semantic1, subreg_id=[Location.Vertebra_Corpus, Location.Spinosus_Process]).round(\n", @@ -984,10 +983,10 @@ "source": [ "from TPTBox import NII, Location, Vertebra_Instance, calc_poi_from_subreg_vert\n", "from TPTBox.registration import Point_Registration, ridged_points_from_poi\n", - "from TPTBox.segmentation.TotalVibeSeg import extract_vertebra_bodies_from_totalVibe\n", + "from TPTBox.segmentation.VibeSeg.vibeseg import extract_vertebra_bodies_from_VibeSeg\n", "\n", "# Example registration axial and sagittal with points.\n", - "# T2w axial points are computed from the total vibe segment.\n", + "# T2w axial points are computed from the vibe segment.\n", "dataset = target_folder / \"dataset-tutorial-data-processing\"\n", "\n", "nii_instance_path2 = out_paths[\"out_vert\"]\n", @@ -1004,7 +1003,7 @@ " dataset,\n", ")\n", "out_file = moving_file.get_changed_path(\n", - " \"nii.gz\", \"msk\", parent=\"derivative\", info={\"seg\": \"TotalVibeSegmentator\", \"mod\": moving_file.bids_format}\n", + " \"nii.gz\", \"msk\", parent=\"derivative\", info={\"seg\": \"VibeSegmentator\", \"mod\": moving_file.bids_format}\n", ")\n", "moving_image = to_nii(moving_file)\n", "out_file = to_nii(out_file, True)\n", @@ -1026,7 +1025,7 @@ "if Vertebra_Instance.L5.value not in poi_fixed.keys_region():\n", " num_lumbar_verts = 4\n", "# Note: this function currently assumes that we see the sacrum in the image.\n", - "nii, poi_moving = extract_vertebra_bodies_from_totalVibe(out_file, num_lumbar_verts=num_lumbar_verts, num_thoracic_verts=num_thoracic_verts)\n", + "nii, poi_moving = extract_vertebra_bodies_from_VibeSeg(out_file, num_lumbar_verts=num_lumbar_verts, num_thoracic_verts=num_thoracic_verts)\n", "\n", "registration_object: Point_Registration = ridged_points_from_poi(poi_fixed, poi_moving, c_val=0)\n", "\n", @@ -1104,7 +1103,7 @@ "fixed_nii = to_nii(Path(\"tutorial_data_processing/PixelPandemonium/mr_1.nii.gz\").absolute(), False)\n", "\n", "\n", - "reg = Deformable_Registration(fixed_nii,moving_nii, normalize_strategy=\"MRI\", device=\"cuda\",mask_foreground=False,verbose=99,lr=0.01)\n", + "reg = Deformable_Registration(fixed_nii, moving_nii, normalize_strategy=\"auto\", device=\"cuda\", verbose=99, lr=0.01)\n", "\n", "moved_nii = reg.transform_nii(moving_nii)\n" ] @@ -1174,7 +1173,7 @@ ], "metadata": { "kernelspec": { - "display_name": "py3.11", + "display_name": "py3.12", "language": "python", "name": "python3" }, @@ -1188,7 +1187,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.8" + "version": "3.12.9" } }, "nbformat": 4, From 67f47981a2f648eda603ace7c664b5cfefca4918 Mon Sep 17 00:00:00 2001 From: ga84mun Date: Sat, 29 Nov 2025 02:14:45 +0000 Subject: [PATCH 27/28] add init --- .gitignore | 2 +- TPTBox/segmentation/VibeSeg/__init__.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 TPTBox/segmentation/VibeSeg/__init__.py diff --git a/.gitignore b/.gitignore index f06f38f..e56499d 100755 --- a/.gitignore +++ b/.gitignore @@ -164,6 +164,6 @@ tutorials/tutorial_data_processing/* tutorials/*PixelPandemonium/* tutorials/dataset-PixelPandemonium/* *.html -_*.py +#_*.py dicom_select examples diff --git a/TPTBox/segmentation/VibeSeg/__init__.py b/TPTBox/segmentation/VibeSeg/__init__.py new file mode 100644 index 0000000..b4f8bcf --- /dev/null +++ b/TPTBox/segmentation/VibeSeg/__init__.py @@ -0,0 +1 @@ +from TPTBox.segmentation.VibeSeg.vibeseg import extract_vertebra_bodies_from_VibeSeg From 346e9db87bdb24cabf0ca64db16571da4e8e9087 Mon Sep 17 00:00:00 2001 From: Robert Graf Date: Sat, 29 Nov 2025 02:19:53 +0000 Subject: [PATCH 28/28] ruff: remove Optional --- TPTBox/registration/deepali/_utils.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/TPTBox/registration/deepali/_utils.py b/TPTBox/registration/deepali/_utils.py index d81ff86..988c6fc 100644 --- a/TPTBox/registration/deepali/_utils.py +++ b/TPTBox/registration/deepali/_utils.py @@ -5,7 +5,7 @@ from collections.abc import Sequence from contextlib import ContextDecorator from pathlib import Path -from typing import Literal, Optional, Union +from typing import Literal, Union import torch import torch.optim @@ -49,7 +49,7 @@ def get_post_transform( target_grid: Grid, source_grid: Grid, align=False, -) -> Optional[SpatialTransform]: +) -> SpatialTransform | None: r"""Get constant rigid transformation between image grid domains.""" if align is False or align is None: return None @@ -90,7 +90,7 @@ def load_transform(path: PathStr, grid: Grid) -> SpatialTransform: """ target_grid = grid - def convert_matrix(matrix: Tensor, grid: Optional[Grid] = None) -> Tensor: + def convert_matrix(matrix: Tensor, grid: Grid | None = None) -> Tensor: if grid is None: pre = target_grid.transform(Axes.CUBE_CORNERS, Axes.WORLD) post = target_grid.transform(Axes.WORLD, Axes.CUBE_CORNERS) @@ -166,7 +166,7 @@ def __exit__(self, exc_type, exc_value, traceback): self.scheduler.step() -def overlap_mask(source_mask: Tensor | None, target_mask: Tensor | None) -> Optional[Tensor]: +def overlap_mask(source_mask: Tensor | None, target_mask: Tensor | None) -> Tensor | None: r"""Overlap mask at which to evaluate pairwise data term.""" if source_mask is None: return target_mask @@ -184,11 +184,12 @@ def make_foreground_mask(image: Image, foreground_lower_threshold, foreground_up return Image(mask, image.grid()) -def normalize_img(image: Image, normalize_strategy: Optional[Literal["auto", "CT", "MRI"]]): +def normalize_img(image: Image, normalize_strategy: Literal["auto", "CT", "MRI"] | None): if normalize_strategy is None: return image data = image.tensor() if normalize_strategy == "MRI": + data = data.float() max_v = torch.quantile(data[data > 0], q=0.95) min_v = 0 elif normalize_strategy == "CT": @@ -210,7 +211,7 @@ def normalize_img(image: Image, normalize_strategy: Optional[Literal["auto", "CT return Image(data, image.grid()) -def clamp_mask(image: Optional[Image]): +def clamp_mask(image: Image | None): if image is None: return image data = image.tensor()