diff --git a/.gitignore b/.gitignore index 1b7731d..f06f38f 100755 --- a/.gitignore +++ b/.gitignore @@ -165,4 +165,5 @@ tutorials/*PixelPandemonium/* tutorials/dataset-PixelPandemonium/* *.html _*.py -dicom_select \ No newline at end of file +dicom_select +examples diff --git a/README.md b/README.md index 1096ac4..8abcb7c 100755 --- a/README.md +++ b/README.md @@ -87,23 +87,36 @@ nii.shape #shape ``` ### Stitching Python function and script for arbitrary image stitching. [See Details](TPTBox/stitching/) + +![Example of a stitching](TPTBox/stitching/stitching.jpg) ### Spineps and Points of Interests (POI) integration + +![Example of two lumbar vertebrae. The left example is derived from 1 mm isotropic CT, the right from sagittal MRI with a resolution of 3.3 mm in the left–right direction. Top row: Subregion of the vertebra used for analysis. Middle row: Extreme points. Bottom row: Corpus edge and ligamentum flavum points.](TPTBox/images/poi_preview.png) For our Spine segmentation pipline follow the installation of [SPINEPS](https://github.com/Hendrik-code/spineps). +Image Source: Rule-based Key-Point Extraction for MR-Guided Biomechanical Digital Twins of the Spine; + + SPINEPS will produce two mask: instance and semantic labels. With these we can compute our POIs. There are either center of mass points or surface points with bioloical meaning. See [Validation of a Patient-Specific Musculoskeletal Model for Lumbar Load Estimation Generated by an Automated Pipeline From Whole Body CT](https://pubmed.ncbi.nlm.nih.gov/35898642/) ```python -from TPTBox import NII, POI, Location, calc_poi_from_subreg_vert +from TPTBox import NII, POI, Location, POI_Global, calc_poi_from_subreg_vert from TPTBox.core.vert_constants import v_name2idx +from TPTBox.segmentation.spineps import run_spineps_single + +# This requires that spineps is installed +output_paths = run_spineps_single( + "file-path-of_T2w.nii.gz", + model_semantic="t2w", + ignore_compatibility_issues=True, +) +out_spine = output_paths["out_spine"] +out_vert = output_paths["out_vert"] +semantic_nii = NII.load(out_spine, seg=True) +instance_nii = NII.load(out_vert, seg=True) -p = "/dataset-DATASET/derivatives/A/" -semantic_nii = NII.load(f"{p}sub-A_sequ-stitched_acq-sag_mod-T2w_seg-spine_msk.nii.gz", seg=True) -instance_nii = NII.load(f"{p}sub-A_sequ-stitched_acq-sag_mod-T2w_seg-vert_msk.nii.gz", seg=True) -poi_path = f"{p}sub-A_sequ-stitched_acq-sag_mod-T2w_seg-spine_ctd.json" -poi = POI.load(poi_path) poi = calc_poi_from_subreg_vert( instance_nii, semantic_nii, - # buffer_file=poi_path, subreg_id=[ Location.Vertebra_Full, Location.Arcus_Vertebrae, @@ -157,18 +170,34 @@ poi = calc_poi_from_subreg_vert( Location.Vertebra_Direction_Right, ], ) -# poi.save(poi_path) poi = poi.round(2) print("Vertebra T4 Vertebra Corpus Center of mass:", poi[v_name2idx["T4"], Location.Vertebra_Corpus]) -# rescale/reorante like nii +print("The id number of T4 Vertebra_Corpus is ", v_name2idx["T4"], Location.Vertebra_Corpus.value) + +# rescale/reorante local poi like nii poi_new = poi.reorient(("P", "I", "R")).rescale((1, 1, 1)) +# Local and global POIs can be rescaled to a target spacing with: poi_new = poi.resample_from_to(other_nii_or_poi) +# local to global poi +global_poi = poi.to_global(itk_coords=True) +# You can save global pois in mrk.json format for import and editing in slicer. +global_poi.save_mrk("FILE.mrk.json", glyphScale=3.0) +# Import as a Markup in slicer; To make points editable you must click on the "lock" symbol under Markups - Control Points - Interaction + +# Save in our format: +poi.save(poi_path) +# Loading local/global Poi +poi = POI.load(poi_path) +poi = POI_Global.load(poi_path) + + + ``` ### Snapshot2D Spine - +![Snapshot2D Spine example](TPTBox/images/snp2D_example.png) The snapshot function automatically generates sag, cor, axial cuts in the center of a segmentation. ```python @@ -188,17 +217,21 @@ create_snapshot(snp_path="snapshot.jpg", frames=[ct_frame, mr_frame]) ### Snapshot3D - +![Snapshot3D example](TPTBox/images/snp3D_example.jpg) Requires additonal python packages: vtk fury xvfbwrapper ```python -from TPTBox.mesh3D.snapshot3D import make_snapshot3D +from TPTBox.mesh3D.snapshot3D import make_snapshot3D, make_snapshot3D_parallel + # all segmentation; view give the rotation of an image -make_snapshot3D("sub-101000_msk.nii.gz","snapshot3D.png",view=["A", "L", "P", "R"]) +make_snapshot3D("sub-101000_msk.nii.gz", "snapshot3D.png", view=["A", "L", "P", "R"]) # Select witch segmentation per panel are chosen. -make_snapshot3D("sub-101000_msk.nii.gz","snapshot3D_v2.png",view=["A"], ids_list=[[1,2],[3]]) +make_snapshot3D("sub-101000_msk.nii.gz", "snapshot3D_v2.png", view=["A"], ids_list=[[1, 2], [3]]) +# we proviede a implementation to process multiple images at the same time. +make_snapshot3D_parallel(["a.nii.gz", "b.nii.gz"], ["snp_a.png", "snp_b.png"], view=["A"]) ``` + \ No newline at end of file diff --git a/TPTBox/core/bids_constants.py b/TPTBox/core/bids_constants.py index 328019e..6c9c677 100755 --- a/TPTBox/core/bids_constants.py +++ b/TPTBox/core/bids_constants.py @@ -152,6 +152,7 @@ "logit", "localizer", "difference", + "labels", ] # https://bids-specification.readthedocs.io/en/stable/appendices/entity-table.html formats_relaxed = [*formats, "t2", "t1", "t2c", "t1c", "cta", "mr", "snapshot", "t1dixon", "dwi"] @@ -175,6 +176,7 @@ file_types = [ "nii.gz", "json", + "mrk.json", "png", "jpg", "tsv", @@ -190,6 +192,7 @@ "xlsx", "bvec", "bval", + "html", ] # Description see: https://bids-specification.readthedocs.io/en/stable/99-appendices/09-entities.html diff --git a/TPTBox/core/bids_files.py b/TPTBox/core/bids_files.py index 678c6eb..64d98ca 100755 --- a/TPTBox/core/bids_files.py +++ b/TPTBox/core/bids_files.py @@ -3,6 +3,7 @@ import json import os +import random import sys import typing from collections.abc import Sequence @@ -358,10 +359,14 @@ def add_file_2_subject(self, bids: BIDS_FILE | Path, ds=None) -> None: ) if self.verbose else None self.subjects[subject].add(bids) - def enumerate_subjects(self, sort=False) -> list[tuple[str, Subject_Container]]: + def enumerate_subjects(self, sort=False, shuffle=False) -> list[tuple[str, Subject_Container]]: # TODO Enumerate should put out numbers... if sort: return sorted(self.subjects.items()) + if shuffle: + s = list(self.subjects.items()) + random.shuffle(s) + return s return self.subjects.items() # type: ignore def iter_subjects(self, sort=False) -> list[tuple[str, Subject_Container]]: @@ -686,7 +691,7 @@ def get_changed_path( # noqa: C901 from_info=False, auto_add_run_id=False, additional_folder: str | None = None, - dataset_path: str | None = None, + dataset_path: str | Path | None = None, make_parent=False, no_sorting_mode: bool = False, non_strict_mode: bool = False, diff --git a/TPTBox/core/nii_poi_abstract.py b/TPTBox/core/nii_poi_abstract.py index 96678b0..15c377d 100755 --- a/TPTBox/core/nii_poi_abstract.py +++ b/TPTBox/core/nii_poi_abstract.py @@ -7,6 +7,7 @@ import nibabel as nib import nibabel.orientations as nio import numpy as np +from scipy.spatial.transform import Rotation from typing_extensions import Self from TPTBox.core.np_utils import np_count_nonzero @@ -113,6 +114,65 @@ def _extract_affine(self: Has_Grid, rm_key=(), **args): out.pop(k) return out + def change_affine( + self, + translation=None, + rotation_degrees=None, + scaling=None, + degrees=True, + inplace=False, + ): + """ + Apply a transformation (translation, rotation, scaling) to the affine matrix. + + Parameters: + translation: (n,) array-like in mm + rotation_degrees: (n,) array-like (pitch, yaw, roll) in degrees + scaling: (n,) array-like scaling factors along x, y, z + """ + warnings.warn("change_affine is untested", stacklevel=2) + n = self.affine.shape[0] + transform = np.eye(n) + + # Scaling + if scaling is not None: + assert len(scaling) == n - 1, f"Scaling must be a {n - 1}-element array-like." + S = np.diag([*list(scaling), 1]) + transform = S @ transform + + # Rotation + if rotation_degrees is not None: + assert len(rotation_degrees) == n - 1, f"Rotation must be a {n - 1}-element array-like." + rot = Rotation.from_euler("xyz", rotation_degrees, degrees=degrees).as_matrix() + R_mat = np.eye(n) + R_mat[: n - 1, : n - 1] = rot + transform = R_mat @ transform + + # Translation + if translation is not None: + T = np.eye(n) + T[: n - 1, n - 1] = translation + transform = T @ transform + if not inplace: + self = self.copy() # noqa: PLW0642 + # Update the affine + self.affine = transform @ self.affine + return self + + def change_affine_(self, translation=None, rotation_degrees=None, scaling=None, degrees=True): + return self.change_affine( + translation=translation, + rotation_degrees=rotation_degrees, + scaling=scaling, + degrees=degrees, + inplace=True, + ) + + def copy(self) -> Self: + raise NotImplementedError( + "The copy method must be implemented in the subclass. It should return a new instance of the same type with the same attributes." + ) + def assert_affine( self, other: Self | Has_Grid | None = None, @@ -181,11 +241,11 @@ def assert_affine( found_errors.append(f"rotation mismatch {self.rotation}, {rotation}") if not rotation_match else None if zoom is not None and (not ignore_missing_values or self.zoom is not None): if self.zoom is None: - found_errors.append(f"zoom mismatch {self.zoom}, {zoom}") + found_errors.append(f"spacing mismatch {self.zoom}, {zoom}") else: zms_diff = (self.zoom[i] - zoom[i] for i in range(3)) zms_match = np.all([abs(a) <= error_tolerance for a in zms_diff]) - found_errors.append(f"zoom mismatch {self.zoom}, {zoom}") if not zms_match else None + found_errors.append(f"spacing mismatch {self.zoom}, {zoom}") if not zms_match else None if orientation is not None and (not ignore_missing_values or self.affine is not None): if self.orientation is None: found_errors.append(f"orientation mismatch {self.orientation}, {orientation}") @@ -256,7 +316,14 @@ def get_empty_POI(self, points: dict | None = None): from TPTBox import POI p = {} if points is None else points - return POI(p, orientation=self.orientation, zoom=self.zoom, shape=self.shape, rotation=self.rotation, origin=self.origin) + return POI( + p, + orientation=self.orientation, + zoom=self.zoom, + shape=self.shape, + rotation=self.rotation, + origin=self.origin, + ) def make_empty_POI(self, points: dict | None = None): from TPTBox import POI @@ -267,7 +334,15 @@ def make_empty_POI(self, points: dict | None = None): args["level_one_info"] = self.level_one_info args["level_two_info"] = self.level_two_info - return POI(p, orientation=self.orientation, zoom=self.zoom, shape=self.shape, rotation=self.rotation, origin=self.origin, **args) + return POI( + p, + orientation=self.orientation, + zoom=self.zoom, + shape=self.shape, + rotation=self.rotation, + origin=self.origin, + **args, + ) def make_empty_nii(self, seg=False, _arr=None): from TPTBox import NII @@ -332,6 +407,30 @@ def to_deepali_grid(self, align_corners: bool = True): grid = grid.align_corners_(align_corners) return grid + @classmethod + def from_deepali_grid(cls, grid): + try: + from deepali.core import Grid as dp_Grid + except Exception: + log.print_error() + log.on_fail("run 'pip install hf-deepali' to install deepali") + raise + grid_: dp_Grid = grid + size = grid_.size() + spacing = grid_.spacing().cpu().numpy() + origin = grid_.origin().cpu().numpy() + direction = grid_.direction().cpu().numpy() + # Convert to ITK LPS convention + origin[:2] *= -1 + direction[:2] *= -1 + # Replace small values and -0 by 0 + epsilon = sys.float_info.epsilon + origin[np.abs(origin) < epsilon] = 0 + direction[np.abs(direction) < epsilon] = 0 + grid = Grid(shape=size, origin=origin, spacing=spacing, rotation=direction) # type: ignore + + return grid + def get_num_dims(self): return len(self.shape) @@ -342,9 +441,14 @@ def __init__(self, **qargs) -> None: for k, v in qargs.items(): if k == "spacing": k = "zoom" # noqa: PLW2901 + if k == "direction": + k = "rotation" # noqa: PLW2901 if k == "rotation": v = np.array(v) # noqa: PLW2901 if len(v.shape) == 1: s = int(np.sqrt(v.shape[0])) v = v.reshape(s, s) # noqa: PLW2901 setattr(self, k, v) + + ort = nio.io_orientation(self.affine) + self.orientation = nio.ornt2axcodes(ort) # type: ignore diff --git a/TPTBox/core/nii_wrapper.py b/TPTBox/core/nii_wrapper.py index 3481259..d924d7b 100755 --- a/TPTBox/core/nii_wrapper.py +++ b/TPTBox/core/nii_wrapper.py @@ -1,5 +1,6 @@ from __future__ import annotations +import math import traceback import warnings import zlib @@ -41,6 +42,7 @@ np_map_labels_based_on_majority_label_mask_overlap, np_point_coordinates, np_smooth_gaussian_labelwise, + np_translate_arr, np_unique, np_unique_withoutzero, np_volume, @@ -80,13 +82,86 @@ def formatwarning_tb(*args, **kwargs): return s -_dtype_max = {"int8": 128, "uint8": 256, "int16": 32768, "uint16": 65536} +_dtype_max = { + "int8": 128, + "uint8": 256, + "int16": 32768, + "uint16": 65536, + "int32": 2147483647, + "uint32": 4294967294, +} # uint32 not supported by nifty +_dtype_u = {"uint8", "uint16"} +_dtype_non_u = {"int8", "int16"} + + +def _check_if_nifty_is_lying_about_its_dtype(self: NII): + change_dtype = False + arr = self.nii.dataobj + dtype = self._nii.dataobj.dtype # type: ignore + dtype_s = str(self._nii.dataobj.dtype) + mi = np.min(arr) + ma = np.max(arr) + has_neg = mi < 0 + max_v = _dtype_max.get(str(dtype), 0) + positive = str(dtype) in _dtype_u + + if has_neg and positive: + warnings.warn( + f"Loaded NIfTY: incorrect dtype detected: {dtype}, but has negative values; min={mi}", + stacklevel=3, + ) + change_dtype = True + positive = False + if not has_neg and self.seg: + positive = True + if str(dtype) in _dtype_non_u: + change_dtype = True + if ma > max_v and "float" not in dtype_s: + change_dtype = True + warnings.warn( + f"Loaded NIfTY: incorrect dtype detected: {dtype}, but has larger max values; max={max_v}", + stacklevel=3, + ) + + out_dtype = dtype + if dtype == np.float16: + warnings.warn( + f"Loaded NIfTY: incorrect dtype detected: {dtype} is not supported", + stacklevel=3, + ) + out_dtype = np.float32 + if "float" in dtype_s and not change_dtype: + pass + elif positive and change_dtype: + if ma < 256: + out_dtype = np.uint8 + elif ma < 65536: + out_dtype = np.uint16 + else: + out_dtype = np.int32 + elif change_dtype: + ma_abs = np.max(np.abs(arr)) + if ma_abs < 128: + out_dtype = np.int8 + elif ma_abs < 32768: + out_dtype = np.int16 + else: + out_dtype = np.int32 + # print("check", out_dtype, change_dtype, positive, has_neg) + return out_dtype + warnings.formatwarning = formatwarning_tb N = TypeVar("N", bound="NII") Image_Reference = Union[bids_files.BIDS_FILE, Nifti1Image, Path, str, N] -Interpolateable_Image_Reference = Union[bids_files.BIDS_FILE, tuple[Nifti1Image, bool], tuple[Path, bool], tuple[str, bool], N] +Interpolateable_Image_Reference = Union[ + bids_files.BIDS_FILE, + tuple[Nifti1Image, bool], + tuple[Path, bool], + tuple[str, bool], + N, +] Proxy = tuple[tuple[int, int, int], np.ndarray] suppress_dtype_change_printout_in_set_array = False @@ -155,11 +230,17 @@ def __init__(self, nii: Nifti1Image|_unpacked_nii, seg=False,c_val=None, desc:st self.__min = None self.info = info if info is not None else {} self.set_description(desc) + if seg: + self._unpack() + if isinstance(self.dtype,np.floating): + self.set_dtype_("smallest_uint") @classmethod def load(cls, path: Image_Reference, seg, c_val=None)-> Self: nii= to_nii(path,seg) + if seg: + nii = nii.set_dtype("smallest_uint") nii.c_val = c_val return nii @@ -190,7 +271,9 @@ def load_nrrd(cls, path: str | Path, seg: bool): raise ImportError("The `pynrrd` package is required but not installed. Install it with `pip install pynrrd`.") from None _nrrd = nrrd.read(path) data = _nrrd[0] + header = dict(_nrrd[1]) + #print(data.shape, header) #print(header) # Example print out: OrderedDict([ # ('type', 'short'), ('dimension', 3), ('space', 'left-posterior-superior'), @@ -211,7 +294,13 @@ def load_nrrd(cls, path: str | Path, seg: bool): #space_directions = space_directions[~np.isnan(space_directions).any(axis=1)] #Filter NAN n = header['dimension'] #print(data.shape) - + if space_directions.shape != (n, n): + space_directions = space_directions[~np.isnan(space_directions).all(axis=1)] + m = len(space_directions[0]) + if m != n: + n=m + data = data.sum(axis=0) + space_directions = space_directions.T if space_directions.shape != (n, n): raise ValueError(f"Expected 'space directions' to be a nxn matrix. n = {n} is not {space_directions.shape}",space_directions) if space_origin.shape != (n,): @@ -235,6 +324,8 @@ def load_nrrd(cls, path: str | Path, seg: bool): except KeyError as e: raise KeyError(f"Missing expected header field: {e}") from None + if len(data.shape) != n: + raise ValueError(f"{len(data.shape)=} diffrent from n = ", n) ref_orientation = header.get("ref_orientation") for i in ["ref_orientation","dimension","space directions","space origin""space","type","endian"]: header.pop(i, None) @@ -281,28 +372,11 @@ def _unpack(self): try: if self.__unpacked: return - if self.seg: - m = np.max(self.nii.dataobj) - if m<256: - dtype = np.uint8 - elif m<65536: - dtype = np.uint16 - else: - dtype = np.int32 - self._arr = np.asanyarray(self.nii.dataobj, dtype=self.nii.dataobj.dtype).astype(dtype).copy() # type: ignore - self._checked_dtype = True - elif not self._checked_dtype: - # if the maximum is lager than the dtype, we use float. + if not self._checked_dtype or self.seg: + dtype = _check_if_nifty_is_lying_about_its_dtype(self) + #print("unpack-nii",f"{self.seg=}",dtype) self._checked_dtype = True - dtype = str(self.dtype) - if dtype not in _dtype_max: - self._arr = np.asanyarray(self.nii.dataobj, dtype=self.nii.dataobj.dtype).copy() #type: ignore - else: - m = np.max(self.nii.dataobj) - if m > _dtype_max[dtype]: - self._arr = self.nii.get_fdata() - else: - self._arr = np.asanyarray(self.nii.dataobj, dtype=self.nii.dataobj.dtype).copy() #type: ignore + self._arr = np.asanyarray(self.nii.dataobj, dtype=dtype).copy() else: self._arr = np.asanyarray(self.nii.dataobj, dtype=self.nii.dataobj.dtype).copy() #type: ignore @@ -355,7 +429,10 @@ def nii(self,nii:Nifti1Image|_unpacked_nii): # is there a dimesion with size 1? arr = arr.squeeze() # TODO try to get back to a saveabel state, if this did not work - + if arr.dtype == np.uint64:#throws error + arr = arr.astype(np.uint32) + if arr.dtype == np.int64:#throws error + arr = arr.astype(np.int32) self._arr = arr self._aff = aff self._checked_dtype = True @@ -406,33 +483,40 @@ def affine(self,affine:np.ndarray): def orientation(self) -> AX_CODES: ort = nio.io_orientation(self.affine) return nio.ornt2axcodes(ort) # type: ignore - + @property + def dims(self)->int: + self._unpack() + return self.affine.shape[0]-1 @property def zoom(self) -> ZOOMS: - rotation_zoom = self.affine[:3, :3] + n = self.dims + rotation_zoom = self.affine[:n, :n] zoom = np.sqrt(np.sum(rotation_zoom * rotation_zoom, axis=0)) if self.__divergent else self.header.get_zooms() z = tuple(np.round(zoom,7)) - if len(z) == 4: - z = z[:3] - assert len(z) == 3,z + if len(z) >= n: + z = z[:n] + #assert len(z) == 3,z return z # type: ignore @property def origin(self) -> tuple[float, float, float]: - z = tuple(np.round(self.affine[:3,3],7)) + n = self.dims + z = tuple(np.round(self.affine[:n,n],7)) assert len(z) == 3 return z # type: ignore @origin.setter def origin(self,x:tuple[float, float, float]): + n = self.dims self._unpack() self.__divergent = True affine = self._aff - affine[:3,3] = np.array(x) # type: ignore + affine[:n,n] = np.array(x) # type: ignore self._aff = affine @property def rotation(self)->np.ndarray: - rotation_zoom = self.affine[:3, :3] + n = self.dims + rotation_zoom = self.affine[:n, :n] zoom = np.array(self.zoom) rotation = rotation_zoom / zoom return rotation @@ -540,10 +624,17 @@ def set_array(self,arr:np.ndarray|Self, inplace=False,verbose:logging=False,seg= def set_array_(self,arr:np.ndarray,verbose:logging=True): return self.set_array(arr,inplace=True,verbose=verbose) - def set_dtype(self,dtype:type|Literal['smallest_int'] = np.float32,order:Literal["C","F","A","K"] ='K',casting:Literal["no","equiv","safe","same_kind","unsafe"] = "unsafe", inplace=False): + def set_dtype(self,dtype:type|Literal['smallest_int','smallest_uint'] = np.float32,order:Literal["C","F","A","K"] ='K',casting:Literal["no","equiv","safe","same_kind","unsafe"] = "unsafe", inplace=False): sel = self if inplace else self.copy() - - if dtype == "smallest_int": + if dtype == "smallest_uint": + arr = self.get_array() + if arr.max()<256: + dtype = np.uint8 + elif arr.max()<65536: + dtype = np.uint16 + else: + dtype = np.int32 + elif dtype == "smallest_int": arr = self.get_array() if arr.max()<128: dtype = np.int8 @@ -551,16 +642,20 @@ def set_dtype(self,dtype:type|Literal['smallest_int'] = np.float32,order:Literal dtype = np.int16 else: dtype = np.int32 - - sel.nii.set_data_dtype(dtype) - if sel.nii.get_data_dtype() != self.dtype: #type: ignore - sel.nii = Nifti1Image(self.get_array().astype(dtype,casting=casting,order=order),self.affine,self.header) + if self.__unpacked: + self._unpack() + sel._arr = sel._arr.astype(dtype) + sel.header.set_data_dtype(dtype) + else: + sel.nii.set_data_dtype(dtype) + if sel.nii.get_data_dtype() != self.dtype: #type: ignore + sel.nii = Nifti1Image(self.get_array().astype(dtype,casting=casting,order=order),self.affine,self.header) return sel - def set_dtype_(self,dtype:type|Literal['smallest_int'] = np.float32,order:Literal["C","F","A","K"] ='K',casting:Literal["no","equiv","safe","same_kind","unsafe"] = "unsafe"): + def set_dtype_(self,dtype:type|Literal['smallest_uint','smallest_int'] = np.float32,order:Literal["C","F","A","K"] ='K',casting:Literal["no","equiv","safe","same_kind","unsafe"] = "unsafe"): return self.set_dtype(dtype=dtype,order=order,casting=casting, inplace=True) - def astype(self,dtype,order:Literal["C","F","A","K"] ='K', casting:Literal["no","equiv","safe","same_kind","unsafe"] = "unsafe",subok=True, copy=True): + def astype(self,dtype,order:Literal["C","F","A","K"] ='K', casting:Literal["no","equiv","safe","same_kind","unsafe"] = "unsafe",subok=True, copy=True)->Self: ''' numpy wrapper ''' if subok: c = self.set_dtype(dtype,order=order,casting=casting, inplace=copy) @@ -713,7 +808,7 @@ def apply_crop_slice_(self,*args,**qargs): warnings.warn("apply_crop_slice_ id deprecated use apply_crop_ instead",stacklevel=5) #TODO remove in version 1.0 return self.apply_crop_(*args,**qargs) - def apply_crop(self,ex_slice:tuple[slice,slice,slice]|Sequence[slice] , inplace=False): + def apply_crop(self,ex_slice:tuple[slice,slice,slice]|Sequence[slice]|None , inplace=False): """ The apply_crop_slice function applies a given slice to reduce the Nifti image volume. If a list of slices is provided, it computes the minimum volume of all slices and applies it. @@ -745,10 +840,14 @@ def pad_to(self,target_shape:list[int]|tuple[int,int,int] | Self, mode:MODES="co def apply_pad(self,padd:Sequence[tuple[int|None,int]],mode:MODES="constant",inplace = False,verbose:logging=True): #TODO add other modes #TODO add testcases and options for modes - transform = np.eye(4, dtype=int) + transform = np.eye(self.dims+1, dtype=int) + assert len(padd) == self.dims for i, (before,_) in enumerate(padd): #transform[i, i] = pad_slice.step if pad_slice.step is not None else 1 transform[i, 3] = -before if before is not None else 0 + + while len(padd) < len(self.shape): + padd = (*tuple(padd), (0, 0)) affine = self.affine.dot(transform) args = {} if mode == "constant": @@ -781,7 +880,7 @@ def reorient_same_as(self, img_as: Nifti1Image | Self, verbose:logging=False, in return self.reorient(axcodes_to=axcodes_to, verbose=verbose, inplace=inplace) def reorient_same_as_(self, img_as: Nifti1Image | Self, verbose:logging=False) -> Self: return self.reorient_same_as(img_as=img_as,verbose=verbose,inplace=True) - def rescale(self, voxel_spacing=(1, 1, 1), c_val:float|None=None, verbose:logging=False, inplace=False,mode:MODES='nearest',order: int |None = None,align_corners:bool=False): + def rescale(self, voxel_spacing:float|tuple[float,...]=(1, 1, 1), c_val:float|None=None, verbose:logging=False, inplace=False,mode:MODES='nearest',order: int |None = None,align_corners:bool=False,atol=0.001): """ Rescales the NIfTI image to a new voxel spacing. @@ -796,12 +895,16 @@ def rescale(self, voxel_spacing=(1, 1, 1), c_val:float|None=None, verbose:loggin mode (str, optional): One of the supported modes by scipy.ndimage.interpolation (e.g., "constant", "nearest", "reflect", "wrap"). See the documentation for more details. Defaults to "constant". align_corners (bool|default): If True or not set and seg==True. Aline corners for scaling. This prevents segmentation mask to shift in a direction. + atol: absolute tolerance for skipping if already close in voxel_spacing Returns: NII: A new NII object with the resampled image data. """ if isinstance(voxel_spacing, (int,float)): - voxel_spacing =(voxel_spacing,voxel_spacing,voxel_spacing) - if voxel_spacing in ((-1, -1, -1), self.zoom): + voxel_spacing =(voxel_spacing for _ in range(min(3,self.affine.shape[0]-1))) + n = self.dims + while n> len(voxel_spacing): + voxel_spacing = (*voxel_spacing, -1) + if all(a in (-1, b) for a,b in zip(voxel_spacing, self.zoom)): log.print(f"Image already resampled to voxel size {self.zoom}",verbose=verbose) return self.copy() if inplace else self @@ -812,15 +915,18 @@ def rescale(self, voxel_spacing=(1, 1, 1), c_val:float|None=None, verbose:loggin zms = self.zoom if order is None: order = 0 if self.seg else 3 + #print(aff.shape,shp,zms,voxel_spacing) voxel_spacing = tuple([v if v != -1 else z for v,z in zip_strict(voxel_spacing,zms)]) - if voxel_spacing == self.zoom: + if np.isclose(voxel_spacing, self.zoom,atol=atol).all(): log.print(f"Image already resampled to voxel size {self.zoom}",verbose=verbose) return self.copy() if inplace else self # Calculate new shape new_shp = tuple(np.rint([shp[i] * zms[i] / voxel_spacing[i] for i in range(len(voxel_spacing))]).astype(int)) - new_aff = nib.affines.rescale_affine(aff, shp, voxel_spacing, new_shp) # type: ignore - new_aff[:3, 3] = nib.affines.apply_affine(aff, [0, 0, 0])# type: ignore + if len(new_shp) < len(shp): + new_shp = new_shp + shp[len(new_shp):] + new_aff = _rescale_affine(aff, shp, voxel_spacing, new_shp) # type: ignore + new_aff[:n, n] = nib.affines.apply_affine(aff, [0 for _ in range(n)])# type: ignore new_img = _resample_from_to(self, (new_shp, new_aff,voxel_spacing), order=order, mode=mode,align_corners=align_corners) log.print(f"Image resampled from {zms} to voxel size {voxel_spacing}",verbose=verbose) if inplace: @@ -845,11 +951,11 @@ def resample_from_to(self, to_vox_map:Image_Reference|Has_Grid|tuple[SHAPE,AFFIN """ '''''' c_val = self.get_c_val(c_val) if isinstance(to_vox_map,Has_Grid): - mapping = to_vox_map + mapping = to_vox_map.to_gird() else: mapping = to_vox_map if isinstance(to_vox_map, tuple) else to_nii_optional(to_vox_map, seg=self.seg, default=to_vox_map) - if isinstance(mapping,Has_Grid) and mapping.assert_affine(self,raise_error=False,origin_tolerance=0.00000001,error_tolerance=0.00000001,shape_tolerance=0): - log.print(f"resample_from_to skipped: {self}",verbose=verbose) + if isinstance(mapping,Has_Grid) and mapping.assert_affine(self,raise_error=False,origin_tolerance=0.000001,error_tolerance=0.000001,shape_tolerance=0): + log.print(f"resample_from_to skipped; already in space: {self}",verbose=verbose) return self if inplace else self.copy() assert mapping is not None log.print(f"resample_from_to: {self} to {mapping}",verbose=verbose) @@ -1017,7 +1123,7 @@ def smooth_gaussian_labelwise( radius: int = 6, truncate: int = 4, boundary_mode: str = "nearest", - dilate_prior: int = 1, + dilate_prior: int = 0, dilate_connectivity: int = 1, smooth_background: bool = True, inplace: bool = False, @@ -1262,7 +1368,7 @@ def get_connected_components(self, labels: int |list[int]=1, connectivity: int = out, _ = np_connected_components(self.get_seg_array(), label_ref=labels, connectivity=connectivity, include_zero=include_zero) return self.set_array(out,inplace=inplace) - def get_connected_components_per_label(self, labels: int |list[int]=1, connectivity: int = 3, include_zero: bool=False) -> dict[int, Self]: # noqa: ARG002 + def get_connected_components_per_label(self, labels: int |list[int], connectivity: int = 3, include_zero: bool=False) -> dict[int, Self]: # noqa: ARG002 assert self.seg, "This only works on segmentations" out = np_connected_components_per_label(self.get_seg_array(), label_ref=labels, connectivity=connectivity, include_zero=include_zero) cc = {i: self.set_array(k) for i,k in out.items()} @@ -1690,17 +1796,16 @@ def clone(self): def save(self,file:str|Path,make_parents=True,verbose:logging=True, dtype = None): if make_parents: Path(file).parent.mkdir(exist_ok=True,parents=True) - arr = self.get_array() + + arr = self.get_array() if not self.seg else self.get_seg_array() + if isinstance(arr,np.floating) and self.seg: + self.set_dtype_("smallest_uint") + arr = self.get_array() if not self.seg else self.get_seg_array() + + out = Nifti1Image(arr, self.affine,self.header)#,dtype=arr.dtype) if dtype is not None: out.set_data_dtype(dtype) - elif self.seg: - if arr.max()<256: - out.set_data_dtype(np.uint8) - elif arr.max()<65536: - out.set_data_dtype(np.uint16) - else: - out.set_data_dtype(np.int32) if out.header["qform_code"] == 0: #NIFTI_XFORM_UNKNOWN Will cause an error for some rounding of the affine in ITKSnap ... # 1 means Scanner coordinate system # 2 means align (to something) coordinate system @@ -1843,15 +1948,16 @@ def is_intersecting_vertical(self, b: Self, min_overlap=40) -> bool: return True return min_v < x2[2] < max_v - def get_intersecting_volume(self, b: Self) -> bool: + def get_intersecting_volume(self, b: Self) -> float: ''' computes intersecting volume ''' - b = b.copy() # type: ignore + b = to_nii(b).copy() # type: ignore b.nii = Nifti1Image(b.get_array()*0+1,affine=b.affine) b.seg = True b.set_dtype_(np.uint8) - b = b.resample_from_to(self,c_val=0,verbose=False) # type: ignore + b.c_val = 0 + b = b.resample_from_to(self,c_val=0,verbose=False,mode="constant") # type: ignore return b.get_array().sum() def extract_background(self,inplace=False): @@ -1882,7 +1988,7 @@ def extract_label(self,label:int|Enum|Sequence[int]|Sequence[Enum], keep_label=F return self.set_array(seg_arr,inplace=inplace) def extract_label_(self,label:int|Location|Sequence[int]|Sequence[Location], keep_label=False): return self.extract_label(label,keep_label,inplace=True) - def remove_labels(self,label:int|Location|Sequence[int]|Sequence[Location], inplace=False, verbose:logging=True): + def remove_labels(self,label:int|Location|Sequence[int]|Sequence[Location], inplace=False, verbose:logging=True, removed_to_label=0): '''If this NII is a segmentation you can single out one label.''' assert label != 0, 'Zero label does not make sens. This is the background' seg_arr = self.get_seg_array() @@ -1891,9 +1997,9 @@ def remove_labels(self,label:int|Location|Sequence[int]|Sequence[Location], inpl for l in label: if isinstance(l, list): for g in l: - seg_arr[seg_arr == g] = 0 + seg_arr[seg_arr == g] = removed_to_label else: - seg_arr[seg_arr == l] = 0 + seg_arr[seg_arr == l] = removed_to_label return self.set_array(seg_arr,inplace=inplace, verbose=verbose) def remove_labels_(self,label:int|Location|Sequence[int]|Sequence[Location], verbose:logging=True): return self.remove_labels(label,inplace=True,verbose=verbose) @@ -1909,10 +2015,53 @@ def unique(self,verbose:logging=False): out = np_unique_withoutzero(self.get_seg_array()) log.print(out,verbose=verbose) return out + def voxel_volume(self): + + product = math.prod(self.spacing) + return product - def volumes(self, include_zero: bool = False) -> dict[int, int]: + def volumes(self, include_zero: bool = False, in_mm3=False) -> dict[int, float]|dict[int, int]: '''Returns a dict stating how many pixels are present for each label''' - return np_volume(self.get_seg_array(), include_zero=include_zero) + dic = np_volume(self.get_seg_array(), include_zero=include_zero) + if in_mm3: + voxel_size = self.voxel_volume() + dic = {k:v*voxel_size for k,v in dic.items()} + return dic + def translate_arr( + self, + translation_vector: tuple[int, int, int] | dict[str,int]| dict[DIRECTIONS,int], + inplace: bool = False, + verbose: bool = True + ): + """ + Translates the NIfTI image array by a given vector. Translation can be specified as a tuple + or as a dict using anatomical directions ('S', 'I', 'R', 'L', 'P', 'A'). + + Args: + translation_vector (tuple or dict): Vector to translate with. Can be a tuple of ints + or a dict with keys from 'SIRLPA' and int values. + inplace (bool): Whether to modify the array in place. + verbose (bool): Whether to print log messages. + + Returns: + NII: Translated image. + """ + log.print("Translating array", end="\r", verbose=verbose) + arr = self.get_array() + + if isinstance(translation_vector, dict): + # Start with zero translation for all axes + vector = [0] * arr.ndim + for direction, amount in translation_vector.items(): + axis = self.get_axis(direction=direction) # type: ignore + sign = +1 if direction in self.orientation else -1 + vector[axis] += sign * amount + v: tuple[int, int, int] = tuple(vector) # type: ignore + else: + v = translation_vector + arr_translated = np_translate_arr(arr, v) + log.print(f"Translated by {v}; ", verbose=verbose) + return self.set_array(arr_translated, inplace=inplace) def center_of_masses(self) -> dict[int, COORDINATE]: '''Returns a dict stating the center of mass for each present label (not including zero!)''' @@ -1935,7 +2084,8 @@ def to_nii_optional(img_bids: Image_Reference|None, seg=False, default=None) -> def to_nii(img_bids: Image_Reference, seg=False) -> NII: if isinstance(img_bids, Path): img_bids = str(img_bids) - if isinstance(img_bids, NII): + # ugly workaround due to module import issues with nii files + if isinstance(img_bids, NII) or img_bids.__class__.__name__ == "NII": return img_bids.copy() elif isinstance(img_bids, bids_files.BIDS_FILE): return img_bids.open_nii() @@ -1969,3 +2119,18 @@ def to_nii_interpolateable(i_img:Interpolateable_Image_Reference) -> NII: return i_img.open_nii() else: raise TypeError("to_nii_interpolateable",i_img) + +def _rescale_affine(affine, shape, zooms, new_shape=None): + shape = np.asarray(shape) + new_shape = np.array(new_shape if new_shape is not None else shape) + s = nib.affines.voxel_sizes(affine) + n = len(zooms) + if len(shape)>= n: + shape = shape[:n] + new_shape = new_shape[:n] + rzs_out = affine[:n, :n] * zooms / s + + # Using xyz = A @ ijk, determine translation + centroid = nib.affines.apply_affine(affine, (shape - 1) // 2) + t_out = centroid - rzs_out @ ((new_shape - 1) // 2) + return nib.affines.from_matvec(rzs_out, t_out) diff --git a/TPTBox/core/nii_wrapper_math.py b/TPTBox/core/nii_wrapper_math.py index 374afe7..1bb36eb 100755 --- a/TPTBox/core/nii_wrapper_math.py +++ b/TPTBox/core/nii_wrapper_math.py @@ -9,6 +9,8 @@ from skimage.metrics import structural_similarity as ssim from typing_extensions import Self +from TPTBox.core.np_utils import np_dice + from .nii_poi_abstract import Has_Grid # fmt: off @@ -23,6 +25,8 @@ def get_array(self) -> np.ndarray: ... def set_array(self,arr:np.ndarray,inplace=False,verbose=True)->Self: ... + def get_seg_array(self) -> np.ndarray: + ... @property def shape(self) -> tuple[int, int, int]: ... @@ -38,12 +42,16 @@ def affine(self) -> np.ndarray: ... def get_c_val(self)->int: ... + def unique(self)->list[int]: + ... C = Union[NII, Number, np.ndarray] else: class NII_Proxy: pass C = Union[Self,Number,np.ndarray] + class NII_Math(NII_Proxy,Has_Grid): + __hash__ = None # type: ignore # explicitly mark as unhashable def _binary_opt(self, other:C, opt,inplace = False)-> Self: if isinstance(other,NII_Math): other = other.get_array() @@ -208,7 +216,18 @@ def psnr(self,nii: NII_Proxy,min_v=0): img_2[img_2<=0] = 0 ssim_value = psnr(img_1, img_2,data_range=img_1.max() - img_1.min()) return ssim_value - + def dice(self,nii: NII_Proxy,bar=True)->dict[int,float]: + out:dict[int,float] = {} + gt = self.get_seg_array() + pred = nii.get_seg_array() + s = set(self.unique()+nii.unique()) + if bar: + from tqdm import tqdm + s = tqdm(s,desc="dice") + for lbl in s: + out[lbl] = np_dice(pred,gt,label=lbl) + #print(out[lbl]) + return out def betti_numbers(self: NII,verbose=False) -> dict[int, tuple[int, int, int]]: # type: ignore """ diff --git a/TPTBox/core/np_utils.py b/TPTBox/core/np_utils.py index 9816f98..090d4fa 100755 --- a/TPTBox/core/np_utils.py +++ b/TPTBox/core/np_utils.py @@ -17,7 +17,12 @@ from cc3d import voxel_connectivity_graph as _voxel_connectivity_graph from fill_voids import fill as _fill from numpy.typing import NDArray -from scipy.ndimage import binary_erosion, center_of_mass, gaussian_filter, generate_binary_structure +from scipy.ndimage import ( + binary_erosion, + center_of_mass, + gaussian_filter, + generate_binary_structure, +) from skimage.measure import euler_number as _euler_number from skimage.measure import label as _label @@ -51,7 +56,11 @@ def np_extract_label( return arr == label if to_label == 0: - warnings.warn("np_extract_label: to_label is zero, this can have unforeseen consequences!", UserWarning, stacklevel=4) + warnings.warn( + "np_extract_label: to_label is zero, this can have unforeseen consequences!", + UserWarning, + stacklevel=4, + ) if not inplace: arr = arr.copy() @@ -90,7 +99,7 @@ def cc3dstatistics(arr: UINTARRAY, use_crop: bool = True) -> dict: Raises: AssertionError: If the input array is not of an unsigned integer or boolean dtype. """ - assert np.issubdtype(arr.dtype, np.unsignedinteger) or np.issubdtype(arr.dtype, np.bool_), ( + assert np.issubdtype(arr.dtype, np.unsignedinteger) or np.issubdtype(arr.dtype, np.int32) or np.issubdtype(arr.dtype, np.bool_), ( f"cc3dstatistics expects uint type, got {arr.dtype}" ) try: @@ -180,7 +189,7 @@ def np_unique_withoutzero(arr: UINTARRAY) -> list[int]: return [i for i in np_unique(arr) if i != 0] -def np_center_of_mass(arr: UINTARRAY) -> dict[int, np.ndarray]: +def np_center_of_mass(arr: UINTARRAY) -> dict[int, COORDINATE]: """Calculates center of mass, mapping label in array to a coordinate (float) (exluding zero) Args: @@ -288,16 +297,18 @@ def np_dice(seg: np.ndarray, gt: np.ndarray, binary_compare: bool = False, label float: dice value """ assert seg.shape == gt.shape, f"shape mismatch, got {seg.shape}, and {gt.shape}" - if binary_compare: - seg = seg.copy() - seg[seg != 0] = 1 - gt = gt.copy() - gt[gt != 0] = 1 - label = 1 with warnings.catch_warnings(): warnings.filterwarnings("ignore", r"invalid value encountered in double_scalars") - dice = float(np.sum(seg[gt == label]) * 2.0 / (np.sum(seg) + np.sum(gt))) + if binary_compare: + seg_l = seg != 0 + gt_l = gt != 0 + else: + seg_l = seg == label # predicted mask for this label + gt_l = gt == label # ground-truth mask for this label + intersect = np.logical_and(seg_l, gt_l).sum() + denom = seg_l.sum() + gt_l.sum() + dice = (2.0 * intersect) / (denom) if np.isnan(dice): return 1.0 return dice @@ -792,11 +803,15 @@ def np_filter_connected_components( for preserve_label in preserve: cc_out[labels_out == preserve_label] = i i += 1 - if removed_to_label != 0: - arr[np.logical_and(labels_out != 0, arr == 0)] = removed_to_label + if return_original_labels: arr *= cc_out > 0 # to get original labels + if removed_to_label != 0: + arr[np.logical_and(labels_out != 0, arr == 0)] = removed_to_label return arr + if removed_to_label != 0: + arr[np.logical_and(labels_out != 0, arr == 0)] = removed_to_label + return cc_out @@ -956,8 +971,20 @@ def np_smooth_gaussian_labelwise( dilate_connectivity: int = 3, smooth_background: bool = True, ) -> UINTARRAY: - """Smoothes labels in a segmentation mask array - + """Smoothes selected labels in a segmentation mask using Gaussian filtering, + while keeping other labels unaffected. + + Internal Description: + 1. Ensures label(s) to be smoothed are present in the segmentation. + 2. Optionally dilates specified labels prior to smoothing (if `dilate_prior > 0`). + 3. Iterates over each label: + - Creates a binary mask for that label. + - Applies Gaussian smoothing only if the label is in `label_to_smooth`. + - Optionally applies a weight from `label_weights`. + 4. Adds background as a separate smoothed or fixed mask depending on `smooth_background`. + 5. Stacks all label probability-like maps and computes a new segmentation by taking the + `argmax` over the stacked array, i.e., the label with the highest value wins per voxel. + 6. Replaces the indices in the argmax map with the original label values to preserve semantics. Args: arr (UINTARRAY): Input Segmentation Mask Array diff --git a/TPTBox/core/poi.py b/TPTBox/core/poi.py index 71cf2b3..b3eb7c5 100755 --- a/TPTBox/core/poi.py +++ b/TPTBox/core/poi.py @@ -40,7 +40,13 @@ ### CURRENT TYPE DEFINITIONS C = TypeVar("C", bound="POI") -POI_Reference = Union[bids_files.BIDS_FILE, Path, str, tuple[Image_Reference, Image_Reference, Sequence[int]], C] +POI_Reference = Union[ + bids_files.BIDS_FILE, + Path, + str, + tuple[Image_Reference, Image_Reference, Sequence[int]], + C, +] @dataclass @@ -114,6 +120,15 @@ class POI(Abstract_POI, Has_Grid): _zoom: ZOOMS = field(init=False, default=(1, 1, 1), repr=False, compare=False) _vert_orientation_pir = {} # Elusive; will not be saved; will not be copied. For Buffering results # noqa: RUF012 + def _set_inplace(self, poi: Self): + self.orientation = poi.orientation + self.centroids = poi.centroids + self.zoom = poi.zoom + self.shape = poi.shape + self.origin = poi.origin + self.rotation = poi.rotation + return self + @property def is_global(self): return False @@ -155,6 +170,8 @@ def spacing(self, value): def clone(self, **qargs): return self.copy(**qargs) + __hash__ = None # type: ignore # explicitly mark as unhashable + def copy( self, centroids: POI_DICT | POI_Descriptor | None = None, @@ -557,6 +574,9 @@ def to_global(self, itk_coords=False): def resample_from_to(self, ref: Has_Grid): return self.to_global().to_other(ref) + def resample_from_to_(self, ref: Has_Grid): + return self._set_inplace(self.resample_from_to_(ref)) + def save( self, out_path: Path | str, @@ -901,6 +921,7 @@ def calc_poi_from_subreg_vert( # use_vertebra_special_action=True, _vert_ids=None, _print_phases=False, + _orientation_version=0, ) -> POI: """ Calculates the POIs of a subregion within a vertebral mask. This function is spine opinionated, the general implementation is "calc_poi_from_two_masks". @@ -943,7 +964,7 @@ def calc_poi_from_subreg_vert( if _vert_ids is None: _vert_ids = vert_msk.unique() - from TPTBox.core.poi_fun.vertebra_pois_non_centroids import add_prerequisites, compute_non_centroid_pois + from TPTBox.core.poi_fun.vertebra_pois_non_centroids import add_prerequisites, compute_non_centroid_pois # noqa: PLC0415 subreg_id = add_prerequisites(_int2loc(subreg_id if isinstance(subreg_id, Sequence) else [subreg_id])) # type: ignore @@ -1035,6 +1056,7 @@ def calc_poi_from_subreg_vert( subreg_msk, _vert_ids=_vert_ids, log=log, + _orientation_version=_orientation_version, ) extend_to.apply_crop_reverse(crop, org_shape, inplace=True) return extend_to @@ -1128,6 +1150,8 @@ def calc_centroids( second_stage: int | Abstract_lvl = 50, extend_to: POI | None = None, inplace: bool = False, + bar=False, + _crop=True, ) -> POI: """ Calculates the centroid coordinates of each region in the given mask image. @@ -1169,15 +1193,31 @@ def calc_centroids( if not inplace: extend_to = extend_to.copy() ctd_list = extend_to.centroids - extend_to.assert_affine(msk_nii, shape_tolerance=0.5, origin_tolerance=0.5) - for i in msk_nii.unique(): - msk_temp = np.zeros(msk_data.shape, dtype=bool) - msk_temp[msk_data == i] = True - ctr_mass: Sequence[float] = center_of_mass(msk_temp) # type: ignore + extend_to.assert_affine(msk_nii, shape_tolerance=1, origin_tolerance=1) + u = msk_nii.unique() + if bar: + from tqdm import tqdm + + u = tqdm(u) + for i in u: + if _crop: + # TODO test implementation and remove old + m = msk_nii.extract_label(i) + crop = m.compute_crop() + m2: NII = m[crop] + ctr_mass: Sequence[float] = center_of_mass(m2.get_seg_array()) # type: ignore + out_coord = tuple(round(x + crop.start, decimals) for x, crop in zip(ctr_mass, crop)) + else: + # OLD + msk_temp = np.zeros(msk_data.shape, dtype=bool) + msk_temp[msk_data == i] = True + ctr_mass: Sequence[float] = center_of_mass(msk_temp) # type: ignore + out_coord = tuple(round(x, decimals) for x in ctr_mass) + if second_stage == -1: - ctd_list[first_stage, int(i)] = tuple(round(x, decimals) for x in ctr_mass) + ctd_list[first_stage, int(i)] = out_coord else: - ctd_list[int(i), second_stage] = tuple(round(x, decimals) for x in ctr_mass) + ctd_list[int(i), second_stage] = out_coord return POI(ctd_list, **msk_nii._extract_affine(), **args) diff --git a/TPTBox/core/poi_fun/poi_abstract.py b/TPTBox/core/poi_fun/poi_abstract.py index 7436033..0c9660d 100755 --- a/TPTBox/core/poi_fun/poi_abstract.py +++ b/TPTBox/core/poi_fun/poi_abstract.py @@ -39,6 +39,10 @@ DIMENSIONS = 3 +def _flatten(vert_label): + return [item for sublist in vert_label for item in (sublist if isinstance(sublist, list) else [sublist])] # type: ignore + + class _Abstract_POI_Definition: def __init__( self, @@ -107,6 +111,8 @@ def __init__( self.definition = definition self._len: int | None = None + __hash__ = None # explicitly mark as unhashable + def __set_name__(self, owner, name): self._name = "_" + name @@ -225,6 +231,7 @@ def str_to_int(self, key: str, subregion: bool): raise def str_to_int_list(self, *keys: int | str, subregion=False): + keys = _flatten(keys) out: list[int] = [] for k in keys: if isinstance(k, str): @@ -582,6 +589,8 @@ def remove(self, *label: tuple[int, int], inplace=False): return obj def extract_subregion(self, *location: Abstract_lvl | int, inplace=False): + location = _flatten(location) + location_values = tuple(l if isinstance(l, int) else l.value for l in location) extracted_centroids = POI_Descriptor() for x1, x2, y in self.centroids.items(): @@ -604,7 +613,9 @@ def extract_vert(self, *vert_label: int, inplace=False): def extract_vert_(self, *vert_label: int): return self.extract_vert(*vert_label, inplace=True) - def extract_region(self, *vert_label: int, inplace=False): + def extract_region(self, *vert_label: int | list[int], inplace=False): + # flatten list + vert_label = _flatten(vert_label) vert_labels = tuple(vert_label) extracted_centroids = POI_Descriptor() for x1, x2, y in self.centroids.items(): diff --git a/TPTBox/core/poi_fun/poi_global.py b/TPTBox/core/poi_fun/poi_global.py index 5aa1fd2..016bf72 100755 --- a/TPTBox/core/poi_fun/poi_global.py +++ b/TPTBox/core/poi_fun/poi_global.py @@ -1,20 +1,19 @@ from __future__ import annotations -import json from copy import deepcopy from pathlib import Path +###### GLOBAL POI ##### from typing_extensions import Self from TPTBox.core import poi from TPTBox.core.nii_poi_abstract import Has_Grid +from TPTBox.core.poi_fun import save_mkr from TPTBox.core.poi_fun.poi_abstract import Abstract_POI, POI_Descriptor from TPTBox.core.poi_fun.save_load import FORMAT_GLOBAL, load_poi, save_poi from TPTBox.core.vert_constants import Abstract_lvl, logging from TPTBox.logger.log_file import log -###### GLOBAL POI ##### - class POI_Global(Abstract_POI): """ @@ -111,6 +110,9 @@ def to_other_poi(self, ref: poi.POI | Self): def to_global(self): return self + def to_local(self, msk: Has_Grid): + return self.resample_from_to(msk) + def resample_from_to(self, msk: Has_Grid): return self.to_other(msk) @@ -176,78 +178,27 @@ def save( self, out_path, make_parents, additional_info, save_hint=save_hint, resample_reference=resample_reference, verbose=verbose ) - def save_mrk(self, filepath: str | Path, color=None, split_by_region=True, split_by_subregion=False): - """ - Save the POI data to a .mrk.json file in Slicer Markups format. - Automatically sets coordinate system based on itk_coords. - Includes level_one_info and level_two_info in the description. - Preserves metadata from `info` dictionary. - """ - if color is None: - color = self.info.get("color", [1.0, 0.0, 0.0]) - filepath = Path(filepath) - if not filepath.name.endswith(".mrk.json"): - filepath = filepath.parent / (filepath.stem + ".mrk.json") - coordinate_system = "LPS" if self.itk_coords else "RAS" - - # Create list of control points - from TPTBox import NII - from TPTBox.mesh3D.mesh_colors import get_color_by_label - - list_markups = {} - for region, subregion, coords in self.centroids.items(): - try: - name = self.level_two_info(subregion).name - except Exception: - name = subregion - try: - name2 = self.level_one_info(region).name - except Exception: - name2 = region - key = "P" - color2 = color - if split_by_region: - key += str(region) + "_" - color2 = get_color_by_label(region).rgb.tolist() - if split_by_subregion: - key += str(subregion) - color2 = get_color_by_label(region).rgb.tolist() - if key not in list_markups: - list_markups[key] = { - "type": "Fiducial", - "coordinateSystem": coordinate_system, - "locked": False, - "labelFormat": "%N-%d", - "controlPoints": [], - "display": { - "visibility": True, - "opacity": 1.0, - "color": color2.copy(), - "propertiesLabelVisibility": False, - }, - "description": "", # self.info, - } - - list_markups[key]["controlPoints"].append( - { - "id": f"{region}-{subregion}", - "label": f"{region}-{subregion}", - "description": name, - "associatedNodeID": name2, - "position": list(coords), - "orientation": [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0], - "selected": False, - "locked": False, - "visibility": True, - "positionStatus": "defined", - } - ) - mrk_data = { - "markups": list(list_markups.values()), - "schema": "https://raw.githubusercontent.com/slicer/slicer/master/Modules/Loadable/Markups/Resources/Schema/markups-schema-v1.0.3.json#", - # "coordinateSystem": coordinate_system, - } - - with open(filepath, "w") as f: - json.dump(mrk_data, f, indent=2) - log.on_save(f"Saved .mrk.json to {filepath}") + def save_mrk( + self: Self, + filepath: str | Path, + color=None, + split_by_region=False, + split_by_subregion=False, + add_points: bool = True, + add_lines: list[save_mkr.MKR_Lines] | None = None, + display: save_mkr.MKR_Display | dict = None, # type: ignore + pointLabelsVisibility=False, + glyphScale=5.0, + ): + save_mkr._save_mrk( + poi=self, + filepath=filepath, + color=color, + split_by_region=split_by_region, + split_by_subregion=split_by_subregion, + add_points=add_points, + add_lines=add_lines, + display=display, + pointLabelsVisibility=pointLabelsVisibility, + glyphScale=glyphScale, + ) diff --git a/TPTBox/core/poi_fun/save_load.py b/TPTBox/core/poi_fun/save_load.py index eeb75bd..276a88d 100644 --- a/TPTBox/core/poi_fun/save_load.py +++ b/TPTBox/core/poi_fun/save_load.py @@ -453,7 +453,7 @@ def _get_poi_idx_from_text(idx: str, label: str, centroids): def _load_mkr_POI(dict_mkr: dict): centroids = POI_Descriptor() - if "@schema" not in dict_mkr and "markups-schema-v1.0.3" not in dict_mkr["@schema"]: + if "@schema" not in dict_mkr or "markups-schema-v1.0.3" not in dict_mkr["@schema"]: log.on_warning( "this file is possible incompatible. Tested only with markups-schema-v1.0.3 and not", dict_mkr.get("@schmea", "No Schema") ) @@ -470,7 +470,7 @@ def _load_mkr_POI(dict_mkr: dict): log.on_warning("unknown coordinate system:", markup["coordinateSystem"]) continue if itk_coords is not None: - assert markup["coordinateSystem"] == "LPS" == itk_coords, "multiple rotations not supported" + assert (markup["coordinateSystem"] == "LPS") == itk_coords, "multiple rotations not supported" itk_coords = markup["coordinateSystem"] == "LPS" if markup.get("coordinateUnits", "mm") != "mm": log.on_warning("unknown coordinateUnits:", markup["coordinateUnits"]) diff --git a/TPTBox/core/poi_fun/save_mkr.py b/TPTBox/core/poi_fun/save_mkr.py new file mode 100644 index 0000000..d99357e --- /dev/null +++ b/TPTBox/core/poi_fun/save_mkr.py @@ -0,0 +1,338 @@ +import json +import random +from pathlib import Path + +###### GLOBAL POI ##### +from typing import TYPE_CHECKING, Any, Literal, Optional, TypedDict + +import numpy as np +from typing_extensions import NotRequired + +from TPTBox.logger.log_file import log +from TPTBox.mesh3D.mesh_colors import RGB_Color, get_color_by_label + +if TYPE_CHECKING: + from TPTBox import POI_Global +CoordinateSystem = Literal["LPS", "RAS"] +ControlPointStatus = Literal["undefined", "preview", "defined"] +MarkupType = Literal["Fiducial", "Line", "Angle", "Curve", "ClosedCurve", "Plane", "ROI", "MeasurementVolume"] +VolumeUnit = Literal["mm3", "cm3"] + + +class MKR_Display(TypedDict, total=False): + visibility: NotRequired[bool] + opacity: NotRequired[float] + color: NotRequired[list[float]] + selectedColor: NotRequired[list[float]] + activeColor: NotRequired[list[float]] + propertiesLabelVisibility: NotRequired[bool] + pointLabelsVisibility: NotRequired[bool] + textScale: NotRequired[float] + glyphType: NotRequired[str] + glyphScale: NotRequired[float] + glyphSize: NotRequired[float] + useGlyphScale: NotRequired[bool] + sliceProjection: NotRequired[bool] + sliceProjectionUseFiducialColor: NotRequired[bool] + sliceProjectionOutlinedBehindSlicePlane: NotRequired[bool] + sliceProjectionColor: NotRequired[list[float]] + sliceProjectionOpacity: NotRequired[float] + lineThickness: NotRequired[float] + lineColorFadingStart: NotRequired[float] + lineColorFadingEnd: NotRequired[float] + lineColorFadingSaturation: NotRequired[float] + lineColorFadingHueOffset: NotRequired[float] + handlesInteractive: NotRequired[bool] + translationHandleVisibility: NotRequired[bool] + rotationHandleVisibility: NotRequired[bool] + scaleHandleVisibility: NotRequired[bool] + interactionHandleScale: NotRequired[float] + snapMode: NotRequired[str] + + +class ControlPoint(TypedDict, total=False): + id: str + label: str + description: str + associatedNodeID: str + position: list[float] # length 3 + orientation: list[float] # length 9 + selected: bool + locked: bool + visibility: bool + positionStatus: Literal["undefined", "preview", "defined"] + + +class MKR_Lines(TypedDict): + key_points: list[tuple[int, int]] + color: NotRequired[list[float]] + name: NotRequired[str] + display: NotRequired[MKR_Display] + controlPoint: NotRequired[ControlPoint] + + +class Markup(TypedDict, total=False): + type: MarkupType + name: str + coordinateSystem: CoordinateSystem + coordinateUnits: str | list[str] + locked: bool + fixedNumberOfControlPoints: bool + labelFormat: str + lastUsedControlPointNumber: int + controlPoints: list[ControlPoint] + + # Optional (ROI/Plane) + roiType: str | None + insideOut: bool | None + planeType: str | None + sizeMode: str | None + autoScalingSizeFactor: float | None + center: list[float] | None + normal: list[float] | None + size: list[float] | None + planeBounds: list[float] | None + objectToBase: list[float] | None + baseToNode: list[float] | None + orientation: list[float] | None + display: MKR_Display + measurements: Any + + +class MeasurementVolumeMarkup(Markup, total=False): + type: Literal["MeasurementVolume"] + volume: float + volumeUnit: VolumeUnit + surfaceArea: float + boundingBox: list[float] + + +MKR_DEFINITION = MKR_Lines | dict + + +def _get_display_dict( + display: MKR_Display | dict, color=None, selectedColor=None, activeColor=None, pointLabelsVisibility=False, glyphScale=1.0 +): + if activeColor is None: + activeColor = [0.4, 1.0, 0.0] + if selectedColor is None: + selectedColor = [1.0, 0.5, 0.5] + if color is None: + color = [0.4, 1.0, 1.0] + # hard cast to float, or all of "display" will be ignored... + return { + "visibility": display.get("visibility", True), + "opacity": float(display.get("opacity", 1.0)), + "color": display.get("color", color), + "selectedColor": display.get("selectedColor", selectedColor), + "activeColor": display.get("activeColor", activeColor), + # Add other properties as needed, using similar patterns: + "propertiesLabelVisibility": display.get("propertiesLabelVisibility", False), + "pointLabelsVisibility": display.get("pointLabelsVisibility", pointLabelsVisibility), + "textScale": float(display.get("textScale", 3.0)), + "glyphType": display.get("glyphType", "Sphere3D"), + "glyphScale": display.get("glyphScale", glyphScale), + "glyphSize": float(display.get("glyphSize", 5.0)), + "useGlyphScale": display.get("useGlyphScale", True), + "sliceProjection": display.get("sliceProjection", False), + "sliceProjectionUseFiducialColor": display.get("sliceProjectionUseFiducialColor", True), + "sliceProjectionOutlinedBehindSlicePlane": display.get("sliceProjectionOutlinedBehindSlicePlane", False), + "sliceProjectionColor": display.get("sliceProjectionColor", [1.0, 1.0, 1.0]), + "sliceProjectionOpacity": float(display.get("sliceProjectionOpacity", 0.6)), + "lineThickness": float(display.get("lineThickness", 0.2)), + "lineColorFadingStart": float(display.get("lineColorFadingStart", 1.0)), + "lineColorFadingEnd": float(display.get("lineColorFadingEnd", 10.0)), + "lineColorFadingSaturation": float(display.get("lineColorFadingSaturation", 1.0)), + "lineColorFadingHueOffset": float(display.get("lineColorFadingHueOffset", 0.0)), + "handlesInteractive": display.get("handlesInteractive", False), + "translationHandleVisibility": display.get("translationHandleVisibility", False), + "rotationHandleVisibility": display.get("rotationHandleVisibility", False), + "scaleHandleVisibility": display.get("scaleHandleVisibility", False), + "interactionHandleScale": float(display.get("interactionHandleScale", 3.0)), + "snapMode": display.get("snapMode", "toVisibleSurface"), + } + + +def _get_markup_color(definition: MKR_DEFINITION, region, subregion, split_by_region=False, split_by_subregion=False): + color = definition.get("color", None) + if color is None: + if split_by_region: + color = get_color_by_label(region) + if split_by_subregion: + color = get_color_by_label(subregion + 10) if subregion == 83 else get_color_by_label(subregion).rgb + if color is None: + color = RGB_Color.init_list([random.randint(20, 245), random.randint(20, 245), random.randint(20, 245)]) + if isinstance(color, RGB_Color): # or str(type(color)) == "RGB_Color": + color = (color.rgb / 255.0).tolist() + if isinstance(color, np.ndarray): + color = color.tolist() + assert isinstance(color, list), (color, type(color)) + if max(color) > 2: + color = [float(c) / 255.0 for c in color] + return color + + +def _get_control_point(cp: ControlPoint, position, id_name="", label="", name="", name2="") -> ControlPoint: + return { + "id": cp.get("id", id_name), + "label": cp.get("label", label), + "description": cp.get("description", name), + "associatedNodeID": cp.get("associatedNodeID", name2), + "position": cp.get("position", position), + "orientation": cp.get("orientation", [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]), + "selected": cp.get("selected", True), + "locked": cp.get("locked", False), + "visibility": cp.get("visibility", True), + "positionStatus": cp.get("positionStatus", "defined"), + } + + +def _make_default_markup( + markup_type: MarkupType, name, coordinateSystem: CoordinateSystem, controlPoints=None, display=None +) -> Markup | MeasurementVolumeMarkup: + if controlPoints is None: + controlPoints = [] + base: Markup = { + "type": markup_type, + "coordinateSystem": coordinateSystem, + "coordinateUnits": "mm", + "locked": True, + "fixedNumberOfControlPoints": False, + "labelFormat": "%N-%d", + "lastUsedControlPointNumber": 0, + "controlPoints": controlPoints, + } + if name is not None: + base["name"] = name + if display is not None: + base["display"] = display + if markup_type == "ROI": + base.update({"roiType": "Box", "insideOut": False}) + elif markup_type == "Plane": + base.update({"planeType": "PointNormal", "sizeMode": "auto", "autoScalingSizeFactor": 1.0}) + elif markup_type == "MeasurementVolume": + mv: MeasurementVolumeMarkup = { + **base, + "type": "MeasurementVolume", + "volume": 0.0, + "volumeUnit": "mm3", + "surfaceArea": 0.0, + "boundingBox": [0.0] * 6, + } + return mv + + return base + + +def _get_markup_lines( + definition: MKR_Lines, + poi: "POI_Global", + coordinate_system: Literal["LPS", "RAS"], + split_by_region=False, + split_by_subregion=False, + display=None, +): + if display is None: + display = {} + key_points = definition.get("key_points") + + region, subregion = key_points[0] + color = _get_markup_color(definition, region, subregion, split_by_region, split_by_subregion) + display = _get_display_dict(display | definition.get("display", {}), selectedColor=color) + + controlPoints = [] + for region, subregion in key_points: + name, name2 = get_desc(poi, region, subregion) + controlPoints.append( + _get_control_point( + definition.get("controlPoint", {}), poi[region, subregion], f"{region}-{subregion}", f"{region}-{subregion}", name, name2 + ) + ) + return _make_default_markup("Line", definition.get("name"), coordinate_system, controlPoints=controlPoints, display=display) + + +def get_desc(self: "POI_Global", region, subregion): + try: + name = self.level_two_info(subregion).name + except Exception: + name = str(subregion) + try: + name2 = self.level_one_info(region).name + except Exception: + name2 = str(region) + return name, name2 + + +def _save_mrk( + poi: "POI_Global", + filepath: str | Path, + color=None, + split_by_region=True, + split_by_subregion=False, + add_points: bool = True, + add_lines: list[MKR_Lines] | None = None, + display: MKR_Display | dict = None, # type: ignore + pointLabelsVisibility=False, + glyphScale=1.0, + **args, +): + """ + Save the POI data to a .mrk.json file in Slicer Markups format. + Automatically sets coordinate system based on itk_coords. + Includes level_one_info and level_two_info in the description. + Preserves metadata from `info` dictionary. + """ + if display is None: + display = {} + if add_lines is None: + add_lines = [] + filepath = Path(filepath) + if not filepath.name.endswith(".mrk.json"): + filepath = filepath.parent / (filepath.stem + ".mrk.json") + coordinate_system: CoordinateSystem = "LPS" if poi.itk_coords else "RAS" + list_markups = {} + # ADD POINTS + addendum = { + "pointLabelsVisibility": pointLabelsVisibility, + "glyphScale": float(glyphScale), + **args, + } + display = addendum | display + if add_points: + # Create list of control points + for region, subregion, coords in poi.centroids.items(): + key = "P" + if split_by_region: + key += str(region) + "_" + if split_by_subregion: + key += str(subregion) + name, name2 = get_desc(poi, region, subregion) + if key not in list_markups: + list_markups[key] = _make_default_markup( + "Fiducial", + key, + coordinate_system, + controlPoints=[], + display=_get_display_dict( + display, + selectedColor=_get_markup_color( + {"color": color}, region, subregion, split_by_region=split_by_subregion, split_by_subregion=split_by_subregion + ), + **addendum, + ), + ) + list_markups[key]["controlPoints"].append( + _get_control_point({}, coords, f"{region}-{subregion}", f"{region}-{subregion}", name, name2) + ) + markups = list(list_markups.values()) + # Lines + [markups.append(_get_markup_lines(line, poi, coordinate_system, split_by_region, split_by_subregion, display)) for line in add_lines] + mrk_data = { + "@schema": "https://raw.githubusercontent.com/slicer/slicer/master/Modules/Loadable/Markups/Resources/Schema/markups-schema-v1.0.3.json#", + "markups": markups, + } + print(markups[-1].get("display")) + filepath.unlink(missing_ok=True) + with open(filepath, "w") as f: + json.dump(mrk_data, f, indent=2) + log.on_save(f"Saved .mrk.json to {filepath}") diff --git a/TPTBox/core/poi_fun/vertebra_direction.py b/TPTBox/core/poi_fun/vertebra_direction.py index 15541c8..774416e 100644 --- a/TPTBox/core/poi_fun/vertebra_direction.py +++ b/TPTBox/core/poi_fun/vertebra_direction.py @@ -1,6 +1,7 @@ from __future__ import annotations from collections.abc import Sequence +from warnings import warn import numpy as np from numpy.linalg import norm @@ -25,6 +26,7 @@ def calc_orientation_of_vertebra_PIR( do_fill_back: bool = False, spine_plot_path: None | str = None, save_normals_in_info=False, + _orientation_version=0, ) -> tuple[POI, NII | None]: """Calculate the orientation of vertebrae using PIR (Posterior, Inferior, Right) DIRECTIONS. @@ -45,6 +47,8 @@ def calc_orientation_of_vertebra_PIR( assert poi is None or poi.zoom is not None from TPTBox import calc_centroids + if _orientation_version != 0: + warn("out dated _orientation_version; Set is to 0", stacklevel=1) # Step 1 compute the up direction # check if label 50 is already computed in POI if poi is None or spline_subreg_point_id.value not in poi.keys_subregion(): @@ -56,16 +60,22 @@ def calc_orientation_of_vertebra_PIR( intersection_target = [Location.Spinosus_Process, Location.Arcus_Vertebrae] # We compute everything in iso space subreg_iso = subreg.rescale().reorient() - - target_labels = subreg_iso.extract_label(intersection_target).get_array() - # we want to see more of the Spinosus_Process and Arcus_Vertebrae than we cut with the plane. Should reduce randomness. - # The ideal solution would be to make a projection onto the plane. Instead we fill values that have a vertical distanc of 10 mm up and down. This approximates the projection on to the plane. - # Without this we have the chance to miss most of the arcus and spinosus, witch leads to instability in the direction. - # TODO this will fail if the vertebra is not roughly aligned with S/I-direction - for _ in range(15): - target_labels[:, :-1] += target_labels[:, 1:] - target_labels[:, 1:] += target_labels[:, :-1] - target_labels = np.clip(target_labels, 0, 1) + if _orientation_version == 0: + target_labels = subreg_iso.extract_label(intersection_target).get_array() + # we want to see more of the Spinosus_Process and Arcus_Vertebrae than we cut with the plane. Should reduce randomness. + # The ideal solution would be to make a projection onto the plane. Instead we fill values that have a vertical distanc of 10 mm up and down. This approximates the projection on to the plane. + # Without this we have the chance to miss most of the arcus and spinosus, witch leads to instability in the direction. + # TODO this will fail if the vertebra is not roughly aligned with S/I-direction + for _ in range(15): + target_labels[:, :-1] += target_labels[:, 1:] + target_labels[:, 1:] += target_labels[:, :-1] + target_labels = np.clip(target_labels, 0, 1) + elif _orientation_version == 1: + target_labels = subreg_iso.extract_label(intersection_target).get_array() + elif _orientation_version == 2: + target_labels = subreg_iso.extract_label(list(range(40, 49))).get_array() + else: + raise NotImplementedError(_orientation_version) out = target_labels * 0 fill_back_nii = subreg_iso.copy() if do_fill_back else None fill_back = out.copy() if do_fill_back else None @@ -120,12 +130,7 @@ def calc_orientation_of_vertebra_PIR( arr = subreg_sar.get_array() fill_back_nii.set_array_(arr) - ret = calc_centroids( - subreg_iso.set_array(out), - second_stage=subreg_id, - extend_to=poi_iso.copy(), - inplace=True, - ) + ret = calc_centroids(subreg_iso.set_array(out), second_stage=subreg_id, extend_to=poi_iso.copy(), inplace=True) poi._vert_orientation_pir = {} if save_normals_in_info: diff --git a/TPTBox/core/poi_fun/vertebra_pois_non_centroids.py b/TPTBox/core/poi_fun/vertebra_pois_non_centroids.py index 1aa7625..3c16058 100755 --- a/TPTBox/core/poi_fun/vertebra_pois_non_centroids.py +++ b/TPTBox/core/poi_fun/vertebra_pois_non_centroids.py @@ -244,6 +244,7 @@ def compute_non_centroid_pois( # noqa: C901 subreg: NII, _vert_ids: Sequence[int] | None = None, log: Logger_Interface = _log, + _orientation_version=0, ): if _vert_ids is None: _vert_ids = vert.unique() @@ -257,7 +258,9 @@ def compute_non_centroid_pois( # noqa: C901 ### Calc vertebra direction; We always need them, so we just compute them. ### sub_regions = poi.keys_subregion() if any(a.value not in sub_regions for a in vert_directions): - poi, _ = calc_orientation_of_vertebra_PIR(poi, vert, subreg, do_fill_back=False, save_normals_in_info=False) + poi, _ = calc_orientation_of_vertebra_PIR( + poi, vert, subreg, do_fill_back=False, save_normals_in_info=False, _orientation_version=_orientation_version + ) [locations.remove(i) for i in vert_directions if i in locations] locations = [pois_computed_by_side_effect.get(l.value, l) for l in locations] diff --git a/TPTBox/core/vert_constants.py b/TPTBox/core/vert_constants.py index f3be05d..79d4e94 100755 --- a/TPTBox/core/vert_constants.py +++ b/TPTBox/core/vert_constants.py @@ -1,6 +1,5 @@ from __future__ import annotations -import typing from collections.abc import Sequence from enum import Enum from typing import TYPE_CHECKING, Literal, NoReturn, Union @@ -230,6 +229,7 @@ class Full_Body_Instance(Abstract_lvl): subcutaneous_fat = 57 muscle_other = 58 inner_fat = 59 + ignore = 60 class Lower_Body(Abstract_lvl): @@ -357,7 +357,21 @@ def cervical(cls): @classmethod def thoracic(cls): - return (cls.T1, cls.T2, cls.T3, cls.T4, cls.T5, cls.T6, cls.T7, cls.T8, cls.T9, cls.T10, cls.T11, cls.T12, cls.T13) + return ( + cls.T1, + cls.T2, + cls.T3, + cls.T4, + cls.T5, + cls.T6, + cls.T7, + cls.T8, + cls.T9, + cls.T10, + cls.T11, + cls.T12, + cls.T13, + ) @classmethod def lumbar(cls): diff --git a/TPTBox/images/poi_preview.png b/TPTBox/images/poi_preview.png new file mode 100644 index 0000000..2a8a457 Binary files /dev/null and b/TPTBox/images/poi_preview.png differ diff --git a/TPTBox/images/snp2D_example.png b/TPTBox/images/snp2D_example.png new file mode 100644 index 0000000..fd7a096 Binary files /dev/null and b/TPTBox/images/snp2D_example.png differ diff --git a/TPTBox/images/snp3D_example.jpg b/TPTBox/images/snp3D_example.jpg new file mode 100644 index 0000000..c2db755 Binary files /dev/null and b/TPTBox/images/snp3D_example.jpg differ diff --git a/TPTBox/mesh3D/html_preview.py b/TPTBox/mesh3D/html_preview.py new file mode 100644 index 0000000..a075ce5 --- /dev/null +++ b/TPTBox/mesh3D/html_preview.py @@ -0,0 +1,183 @@ +from dataclasses import dataclass +from pathlib import Path +from typing import Literal + +import pyvista as pv + +# from utils.filepaths import filepath_data, filepath_dataset +# from utils.poi_plotter import * +# from utils.poi_surface import project_pois_onto_segmentation_surface +from TPTBox import POI, Has_Grid, Log_Type, Print_Logger +from TPTBox.core.nii_wrapper import NII, to_nii +from TPTBox.mesh3D.mesh import Mesh3D, POIMesh, SegmentationMesh +from TPTBox.mesh3D.mesh_colors import RGB_Color, get_color_by_label + + +def _add_mesh(pl, mesh: pv.PolyData | SegmentationMesh, color: str | RGB_Color, opacity: float = 1.0): + if isinstance(mesh, Mesh3D): + mesh = mesh.mesh + if isinstance(color, RGB_Color): + color = color.rgb.tolist() + + pl.add_mesh(mesh, opacity=opacity, color=color) + + +@dataclass +class Preview_Settings: + """ + Configuration for visualizing a `NII` or `POI` object as a mesh. + + Attributes: + obj (NII | POI): The image or point-of-interest object to visualize. + offset (tuple[float, float, float] | None): Optional (PIR) spatial offset for rendering. + opacity (float): Mesh opacity value between 0 and 1. + color (Literal["auto"] | str | None): Desired mesh color. Defaults to "auto". + binary (bool): Whether to render the object as a binary segmentation. + """ + + obj: NII | POI + offset: tuple[float, float, float] | None = None # PIR + opacity: float = 1.0 + color: Literal["auto"] | str | None = "auto" # noqa: PYI051 + binary = False + + def _get_mesh( + self, + rescale_to_iso, + poi_size, + default_color_nii="bisque", + default_poi_nii="red", + ): + """ + Generates one or more meshes from the underlying image or POI. + + Args: + rescale_to_iso (bool): Whether to rescale to isotropic spacing. + poi_size (float): Size factor for rendering POI objects. + default_color_nii (str): Default color for NII objects. + default_poi_nii (str): Default color for POI objects. + + Yields: + Tuple[Mesh, str]: A tuple of the mesh and its associated color. + """ + img = self.obj + if (self.color is None or self.color == "auto") and not isinstance(img, NII): + self.color = default_poi_nii + if self.binary or (self.color is not None and self.color != "auto"): + if isinstance(img, NII): + mesh = SegmentationMesh.from_segmentation_nii(img, rescale_to_iso=rescale_to_iso) + is_poi = False + elif isinstance(img, POI): + mesh = POIMesh(img, rescale_to_iso=False, regions=None, subregions=None, size_factor=poi_size) + is_poi = True + else: + raise NotImplementedError(f"{img.__class__} is not supported") + if self.offset is not None: + mesh = mesh.get_mesh_with_offset(self.offset) + color = self.color + if color is None or color == "auto": + color = default_poi_nii if is_poi else default_color_nii + yield mesh, color + elif isinstance(img, NII): + for u in img.unique(): + color = get_color_by_label(u) + mesh = SegmentationMesh.from_segmentation_nii(img.extract_label(u), rescale_to_iso=rescale_to_iso) + if self.offset is not None: + mesh = mesh.get_mesh_with_offset(self.offset) + yield mesh, color + else: + raise NotImplementedError("auto poi color") + + +l = Print_Logger() + +offset = tuple[float, float, float] + + +def make_html_preview( + images: list[NII | POI | Preview_Settings], + html_out: str | Path | None, + background="black", + rescale_to_iso=False, + poi_size=1.7, + logger=l, + show=False, + default_color_nii="bisque", + default_poi_nii="red", + ref_spacing: Has_Grid | None = None, + auto_rescale_to_ref=False, +): + """ + Render NII or POI objects as meshes in an interactive 3D HTML viewer. + + Args: + images (list[NII | POI | Preview_Settings]): List of images or wrapped settings to visualize. + html_out (str | Path | None): Output file path for HTML export. Must end in `.html`. + background (str): Background color of the 3D viewer. + rescale_to_iso (bool): Whether to rescale NII images to isotropic voxel spacing. + poi_size (float): Size factor for point-of-interest rendering. + logger (Print_Logger): Logger for output messages. + show (bool): If True, shows the viewer after rendering. + default_color_nii (str): Default color for NII-based visualizations. + default_poi_nii (str): Default color for POI-based visualizations. + ref_spacing (Has_Grid | None): Optional reference object to resample all images to a common spacing. + auto_rescale_to_ref (bool): Whether to resample all objects to the reference spacing automatically. + + Raises: + AssertionError: If `html_out` is invalid or neither `html_out` nor `show` is provided. + """ + assert (html_out is None) or str(html_out).endswith(".html"), f"not a valid file ending {html_out}; expected .html" + assert html_out is not None or show, "show must be True or html_out must be set" + pl: pv.Plotter = pv.Plotter() # type: ignore + pl.set_background(background, top=None) # type: ignore + pl.add_axes() # type: ignore + + images_ = [Preview_Settings(obj) if not isinstance(obj, Preview_Settings) else obj for obj in images] + if ref_spacing is not None: + auto_rescale_to_ref = True + if auto_rescale_to_ref and ref_spacing is None: + for a in images_: + if isinstance(a, Has_Grid): + ref_spacing = a + break + if auto_rescale_to_ref: + assert ref_spacing is not None + + def resample(obj: Preview_Settings) -> Preview_Settings: + obj.obj = obj.obj.resample_from_to(ref_spacing) + return obj + + images_ = [resample(obj) for obj in images_] + + for setting in images_: + for m, color in setting._get_mesh( + poi_size=poi_size, rescale_to_iso=rescale_to_iso, default_color_nii=default_color_nii, default_poi_nii=default_poi_nii + ): + _add_mesh(pl, m, opacity=setting.opacity, color=color) + + if html_out is not None: + pl.export_html(html_out) + logger.print(f"Saved scene into {html_out}", Log_Type.SAVE) + if show: + pl.show() + + +if __name__ == "__main__": + p = "/media/data/robert/dataset-myelom/dataset-myelom/derivatives-leg/MM00191/ses-20180502" + nii = to_nii(Path(p, "sub-MM00191_ses-20180502_sequ-202_seg-leg-left_msk.nii.gz"), True) + poi = POI.load(Path(p, "sub-MM00191_ses-20180502_sequ-202_seg-leg-subreg-left_poi.mrk.json"), reference=nii) + nii_r = to_nii(Path(p, "sub-MM00191_ses-20180502_sequ-202_seg-leg-right_msk.nii.gz"), True) + poi_r = POI.load(Path(p, "sub-MM00191_ses-20180502_sequ-202_seg-leg-subreg-right_poi.mrk.json"), reference=nii) + make_html_preview( + [ + nii, + poi, + nii_r, + poi_r, + Preview_Settings(nii, offset=(0, 0, -250), opacity=0.5), + Preview_Settings(poi, offset=(0, 0, -250)), + ], + Path(p, "test.html"), + show=True, + poi_size=10, + ) diff --git a/TPTBox/mesh3D/snapshot3D.py b/TPTBox/mesh3D/snapshot3D.py index 18ad63d..5fe5af3 100644 --- a/TPTBox/mesh3D/snapshot3D.py +++ b/TPTBox/mesh3D/snapshot3D.py @@ -103,7 +103,7 @@ def make_snapshot3D( window_size = (width * len(ids_list), nii.shape[1]) with Xvfb(): scene = window.Scene() - show_m = window.ShowManager(scene, size=window_size, reset_camera=False) + show_m = window.ShowManager(scene=scene, size=window_size, reset_camera=False) show_m.initialize() for i, ids in enumerate(ids_list): x = width * i @@ -120,25 +120,28 @@ def make_snapshot3D( return out_img -def make_sub_snapshot_parallel( - imgs: list[Path], +def make_snapshot3D_parallel( + imgs: list[Path | str], output_paths: list[Image_Reference], - orientation: VIEW | list[VIEW] = "A", + view: VIEW | list[VIEW] = "A", ids_list: list[Sequence[int]] | None = None, smoothing=20, - resolution=2, + resolution: float = 2, cpus=10, width_factor=1.0, + override=True, ): ress = [] with Pool(cpus) as p: # type: ignore for out_path, img in zip_strict(output_paths, imgs): + if not override and Path(out_path).exists(): + continue res = p.apply_async( make_snapshot3D, kwds={ "output_path": out_path, "img": img, - "view": orientation, + "view": view, "ids_list": ids_list, "smoothing": smoothing, "resolution": resolution, @@ -152,6 +155,9 @@ def make_sub_snapshot_parallel( p.join() +make_sub_snapshot_parallel = make_snapshot3D_parallel + + def _plot_sub_seg(scene: window.Scene, nii: NII, x, y, smoothing, orientation: VIEW): if orientation == "A": # [ axis1(w) ] [ axis2(h) ] [ view in ] @@ -174,11 +180,27 @@ def _plot_sub_seg(scene: window.Scene, nii: NII, x, y, smoothing, orientation: V raise NotImplementedError() for idx in nii.unique(): color = get_color_by_label(idx) - cont_actor = _plot_mask(nii.extract_label(idx), affine, x, y, smoothing=smoothing, color=color, opacity=1) + cont_actor = _plot_mask( + nii.extract_label(idx), + affine, + x, + y, + smoothing=smoothing, + color=color, + opacity=1, + ) scene.add(cont_actor) -def _plot_mask(nii: NII, affine, x_current, y_current, smoothing=10, color: list | np.ndarray = _red, opacity=1): +def _plot_mask( + nii: NII, + affine, + x_current, + y_current, + smoothing=10, + color: list | np.ndarray = _red, + opacity=1, +): mask = nii.get_seg_array() cont_actor = _contour_from_roi_smooth(mask, affine=affine, color=color, opacity=opacity, smoothing=smoothing) cont_actor.SetPosition(x_current, y_current, 0) diff --git a/TPTBox/registration/deepali/deepali_model.py b/TPTBox/registration/deepali/deepali_model.py index 9f624d0..581ac62 100644 --- a/TPTBox/registration/deepali/deepali_model.py +++ b/TPTBox/registration/deepali/deepali_model.py @@ -6,7 +6,7 @@ import time from collections.abc import Sequence from pathlib import Path -from typing import Literal, Optional, Union +from typing import Literal, Optional, Self, Union import torch import torch.optim @@ -15,7 +15,9 @@ from deepali.core import Grid as Deepali_Grid from deepali.data import Image as deepaliImage from deepali.modules import TransformImage -from deepali.spatial import SpatialTransform +from deepali.spatial import ( + SpatialTransform, +) from TPTBox import NII, POI, Image_Reference, to_nii from TPTBox.core.compat import zip_strict @@ -55,18 +57,33 @@ def _load_config(path): def _warp_image( - source_image: deepaliImage, target_grid: Deepali_Grid, transform: SpatialTransform, mode="linear", device=default_device, inverse=False + source_image: deepaliImage, + target_grid: Deepali_Grid, + transform: SpatialTransform, + mode="linear", + device=default_device, + inverse=False, ) -> torch.Tensor: if inverse: transform = transform.inverse(update_buffers=True) - warp_func = TransformImage(target=target_grid, source=target_grid, sampling=mode, padding=source_image.min()).to(device) + warp_func = TransformImage( + target=target_grid, + source=target_grid, + sampling=mode, + padding=source_image.min(), + ).to(device) with torch.inference_mode(): data = warp_func(transform.tensor(), source_image.to(device)) return data def _warp_poi( - poi_moving: POI, target_grid: TPTBox_Grid, transform: SpatialTransform, align_corners, device=default_device, inverse=True + poi_moving: POI, + target_grid: TPTBox_Grid, + transform: SpatialTransform, + align_corners, + device=default_device, + inverse=True, ) -> POI: keys: list[tuple[int, int]] = [] points = [] @@ -75,6 +92,7 @@ def _warp_poi( points.append((x, y, z)) print(key, key2, (x, y, z)) with torch.inference_mode(): + assert len(points) != 0 # Ensure points is not empty data = torch.Tensor(points) transform.to(device) # data2 = data @@ -94,6 +112,36 @@ def _warp_poi( return out_poi +def _warp_points( + points, + axes: Axes, + to_axes: Axes, + grid: Deepali_Grid, + to_grid: Deepali_Grid, + transform: SpatialTransform, + device=default_device, + inverse=True, +) -> torch.Tensor: + """ + Warp points using a spatial transform. + Args: + points (list): List of points to warp: (b,n) b points with n coordinates. + transform (SpatialTransform): Spatial transform to apply. + align_corners (bool): Whether to align corners during warping. + device (torch.device, optional): Device to perform computation on. Defaults to default_device. + inverse (bool, optional): Whether to apply the inverse transform. Defaults to True. + """ + with torch.inference_mode(): + data = torch.Tensor(points) + transform.to(device) + # data2 = data + if inverse: + transform = transform.inverse(update_buffers=True) + data = transform.points(data.to(device), axes=axes, to_axes=to_axes, grid=grid, to_grid=to_grid) + + return data.cpu() + + class General_Registration(DeepaliPairwiseImageTrainer): """ A class for performing deformable registration between a fixed and moving image. @@ -109,6 +157,8 @@ def __init__( self, fixed_image: Image_Reference, moving_image: Image_Reference, + fixed_seg: Image_Reference | None = None, + moving_seg: Image_Reference | None = None, reference_image: Image_Reference | None = None, source_pset=None, target_pset=None, @@ -123,14 +173,13 @@ def __init__( fixed_mask: Image_Reference | None = None, moving_mask: Image_Reference | None = None, # normalize - normalize_strategy: Optional[ - Literal["auto", "CT", "MRI"] - ] = "auto", # Override on_normalize for finer normalization schema or normalize before and set to None. auto: [min,max] -> [0,1]; None: Do noting + normalize_strategy: Literal["auto", "CT", "MRI"] + | None = "auto", # Override on_normalize for finer normalization schema or normalize before and set to None. auto: [min,max] -> [0,1]; None: Do noting # Pyramid - pyramid_levels: Optional[int] = None, # 1/None = no pyramid; int: number of stacks, tuple from to (0 is finest) + pyramid_levels: int | None = None, # 1/None = no pyramid; int: number of stacks, tuple from to (0 is finest) finest_level: int = 0, - coarsest_level: Optional[int] = None, - pyramid_finest_spacing: Optional[Sequence[int] | torch.Tensor] = None, + coarsest_level: int | None = None, + pyramid_finest_spacing: Sequence[int] | torch.Tensor | None = None, pyramid_min_size=16, dims=("x", "y", "z"), align=False, @@ -138,16 +187,16 @@ def __init__( transform_args: dict | None = None, transform_init: PathStr | None = None, # reload initial flowfield from file optim_name="Adam", # Optimizer name defined in torch.optim. or override on_optimizer finer control - lr: float | list[float] = 0.01, # Learning rate + lr: float | Sequence[float] = 0.01, # Learning rate optim_args=None, # args of Optimizer with out lr smooth_grad=0.0, verbose=99, max_steps: int | Sequence[int] = 250, # Early stopping. override on_converged finer control max_history: int | None = None, min_value=0.0, # Early stopping. override on_converged finer control - min_delta=0.0, # Early stopping. override on_converged finer control + min_delta: float | Sequence[float] = 0.0, # Early stopping. override on_converged finer control loss_terms: list[LOSS | str] | dict[str, LOSS] | dict[str, str] | dict[str, tuple[str, dict]] | None = None, - weights: list[float] | dict[str, float] | None = None, + weights: list[float] | dict[str, float | list[float]] | None = None, auto_run=True, ) -> None: if device is None: @@ -160,6 +209,8 @@ def __init__( reference_image = fix else: fix = fix.resample_from_to(reference_image) + if fixed_seg is not None: + fixed_seg = to_nii(fixed_seg, True).resample_from_to(reference_image) ## Resample and save images source = mov # .resample_from_to_(reference_image) ## Load configuration and perform registration @@ -167,9 +218,13 @@ def __init__( self.input_grid = mov.to_gird() self.source_landmarks_poi = source_landmarks self.target_landmarks_poi = target_landmarks + self._is_inverted = False + super().__init__( source=source.to_deepali(), target=fix.to_deepali(), + source_seg=to_nii(moving_seg, True).to_deepali() if fixed_seg is not None else None, + target_seg=to_nii(fixed_seg, True).to_deepali() if moving_seg is not None else None, source_pset=source_pset, target_pset=target_pset, source_landmarks=source_landmarks, @@ -234,6 +289,90 @@ def __init__( # # .unsqueeze(0) # # ) # return data.clone() + def inverse(self) -> Self: + """ + Invert the registration transformation. + + Returns: + Self: The instance with the inverted transformation. + """ + self._is_inverted = not self._is_inverted + from copy import copy + + out = copy(self) + out._is_inverted = not self._is_inverted + return out + + # def on_run_end( + # self, + # grid_transform, + # target_image: deepaliImage, + # source_image: deepaliImage, + # target_image_seg: deepaliImage, + # source_image_seg: deepaliImage, + # opt, + # lr_sq, + # num_steps, + # level, + # ): + # import numpy as np + # + # arr_target = ( + # target_image.tensor() + # .squeeze() + # .permute(2, 1, 0) + # .detach() + # .cpu() + # .float() + # .numpy() + # ) + # grid = NII.from_deepali_grid(target_image.grid()) + # nii_target = grid.make_nii(arr_target, False) + # nii_target.save( + # f"/DATA/NAS/datasets_processed/CT_spine/dataset-myelom/target_img{level}.nii.gz" + # ) + # arr_source = ( + # source_image.tensor() + # .squeeze() + # .permute(2, 1, 0) + # .detach() + # .cpu() + # .float() + # .numpy() + # ) + # nii_source = grid.make_nii(arr_source, False) + # nii_source.save( + # f"/DATA/NAS/datasets_processed/CT_spine/dataset-myelom/source_img{level}.nii.gz" + # ) + # arr = source_image_seg.tensor().permute(0, 3, 2, 1).detach().cpu().numpy() + # + # arr_new_source_seg = np.zeros(arr.shape[-3:]) + # print(arr_new_source_seg.shape) + # print(arr.shape) + # for i in range(arr.shape[0]): + # arr_new_source_seg[arr[i] >= 0.5] = i + # nii_source = grid.make_nii(arr_new_source_seg.astype(np.uint16), True) + # nii_source.save( + # f"/DATA/NAS/datasets_processed/CT_spine/dataset-myelom/source{level}.nii.gz" + # ) + # arr_target_seg = target_image_seg.tensor().permute(0, 3, 2, 1).detach().cpu().numpy() + # + # arr_new_target_seg = np.zeros(arr_target_seg.shape[-3:]) + # for i in range(arr_target_seg.shape[0]): + # arr_new_target_seg[arr_target_seg[i] >= 0.5] = i + # nii_target_seg = grid.make_nii(arr_new_target_seg.astype(np.uint16), True) + # nii_target_seg.save( + # f"/DATA/NAS/datasets_processed/CT_spine/dataset-myelom/target{level}.nii.gz" + # ) + # out = self.transform_nii(nii_target_seg) + # out.save( + # f"/DATA/NAS/datasets_processed/CT_spine/dataset-myelom/moved{level}.nii.gz" + # ) + # dice = out.resample_from_to(nii_source).dice(nii_source) + # from TPTBox import Print_Logger + # + # Print_Logger().on_debug(np.mean(list(dice.values())), dice) + # # exit() @torch.no_grad() def transform_nii( @@ -254,23 +393,88 @@ def transform_nii( Returns: NII: The transformed image as an NII object. """ + if self._is_inverted: + inverse = not inverse device = get_device(ddevice, 0 if gpu is None else gpu) if ddevice is not None else self.device target_grid_nii = self.target_grid if target is None else target target_grid = target_grid_nii.to_deepali_grid(align_corners) - source_image = img.resample_from_to(self.input_grid).to_deepali() + source_image = img.resample_from_to(self.input_grid, mode="constant").to_deepali() data = _warp_image( - source_image, target_grid, self.transform, "nearest" if img.seg else "linear", device=device, inverse=inverse + source_image, + target_grid, + self.transform, + "nearest" if img.seg else "linear", + device=device, + inverse=inverse, ).squeeze() data: torch.Tensor = data.permute(*torch.arange(data.ndim - 1, -1, -1)) # type: ignore out = target_grid_nii.make_nii(data.detach().cpu().numpy(), img.seg) return out - def transform_poi(self, poi: POI, gpu: int | None = None, ddevice: DEVICES | None = None, align_corners=True, inverse=True): + def transform_poi( + self, + poi: POI, + gpu: int | None = None, + ddevice: DEVICES | None = None, + align_corners=True, + inverse=True, + ): + if self._is_inverted: + inverse = not inverse device = get_device(ddevice, 0 if gpu is None else gpu) if ddevice is not None else self.device source_image = poi.resample_from_to(self.target_grid) - data = _warp_poi(source_image, self.target_grid, self.transform, align_corners, device=device, inverse=inverse) + data = _warp_poi( + source_image, + self.target_grid, + self.transform, + align_corners, + device=device, + inverse=inverse, + ) return data.resample_from_to(self.target_grid) + def transform_points( + self, + points, + axes: Axes, + to_axes: Axes, + grid: Deepali_Grid | Has_Grid, + to_grid: Deepali_Grid | Has_Grid, + gpu: int | None = None, + ddevice: DEVICES | None = None, + inverse=True, + ): + """ + Transform a set of points using the registered transformation. + Args: + points (list): List of points to warp: (b,n) b points with n coordinates. + axes (Axes): Axes of the input points. + to_axes (Axes): Axes of the output points. + grid (Deepali_Grid | Has_Grid): The grid to which the points belong. + to_grid (Deepali_Grid | Has_Grid): The target grid for the transformed points. + gpu (int, optional): GPU index to use. Defaults to None. + ddevice (DEVICES, optional): Device type. Defaults to "cuda". + inverse (bool, optional): Whether to apply the inverse transformation. Defaults to True. + """ + + if self._is_inverted: + inverse = not inverse + if isinstance(grid, Has_Grid): + grid = grid.to_deepali_grid() + if isinstance(to_grid, Has_Grid): + to_grid = to_grid.to_deepali_grid() + device = get_device(ddevice, 0 if gpu is None else gpu) if ddevice is not None else self.device + return _warp_points( + points, + axes, + to_axes, + grid, + to_grid, + transform=self.transform, + device=device, + inverse=True, + ) + def __call__(self, *args, **kwds) -> NII: """ Call method to apply the transformation using the transform_nii method. @@ -285,7 +489,7 @@ def __call__(self, *args, **kwds) -> NII: return self.transform_nii(*args, **kwds) def get_dump(self): - return (self.transform, self.target_grid, self.input_grid) + return (self.transform, self.target_grid, self.input_grid, self._is_inverted) def save(self, path: str | Path): with open(path, "wb") as w: @@ -297,11 +501,12 @@ def load(cls, path, gpu=0, ddevice: DEVICES = "cuda"): return cls.load_(pickle.load(w), gpu, ddevice) @classmethod - def load_(cls, w, gpu=0, ddevice: DEVICES = "cuda"): - transform, grid, mov = w + def load_(cls, w, gpu=0, ddevice: DEVICES = "cuda") -> Self: + transform, grid, mov, _is_inverted = w self = cls.__new__(cls) self.transform = transform self.target_grid = grid self.input_grid = mov + self._is_inverted = _is_inverted self.device = get_device(ddevice, gpu) return self diff --git a/TPTBox/registration/deepali/deepali_trainer.py b/TPTBox/registration/deepali/deepali_trainer.py index 314660b..a2edd5a 100644 --- a/TPTBox/registration/deepali/deepali_trainer.py +++ b/TPTBox/registration/deepali/deepali_trainer.py @@ -7,7 +7,7 @@ from copy import copy from pathlib import Path from timeit import default_timer as timer -from typing import Literal, Optional, Union +from typing import Literal, Union import torch import torch.optim @@ -69,6 +69,8 @@ def __init__( self, source: Union[Image, PathStr], target: Union[Image, PathStr], + source_seg: Union[Image, PathStr] | None = None, + target_seg: Union[Image, PathStr] | None = None, source_pset=None, target_pset=None, source_landmarks=None, @@ -80,14 +82,13 @@ def __init__( source_mask: Union[Image, PathStr] | None = None, target_mask: Union[Image, PathStr] | None = None, # normalize - normalize_strategy: Optional[ - Literal["auto", "CT", "MRI"] - ] = "auto", # Override on_normalize for finer normalization schema or normalize before and set to None. auto: [min,max] -> [0,1]; None: Do noting + normalize_strategy: Literal["auto", "CT", "MRI"] + | None = "auto", # Override on_normalize for finer normalization schema or normalize before and set to None. auto: [min,max] -> [0,1]; None: Do noting # Pyramid - pyramid_levels: Optional[int] = None, # 1/None = no pyramid; int: number of stacks, tuple from to (0 is finest) + pyramid_levels: int | None = None, # 1/None = no pyramid; int: number of stacks, tuple from to (0 is finest) finest_level: int = 0, - coarsest_level: Optional[int] = None, - pyramid_finest_spacing: Optional[Sequence[int] | torch.Tensor] = None, + coarsest_level: int | None = None, + pyramid_finest_spacing: Sequence[int] | torch.Tensor | None = None, pyramid_min_size=16, dims=("x", "y", "z"), align=False, @@ -95,16 +96,16 @@ def __init__( transform_args: dict | None = None, transform_init: PathStr | None = None, # reload initial flowfield from file optim_name="Adam", # Optimizer name defined in torch.optim. or override on_optimizer finer control - lr: float | list[float] = 0.01, # Learning rate + lr: float | Sequence[float] = 0.01, # Learning rate optim_args=None, # args of Optimizer with out lr smooth_grad=0.0, verbose=0, max_steps: int | Sequence[int] = 250, # Early stopping. override on_converged finer control max_history: int | None = None, min_value=0.0, # Early stopping. override on_converged finer control - min_delta=0.0, # Early stopping. override on_converged finer control + min_delta: float | Sequence[float] = 0.0, # Early stopping. override on_converged finer control loss_terms: list[LOSS | str] | dict[str, LOSS] | dict[str, str] | dict[str, tuple[str, dict]] | None = None, - weights: list[float] | dict[str, float] | dict[str, list[float]] | None = None, + weights: list[float] | dict[str, float | list[float]] | dict[str, list[float]] | None = None, ) -> None: """ Initializes the DeepaliPairwiseImageTrainer for pairwise image registration. @@ -180,39 +181,101 @@ def __init__( if transform_args is None: transform_args = {} - self._dtype = torch.float32 + self._dtype = torch.float16 self.device = get_device_config(device) - self.normalize_strategy: Optional[Literal["auto", "CT", "MRI"]] = normalize_strategy + self.normalize_strategy: Literal["auto", "CT", "MRI"] | None = normalize_strategy self.align = align self.transform_name = transform_name self.model_args = transform_args self.model_init = transform_init self.optim_name = optim_name - self.lr = lr if not isinstance(lr, (list, tuple)) else lr[::-1] + self.lr = lr if not isinstance(lr, (Sequence)) else lr[::-1] self.optim_args = optim_args - self.max_steps = max_steps + self.max_steps = max_steps if not isinstance(max_steps, (Sequence)) else max_steps[::-1] self.max_history = max_history self.min_value = min_value - self.min_delta = min_delta + self.min_delta = min_delta if not isinstance(min_delta, (Sequence)) else min_delta[::-1] self.verbose = verbose self.loss_terms, self.weights = parse_loss(loss_terms, weights) # reading images self.source = self._read(source) self.target = self._read(target) - # self.source_seg = self._read(source_seg) - # self.target_seg = self._read(target_seg) + # generate mask - self.source_mask = self._read(source_mask).type(torch.int8) if source_mask is not None else None - self.target_mask = self._read(target_mask).type(torch.int8) if target_mask is not None else None + self.source_mask = self._read(source_mask, torch.int8) if source_mask is not None else None + self.target_mask = self._read(target_mask, torch.int8) if target_mask is not None else None # normalize self.source, self.target = self.on_normalize(self.source, self.target) - # self.source_seg, self.target_seg = self.on_normalize_seg(self.source_seg, self.target_seg) # Pyramid self.source_pyramid, self.target_pyramid = self.make_pyramid( self.source, self.target, pyramid_levels, finest_level, coarsest_level, dims, pyramid_finest_spacing, pyramid_min_size ) + if source_seg is not None or target_seg is not None: + with torch.no_grad(): + self.source_seg_org = self._read(source_seg, torch.long, "cpu") + self.target_seg_org = self._read(target_seg, torch.long, "cpu") + # Get unique labels from both source and target + u = torch.unique(self.target_seg_org.tensor()) + u = u.detach().cpu() + u = [a for a in u if a != 0] + # Build a mapping from original label -> index (starting from 1) + mapping = {int(label.item()): idx for idx, label in enumerate(u, 1)} + + # Remap the segmentation labels according to mapping + source_remapped = self.source_seg_org.tensor().clone() + target_remapped = self.target_seg_org.tensor().clone() + for orig_label, new_label in mapping.items(): + source_remapped[self.source_seg_org.tensor() == orig_label] = new_label + target_remapped[self.target_seg_org.tensor() == orig_label] = new_label + + # Convert to one-hot if needed (optional) + num_classes = len(mapping) + 1 # Add 1 for background or assume no 0 + print(f"Found {num_classes=}, {source_remapped.unique()}, {target_remapped.unique()}") + one_hot_source = ( + (torch.nn.functional.one_hot(source_remapped.long(), num_classes).to(self._dtype).to(self.device)) + .permute(0, 4, 1, 2, 3) + .squeeze(0) + ) + one_hot_target = ( + (torch.nn.functional.one_hot(target_remapped.long(), num_classes).to(self._dtype).to(self.device)) + .permute(0, 4, 1, 2, 3) + .squeeze(0) + ) + print(f"{one_hot_target.shape=}", one_hot_target.device) + + # Wrap in Image object again + self.source_seg = Image( + one_hot_source.detach(), + self.source_seg_org.grid(), + dtype=self._dtype, + device=self.device, + ) + self.target_seg = Image( + one_hot_target.detach(), + self.target_seg_org.grid(), + dtype=self._dtype, + device=self.device, + ) + print("make_pyramid seg", self.source_seg.dtype, self.source_seg.device) + self.source_pyramid_seg, self.target_pyramid_seg = self.make_pyramid( + self.source_seg, + self.target_seg, + pyramid_levels, + finest_level, + coarsest_level, + dims, + pyramid_finest_spacing, + pyramid_min_size, + ) + print("make_pyramid seg end", self.source_seg.dtype) + else: + self.source_seg = None + self.target_seg = None + self.source_pyramid_seg = None + self.target_pyramid_seg = None + self.source_pset = source_pset self.target_pset = target_pset self.source_landmarks = source_landmarks @@ -227,13 +290,17 @@ def on_normalize(self, source: Image, target: Image): # def on_normalize_seg(self, source_seg: Optional[Image], target_seg: Optional[Image]): # return clamp_mask(source_seg), clamp_mask(target_seg) - def _read(self, source) -> Image: + def _read(self, source, dtype=None, device=None) -> Image: + if dtype is None: + dtype = self._dtype + if device is None: + device = self.device if isinstance(source, (str, Path)): - return Image.read(source, dtype=self._dtype, device=self.device) + return Image.read(source, dtype=dtype, device=device) elif hasattr(source, "to_deepali"): - source = source.to_deepali(dtype=self._dtype, device=self.device) + source = source.to_deepali(dtype=dtype, device=device) else: - source = source.to(dtype=self._dtype, device=self.device) + source = source.to(dtype=dtype, device=device) return source def _pyramid(self, target_image: Image): @@ -250,11 +317,11 @@ def make_pyramid( self, source_image: Image, target_image: Image, - levels: Optional[int] = None, + levels: int | None = None, finest_level: int = 0, - coarsest_level: Optional[int] = None, + coarsest_level: int | None = None, pyramid_dims=("x", "y", "z"), - finest_spacing: Optional[Sequence[int] | torch.Tensor] = None, + finest_spacing: Sequence[int] | torch.Tensor | None = None, min_size=16, ): if levels is None or levels <= 1: @@ -286,7 +353,7 @@ def make_pyramid( def on_make_transform(self, transform_name, grid, groups=1, **model_args): return new_spatial_transform(transform_name, grid, groups=groups, **model_args) - def on_optimizer(self, grid_transform: SequentialTransform, level) -> tuple[Optimizer, Optional[LRScheduler]]: + def on_optimizer(self, grid_transform: SequentialTransform, level) -> tuple[Optimizer, LRScheduler | None]: name = self.optim_name cls = getattr(torch.optim, name, None) if cls is None: @@ -297,15 +364,21 @@ def on_optimizer(self, grid_transform: SequentialTransform, level) -> tuple[Opti kwargs["lr"] = self.lr[level] if isinstance(self.lr, (list, tuple)) else self.lr return cls(grid_transform.parameters(), **kwargs), None - def on_converged(self) -> bool: + def on_converged(self, level) -> bool: r"""Check convergence criteria.""" - if self.min_delta == 0 and self.min_value == 0: + if isinstance(self.min_delta, (float, int)): + min_delta = self.min_delta + elif len(self.min_delta) > level: + min_delta = self.min_delta[level] + else: + min_delta = self.min_delta[-1] + if min_delta == 0 and self.min_value == 0: return False values = self.loss_values if not values: return False value = values[-1] - epsilon = abs(self.min_delta * value) if self.min_delta < 0 else self.min_delta + epsilon = abs(min_delta * value) if min_delta < 0 else min_delta slope = slope_of_least_squares_fit(values) if abs(slope) < epsilon: return True @@ -313,7 +386,11 @@ def on_converged(self) -> bool: def _loss_terms_of_type(self, loss_type: type) -> dict[str, Module]: r"""Get dictionary of loss terms of a specific type.""" - return {name: module for name, module in self.loss_terms.items() if isinstance(module, loss_type)} # type: ignore + return { + name: module + for name, module in self.loss_terms.items() + if isinstance(module, loss_type) and not (name in ["Dice", "DCE"] or module.__class__.__name__ in ["Dice", "DCE"]) + } # type: ignore def _transforms_of_type(self, transform_type: type[SpatialTransform]) -> list[SpatialTransform]: r"""Get list of spatial transformations of a specific type.""" @@ -344,7 +421,15 @@ def _weighted_sum(self, losses: dict[str, Tensor], level) -> Tensor: loss += value.sum() return loss - def on_loss(self, grid_transform: SequentialTransform, target: Image, source: Image, level: int): # noqa: C901 + def on_loss( # noqa: C901 + self, + grid_transform: SequentialTransform, + target: Image, + source: Image, + target_image_seg: Image | None, + source_image_seg: Image | None, + level: int, + ): # noqa: C901 r"""Evaluate pairwise image registration loss.""" target_data = target.tensor() result = {} @@ -367,6 +452,23 @@ def on_loss(self, grid_transform: SequentialTransform, target: Image, source: Im result["source"] = moved_data result["target"] = target_data result["mask"] = mask + if self.loss_pairwise_image_terms2: + assert source_image_seg is not None, "Source image segmentation is required" + moved_data: torch.Tensor = self._sample_image(y, source_image_seg.tensor()) + target_data_seg = target_image_seg.tensor() + if self.source_mask is not None and self.target_mask is not None: + # TODO this is from the reference implantation but is need way to much GPU... + moved_mask = self._sample_image(y, self.source_mask) + mask = overlap_mask(moved_mask, self.target_mask) + else: + mask = None + for name, term in self.loss_pairwise_image_terms2.items(): + losses[name] = term( # DICE + moved_data.unsqueeze(0), target_data_seg.unsqueeze(0), mask=mask + ) + result["source"] = moved_data + result["target"] = target_data + result["mask"] = mask ## Sum of pairwise point set distance terms if self.loss_dist_terms: if self.source_pset is None: @@ -443,6 +545,8 @@ def on_step( grid_transform: SequentialTransform, target_image: Image, source_image: Image, + target_image_seg: Image | None, + source_image_seg: Image | None, opt, scheduler, num_steps, @@ -455,7 +559,7 @@ def on_step( """ with OptimizerWrapper(opt, scheduler): - result = self.on_loss(grid_transform, target_image, source_image, level) + result = self.on_loss(grid_transform, target_image, source_image, target_image_seg, source_image_seg, level) loss: Tensor = result["loss"] loss.backward() with torch.no_grad(): @@ -471,6 +575,8 @@ def _run_level( grid_transform: Union[SequentialTransform, SpatialTransform, CompositeTransform], target_image: Image, source_image: Image, + target_image_seg: Image | None, + source_image_seg: Image | None, level, sampling: Union[Sampling, str] = Sampling.LINEAR, ): @@ -485,17 +591,37 @@ def _run_level( self.optimizer = opt if isinstance(self.max_steps, int): max_steps = self.max_steps - elif len(self.max_steps) >= level: - max_steps = self.max_steps[-1] - elif len(self.max_steps) >= level: + elif len(self.max_steps) > level: max_steps = self.max_steps[level] - - return self.run_level(grid_transform, target_image, source_image, opt, lr_sq, level, max_steps, sampling) + else: + max_steps = self.max_steps[-1] + return self.run_level( + grid_transform, + target_image, + source_image, + target_image_seg, + source_image_seg, + opt, + lr_sq, + level, + max_steps, + sampling, + ) def on_split_losses(self): misc_excl = set() self.loss_terms = {a: l.to(self.device) for a, l in self.loss_terms.items()} + from TPTBox.registration.ridged_intensity.affine_deepali import ( # noqa: PLC0415 + PairwiseSegImageLoss, + ) + + self.loss_pairwise_image_terms2 = self._loss_terms_of_type(PairwiseSegImageLoss) + for name, module in self.loss_terms.items(): + if name in ["Dice", "DCE"] or module.__class__.__name__ in ["Dice", "DCE"]: + self.loss_pairwise_image_terms2[name] = module + misc_excl |= set(self.loss_pairwise_image_terms2.keys()) self.loss_pairwise_image_terms = self._loss_terms_of_type(PairwiseImageLoss) + misc_excl |= set(self.loss_pairwise_image_terms.keys()) dist_terms = self._loss_terms_of_type(PointSetDistance) misc_excl |= set(dist_terms.keys()) @@ -510,11 +636,41 @@ def on_split_losses(self): misc_excl |= set(self.loss_params_terms.keys()) self.loss_misc_terms = {k: v for k, v in self.loss_terms.items() if k not in misc_excl} + def on_run_start( + self, + grid_transform, + target_image, + source_image, + target_image_seg, + source_image_seg, + opt, + lr_sq, + num_steps, + level, + ): + pass + + def on_run_end( + self, + grid_transform, + target_image, + source_image, + target_image_seg, + source_image_seg, + opt, + lr_sq, + num_steps, + level, + ): + pass + def run_level( self, grid_transform: SequentialTransform, target_image: Image, source_image: Image, + target_image_seg: Image | None, + source_image_seg: Image | None, opt: Optimizer, lr_sq: LRScheduler | None, level, @@ -527,7 +683,11 @@ def run_level( source_grid = source_image.grid() self.transform = grid_transform self._sample_image = SampleImage( - target=target_grid, source=source_grid, sampling=sampling, padding=PaddingMode.ZEROS, align_centers=False + target=target_grid, + source=source_grid, + sampling=sampling, + padding=PaddingMode.ZEROS, + align_centers=False, ).to(self.device) grad_sigma = self.smooth_grad self.loss_values = [] @@ -543,9 +703,29 @@ def run_level( self.register_eval_hook(print_eval_loss_hook_tqdm(level, max_steps)) elif self.verbose > 1: self.register_step_hook(print_step_loss_hook_tqdm(level, max_steps)) - - while num_steps < max_steps and not self.on_converged(): - value = self.on_step(grid_transform, target_image, source_image, opt, lr_sq, num_steps, level) + self.on_run_start( + grid_transform, + target_image, + source_image, + target_image_seg, + source_image_seg, + opt, + lr_sq, + num_steps, + level, + ) + while num_steps < max_steps and not self.on_converged(level): + value = self.on_step( + grid_transform, + target_image, + source_image, + target_image_seg, + source_image_seg, + opt, + lr_sq, + num_steps, + level, + ) num_steps += 1 with torch.no_grad(): if math.isnan(value): @@ -558,6 +738,17 @@ def run_level( self._eval_hooks = _eval_hooks self._step_hooks = _step_hooks + self.on_run_end( + grid_transform, + target_image, + source_image, + target_image_seg, + source_image_seg, + opt, + lr_sq, + num_steps, + level, + ) def _load_initial_transform(self, transform: SpatialTransform): if self.model_init: @@ -569,7 +760,10 @@ def _load_initial_transform(self, transform: SpatialTransform): del disp_field return transform - def register_eval_hook(self, hook: Callable[[DeepaliPairwiseImageTrainer, int, int, RegistrationResult], None]) -> RemovableHandle: + def register_eval_hook( + self, + hook: Callable[[DeepaliPairwiseImageTrainer, int, int, RegistrationResult], None], + ) -> RemovableHandle: r"""Registers a evaluation hook. The hook will be called every time after the registration loss has been evaluated @@ -616,7 +810,14 @@ def run(self): transform = transform.to(device=self.device) grid_transform = SequentialTransform(transform) self.on_transform_update(grid_transform) - grid_transform = self._run_level(grid_transform, target_image, source_image, 0) + grid_transform = self._run_level( + grid_transform, + target_image, + source_image, + self.target_seg, + self.source_seg, + 0, + ) else: with torch.no_grad(): ## loop pyramid @@ -635,7 +836,12 @@ def run(self): transform_downsample = self.model_args.pop("downsample", 0) transform_grid = coarsest_grid.downsample(transform_downsample) - transform = self.on_make_transform(self.transform_name, grid=transform_grid, groups=1, **self.model_args) + transform = self.on_make_transform( + self.transform_name, + grid=transform_grid, + groups=1, + **self.model_args, + ) transform = self._load_initial_transform(transform) grid_transform = SequentialTransform(transform, post_transform) grid_transform = grid_transform.to(device=self.device) @@ -647,6 +853,8 @@ def run(self): with torch.no_grad(): target_image = target_pyramid[level] source_image = source_pyramid[level] + target_image_seg = self.target_pyramid_seg[level] if self.target_pyramid_seg is not None else None + source_image_seg = self.source_pyramid_seg[level] if self.source_pyramid_seg is not None else None if self.target_mask is not None: self.target_mask = torch.ceil(target_mask_pyramid[level]).to(dtype=torch.int8) ## Initialize transformation @@ -658,7 +866,14 @@ def run(self): if self.verbose > 3: print(f"Subdivided control point grid in {timer() - start:.3f}s") grid_transform.grid_(target_image.grid()) - self._run_level(grid_transform, target_image, source_image, level) + self._run_level( + grid_transform, + target_image, + source_image, + target_image_seg, + source_image_seg, + level, + ) if self.verbose > 3: print(f"Registered images in {timer() - start_reg:.3f}s") if self.verbose > 0: diff --git a/TPTBox/registration/deepali/spine_rigid_elements_reg.py b/TPTBox/registration/deepali/spine_rigid_elements_reg.py index 3650153..337d199 100644 --- a/TPTBox/registration/deepali/spine_rigid_elements_reg.py +++ b/TPTBox/registration/deepali/spine_rigid_elements_reg.py @@ -1,5 +1,7 @@ from __future__ import annotations +import os + # pip install hf-deepali import pickle from collections.abc import Sequence @@ -22,6 +24,17 @@ from TPTBox.registration.ridged_points import Point_Registration +def not_exist_or_is_younger_than(path: Path | str, other_path: Path | str | None): + path = Path(path) + if not path.exists(): + return True + if other_path is None: + return True + fileCreation = os.path.getmtime(path) + fileCreation_ref = os.path.getmtime(other_path) + return fileCreation < fileCreation_ref + + def _load_poi(fixed_poi_file, vert: NII, subreg: NII, save_pois): buffer_file = None if fixed_poi_file is not None: @@ -73,10 +86,10 @@ def __init__( gpu=0, ddevice: DEVICES = "cuda", # Pyramid - pyramid_levels: Optional[int] = None, # 1/None = no pyramid; int: number of stacks, tuple from to (0 is finest) + pyramid_levels: int | None = None, # 1/None = no pyramid; int: number of stacks, tuple from to (0 is finest) finest_level: int = 0, - coarsest_level: Optional[int] = None, - pyramid_finest_spacing: Optional[Sequence[int] | torch.Tensor] = None, + coarsest_level: int | None = None, + pyramid_finest_spacing: Sequence[int] | torch.Tensor | None = None, pyramid_min_size=16, dims=("x", "y", "z"), align=False, @@ -90,6 +103,7 @@ def __init__( patience=100, patience_delta=0.00001, my=1.5, + orientation=None, ) -> None: self.my = my if weights is None: @@ -97,6 +111,7 @@ def __init__( if device is None: device = get_device(ddevice, gpu) self.device: torch.device = device # type: ignore + self.orientation = orientation self._resample_and_preReg( fixed_image, fixed_vert, @@ -109,6 +124,7 @@ def __init__( reference_image, crop_to_FOV, save_pois, + orientation, ) self._run_rigid_reg( @@ -144,6 +160,7 @@ def _resample_and_preReg( reference_image: Image_Reference | None, crop_to_FOV: bool, save_pois: bool, + orientation, ): # Load # fixed_image = to_nii(fixed_image_) @@ -158,7 +175,28 @@ def _resample_and_preReg( # Resample ref = fixed_vert if reference_image is None else reference_image if not isinstance(ref, Has_Grid): - ref = to_nii(reference_image) + ref = to_nii(ref) + if orientation is not None: + ref.reorient_(orientation) # type: ignore + fixed_vert.reorient_(orientation) + fixed_subreg.reorient_(orientation) + moving_image.reorient_(orientation) + moving_vert.reorient_(orientation) + moving_subreg.reorient_(orientation) + fixed_poi.reorient_(orientation) + moving_poi.reorient_(orientation) + + if not fixed_vert.assert_affine(fixed_subreg, raise_error=False): + fixed_subreg.resample_from_to_(fixed_vert) + if not fixed_vert.assert_affine(fixed_poi, raise_error=False): + fixed_poi = fixed_poi.resample_from_to(fixed_vert) + + if not moving_image.assert_affine(moving_vert, raise_error=False): + moving_vert.resample_from_to_(moving_image) + if not moving_image.assert_affine(moving_subreg, raise_error=False): + moving_subreg.resample_from_to_(moving_image) + if not moving_image.assert_affine(moving_poi, raise_error=False): + moving_poi = moving_poi.resample_from_to(moving_image) # fixed_image.resample_from_to_(ref) fixed_vert.resample_from_to_(ref) @@ -198,10 +236,10 @@ def _resample_and_preReg( def _run_rigid_reg( self, - pyramid_levels: Optional[int] = None, # 1/None = no pyramid; int: number of stacks, tuple from to (0 is finest) + pyramid_levels: int | None = None, # 1/None = no pyramid; int: number of stacks, tuple from to (0 is finest) finest_level: int = 0, - coarsest_level: Optional[int] = None, - pyramid_finest_spacing: Optional[Sequence[int] | torch.Tensor] = None, + coarsest_level: int | None = None, + pyramid_finest_spacing: Sequence[int] | torch.Tensor | None = None, pyramid_min_size=16, dims=("x", "y", "z"), align=False, @@ -226,7 +264,13 @@ def _run_rigid_reg( for idx in ids_fixed: if idx not in ids_moving: continue + if idx >= 100: + continue # subreg_ct[subreg_ct == 20] = 49 + try: + name = Vertebra_Instance(idx).name + except Exception: + name = str(idx) reg = Rigid_Registration_with_Tether( self.fixed_vert.extract_label(idx) * self.fixed_subreg, self.moving_vert.extract_label(idx) * self.moving_subreg, @@ -246,8 +290,9 @@ def _run_rigid_reg( weights=weights, patience=patience, patience_delta=patience_delta, - desc=Vertebra_Instance(idx).name, + desc=name, ) + self._rigid_registrations.append(reg) self._ids.append(idx) @@ -329,7 +374,8 @@ def transform_nii( self, img: NII, gpu: int | None = None, ddevice: DEVICES | None = None, align_corners=True, padding=PaddingMode.ZEROS ) -> NII: device = get_device(ddevice, 0 if gpu is None else gpu) if ddevice is not None else self.device - + if self.orientation: + img = img.reorient(self.orientation) img_point_reg = self.point_reg.transform_nii(img) if self.crop is not None: img_point_reg.apply_crop_(self.crop) diff --git a/TPTBox/registration/deformable/_deepali/deform_reg_pair.py b/TPTBox/registration/deformable/_deepali/deform_reg_pair.py index 2d0083f..398131e 100644 --- a/TPTBox/registration/deformable/_deepali/deform_reg_pair.py +++ b/TPTBox/registration/deformable/_deepali/deform_reg_pair.py @@ -116,7 +116,7 @@ def append_mask(image: Image, mask_nii: NII | None, channels: dict[str, tuple[in mask = deepali_functional.threshold(data[slice(*channels["img"])], lower_threshold, upper_threshold) else: - mask = torch.ones((1,) + data.shape[1:], dtype=data.dtype, device=data.device) + mask = torch.ones((1, *data.shape[1:]), dtype=data.dtype, device=data.device) else: # torch.nn.functional.one_hot diff --git a/TPTBox/registration/deformable/deformable_reg.py b/TPTBox/registration/deformable/deformable_reg.py index 9c6f0a7..0a9c73f 100644 --- a/TPTBox/registration/deformable/deformable_reg.py +++ b/TPTBox/registration/deformable/deformable_reg.py @@ -32,6 +32,8 @@ def __init__( self, fixed_image: Image_Reference, moving_image: Image_Reference, + fixed_seg: Image_Reference | None = None, + moving_seg: Image_Reference | None = None, reference_image: Image_Reference | None = None, source_pset=None, target_pset=None, @@ -46,14 +48,13 @@ def __init__( fixed_mask: Image_Reference | None = None, moving_mask: Image_Reference | None = None, # normalize - normalize_strategy: Optional[ - Literal["auto", "CT", "MRI"] - ] = "auto", # Override on_normalize for finer normalization schema or normalize before and set to None. auto: [min,max] -> [0,1]; None: Do noting + normalize_strategy: Literal["auto", "CT", "MRI"] + | None = "auto", # Override on_normalize for finer normalization schema or normalize before and set to None. auto: [min,max] -> [0,1]; None: Do noting # Pyramid - pyramid_levels: Optional[int] = 3, # 1/None = no pyramid; int: number of stacks, tuple from to (0 is finest) + pyramid_levels: int | None = 3, # 1/None = no pyramid; int: number of stacks, tuple from to (0 is finest) finest_level: int = 0, - coarsest_level: Optional[int] = None, - pyramid_finest_spacing: Optional[Sequence[int] | torch.Tensor] = None, + coarsest_level: int | None = None, + pyramid_finest_spacing: Sequence[int] | torch.Tensor | None = None, pyramid_min_size=16, dims=("x", "y", "z"), align=False, @@ -68,24 +69,26 @@ def __init__( max_steps: int | Sequence[int] = 1000, # Early stopping. override on_converged finer controle max_history: int | None = None, min_value=0.0, # Early stopping. override on_converged finer controle - min_delta=-0.0001, # Early stopping. override on_converged finer controle + min_delta: float | Sequence[float] = -0.0001, # Early stopping. override on_converged finer controle loss_terms: list[LOSS | str] | dict[str, LOSS] | dict[str, str] | dict[str, tuple[str, dict]] | None = None, - weights: list[float] | dict[str, float] | None = None, + weights: list[float] | dict[str, float | list[float]] | None = None, auto_run=True, + stride=8, ): if transform_args is None: - transform_args = {"stride": [8, 8, 16], "transpose": False} + transform_args = {"stride": [stride, stride, stride], "transpose": False} if loss_terms is None: loss_terms = { "be": BSplineBending(stride=1), "lncc": LNCC(), } if weights is None: - weights = {"be": 0.001, "seg": 1} - + weights = {"be": 0.001, "lncc": 1} super().__init__( fixed_image=fixed_image, moving_image=moving_image, + fixed_seg=fixed_seg, + moving_seg=moving_seg, reference_image=reference_image, source_pset=source_pset, target_pset=target_pset, diff --git a/TPTBox/registration/deformable/multilabel_segmentation.py b/TPTBox/registration/deformable/multilabel_segmentation.py new file mode 100644 index 0000000..d2b9a4b --- /dev/null +++ b/TPTBox/registration/deformable/multilabel_segmentation.py @@ -0,0 +1,291 @@ +import pickle +from pathlib import Path + +from TPTBox import NII, POI +from TPTBox.core.internal.deep_learning_utils import DEVICES +from TPTBox.core.poi import calc_centroids +from TPTBox.core.poi_fun.poi_global import POI_Global +from TPTBox.registration.deformable.deformable_reg import Deformable_Registration +from TPTBox.registration.ridged_intensity.affine_deepali import Tether_Seg +from TPTBox.registration.ridged_points.point_registration import Point_Registration + + +class Register_Multi_Seg: + """ + Class to perform multi-stage registration between two multi-label segmentations, including optional + landmark (point-of-interest, POI) alignment and deformable registration. If not provided they will be computed on the fly. + + This is especially useful for aligning anatomical segmentations from MRI or CT between a target and an atlas, + optionally considering left/right flipping if segmentations are from different body sides. + + Attributes: + same_side (bool): Whether the target and atlas represent the same anatomical side (e.g., both right sides). + reg_point (Point_Registration): The rigid point-based registration component. + reg_deform (Deformable_Registration): The deformable registration component. + crop (tuple): The crop applied to both target and atlas after registration. + target_grid_org (NII): Original spatial grid of the target. + atlas_org (NII): Original spatial grid of the atlas. + target_grid (NII): Cropped spatial grid used for deformable registration. + """ + + def __init__( + self, + target: NII, + atlas: NII, + poi_cms: POI | None, + same_side: bool, + verbose=99, + gpu=0, + ddevice: DEVICES = "cuda", + loss_terms=None, # type: ignore + weights=None, + lr=0.01, + max_steps=1500, + min_delta=1e-06, + pyramid_levels=4, + coarsest_level=3, + finest_level=0, + cms_ids: list | None = None, + poi_target_cms: POI | None = None, + **args, + ): + """ + Initialize a multi-stage registration pipeline from an atlas to a target image. + + Args: + target (NII): Target image with segmentation (e.g., from a subject). + atlas (NII): Atlas image with segmentation (e.g., a reference or template). + poi_cms (POI | None): POI centroids of the atlas, used for initial point registration. + same_side (bool): Whether atlas and target represent the same body side. + verbose (int): Verbosity level for logging. + gpu (int): GPU device ID (only relevant if using GPU). + ddevice (DEVICES): Device type ('cuda' or 'cpu'). + loss_terms (dict): Dictionary of loss terms for deformable registration. + weights (dict): Weights for the loss terms. + lr (float): Learning rate for deformable registration optimizer. + max_steps (int): Maximum optimization steps. + min_delta (float): Minimum delta for convergence. + pyramid_levels (int): Number of resolution levels in multi-scale deformable registration. + coarsest_level (int): Coarsest level index. + finest_level (int): Finest level index. + cms_ids (list | None): List of segmentation labels used to extract POI centroids. + poi_target_cms (POI | None): Optional precomputed centroids for the target image. + **args: Additional keyword arguments passed to Deformable_Registration. + Raises: + ValueError: If an invalid axis is detected during flipping. + """ + if weights is None: + weights = {"be": 0.0001, "seg": 1, "Dice": 0.01, "Tether": 0.001} + if loss_terms is None: + loss_terms = { + "be": ("BSplineBending", {"stride": 1}), + "seg": "MSE", + "Dice": "Dice", + "Tether": Tether_Seg(delta=5), + } + + assert target.seg, target.seg + assert atlas.seg + target = target.copy() + atlas = atlas.copy() + self.same_side = same_side + self.target_grid_org = target.to_gird() + self.atlas_org = atlas.to_gird() + if not same_side: + axis = target.get_axis("R") + if axis == 0: + target = target.set_array(target.get_array()[::-1]).copy() + elif axis == 1: + target = target.set_array(target.get_array()[:, ::-1]).copy() + elif axis == 2: + target = target.set_array(target.get_array()[:, :, ::-1]).copy() + else: + raise ValueError(axis) + if poi_target_cms is not None: + axis = poi_target_cms.get_axis("R") + for k1, k2, (x, y, z) in poi_target_cms.copy().items(): + if axis == 0: + poi_target_cms[k1, k2] = (poi_target_cms.shape[0] - 1 - x, y, z) + elif axis == 1: + poi_target_cms[k1, k2] = (x, poi_target_cms.shape[1] - 1 - y, z) + elif axis == 2: + poi_target_cms[k1, k2] = (x, y, poi_target_cms.shape[2] - 1 - z) + print("crop") + crop = 50 + t_crop = (target).compute_crop(0, crop) + target = target.apply_crop(t_crop) + if atlas.is_segmentation_in_border(): + atlas = atlas.apply_pad(((1, 1), (1, 1), (1, 1))) + for i in range(10): # 1000, + if i != 0: + target = target.apply_pad(((25, 25), (25, 25), (25, 25))) + crop += 50 + t_crop = (target).compute_crop(0, crop) # if the angel is to different we need a larger crop... + target_ = target.apply_crop(t_crop) + # Point Registration + print("calc_centroids") + if poi_target_cms is None: + x = target_.extract_label(cms_ids, keep_label=True) if cms_ids else target_ + poi_target = calc_centroids(x, second_stage=40, bar=True) # TODO REMOVE + else: + poi_target = poi_target_cms.resample_from_to(target_) + if poi_cms is None: + x = atlas.extract_label(cms_ids, keep_label=True) if cms_ids else atlas + poi_cms = calc_centroids(x, second_stage=40, bar=True) # This will be needlessly computed all the time + if not poi_cms.assert_affine(atlas, raise_error=False): + poi_cms = poi_cms.resample_from_to(atlas) + self.reg_point = Point_Registration(poi_target, poi_cms) + atlas_reg = self.reg_point.transform_nii(atlas) + if atlas_reg.is_segmentation_in_border(): + print("atlas_reg does touch the border") + else: + target = target_ + break + target = target_ + self.crop = (target + atlas_reg).compute_crop(0, 5) + target = target.apply_crop(self.crop) + atlas_reg = atlas_reg.apply_crop(self.crop) + self.target_grid = target.to_gird() + self.reg_deform = Deformable_Registration( + target, + atlas_reg, + target.copy(), + atlas_reg.copy(), + loss_terms=loss_terms, + weights=weights, + lr=lr, + max_steps=max_steps, + min_delta=min_delta, + pyramid_levels=pyramid_levels, + coarsest_level=coarsest_level, + finest_level=finest_level, + verbose=verbose, + gpu=gpu, + ddevice=ddevice, + **args, + ) + + def get_dump(self): + """ + Collect serializable state of the registration object. + + Returns: + tuple: Serialized components including version, rigid registration state, deformable state, + and spatial metadata. + """ + return ( + 1, # version + (self.reg_point.get_dump()), + (self.reg_deform.get_dump()), + (self.same_side, self.atlas_org, self.target_grid_org, self.target_grid, self.crop), + ) + + def save(self, path: str | Path): + """ + Save the registration state to a file. + + Args: + path (str | Path): Path to save the pickle file. + """ + with open(path, "wb") as w: + pickle.dump(self.get_dump(), w) + + @classmethod + def load(cls, path): + """ + Load a previously saved registration state from a file. + + Args: + path (str | Path): Path to the pickle file. + + Returns: + Register_Multi_Seg: Reconstructed instance of the class. + """ + with open(path, "rb") as w: + return cls.load_(pickle.load(w)) + + @classmethod + def load_(cls, w): + """ + Load a registration object from a deserialized state (as returned by `get_dump()`). + + Args: + w (tuple): Serialized state. + + Returns: + Register_Multi_Seg: Reconstructed instance of the class. + """ + (version, t0, t1, x) = w + assert version == 1, f"Version mismatch {version=}" + self = cls.__new__(cls) + self.reg_point = Point_Registration.load_(t0) + self.reg_deform = Deformable_Registration.load_(t1) + self.same_side, self.atlas_org, self.target_grid_org, self.target_grid, self.crop = x + return self + + def forward_nii(self, nii_atlas: NII): + """ + Apply both rigid and deformable registration to a new NII object. + + Args: + nii_atlas (NII): New atlas image to be transformed. + + Returns: + NII: Transformed image aligned with the original target image. + """ + nii_atlas = self.reg_point.transform_nii(nii_atlas) + nii_atlas = nii_atlas.apply_crop(self.crop) + nii_reg = self.reg_deform.transform_nii(nii_atlas) + if nii_reg.seg: + nii_reg.set_dtype_("smallest_uint") + out = nii_reg.resample_from_to(self.target_grid_org) + if self.same_side: + return out + axis = out.get_axis("R") + if axis == 0: + target = out.set_array(out.get_array()[::-1]).copy() + elif axis == 1: + target = out.set_array(out.get_array()[:, ::-1]).copy() + elif axis == 2: + target = out.set_array(out.get_array()[:, :, ::-1]).copy() + else: + raise ValueError(axis) + + return target + + def forward_poi(self, poi_atlas: POI_Global | POI): + """ + Apply both rigid and deformable registration to a POI (landmark) object. + + Args: + poi_atlas (POI_Global | POI): Atlas landmarks to be transformed. + + Returns: + POI: Transformed POIs aligned to the target space. + """ + poi_atlas = poi_atlas.resample_from_to(self.atlas_org) + + # Point Reg + poi_atlas = self.reg_point.transform_poi(poi_atlas) + # Deformable + poi_atlas = poi_atlas.apply_crop(self.crop) + + poi_reg = self.reg_deform.transform_poi(poi_atlas) + poi_reg = poi_reg.resample_from_to(self.target_grid_org) + if self.same_side: + return poi_reg + for k1, k2, v in poi_reg.copy().items(): + k = k1 # % 100 + poi_reg[k, k2] = v + poi_reg_flip = poi_reg.make_empty_POI() + for k1, k2, (x, y, z) in poi_reg.copy().items(): + axis = poi_reg.get_axis("R") + if axis == 0: + poi_reg_flip[k1, k2] = (poi_reg.shape[0] - 1 - x, y, z) + elif axis == 1: + poi_reg_flip[k1, k2] = (x, poi_reg.shape[1] - 1 - y, z) + elif axis == 2: + poi_reg_flip[k1, k2] = (x, y, poi_reg.shape[2] - 1 - z) + else: + raise ValueError(axis) + return poi_reg_flip diff --git a/TPTBox/registration/ridged_intensity/affine_deepali.py b/TPTBox/registration/ridged_intensity/affine_deepali.py index 0a9e58d..1dad5bb 100644 --- a/TPTBox/registration/ridged_intensity/affine_deepali.py +++ b/TPTBox/registration/ridged_intensity/affine_deepali.py @@ -1,59 +1,303 @@ from __future__ import annotations +from abc import ABCMeta, abstractmethod + # pip install hf-deepali -import json -import pickle -import time from collections.abc import Sequence -from pathlib import Path -from typing import Literal, Optional, Union +from copy import deepcopy +from typing import Literal, Union import torch import torch.optim -import yaml from deepali import spatial -from deepali.core import Axes, PathStr, Sampling -from deepali.core import Grid as Deepali_Grid +from deepali.core import PathStr, Sampling from deepali.data import Image as deepaliImage from deepali.losses import ( - BSplineLoss, - DisplacementLoss, - LandmarkPointDistance, PairwiseImageLoss, - ParamsLoss, - PointSetDistance, ) -from deepali.losses.functional import mse_loss, ncc_loss -from deepali.modules import TransformImage -from deepali.spatial import SpatialTransform +from deepali.losses.functional import ncc_loss +from torch import Tensor +from torch.nn import Module from tqdm import tqdm -from TPTBox import NII, POI, Image_Reference, to_nii -from TPTBox.core.compat import zip_strict +from TPTBox import Image_Reference from TPTBox.core.internal.deep_learning_utils import DEVICES, get_device -from TPTBox.core.nii_poi_abstract import Grid as TPTBox_Grid -from TPTBox.core.nii_poi_abstract import Has_Grid from TPTBox.registration.deepali.deepali_model import General_Registration from TPTBox.registration.deepali.deepali_trainer import PairwiseImageLoss -from TPTBox.registration.ridged_points import Point_Registration + + +class PairwiseSegImageLoss(Module, metaclass=ABCMeta): + r"""Base class of pairwise image dissimilarity criteria.""" + + @abstractmethod + def forward(self, source: Tensor, target: Tensor, mask: [Tensor] | None = None) -> Tensor: + r"""Evaluate image dissimilarity loss.""" + raise NotImplementedError(f"{type(self).__name__}.forward()") def center_of_mass(tensor): grid = torch.meshgrid([torch.arange(s, device=tensor.device) for s in tensor.shape], indexing="ij") - com = torch.stack([(tensor * g).sum() / tensor.sum() for g in grid]) + t = tensor / tensor.sum() + com = torch.stack([(t * g).sum() for g in grid]) return com -class Tether(PairwiseImageLoss): - def forward(self, source: torch.Tensor, target: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor: # noqa: ARG002 +class Tether_single(PairwiseImageLoss): + def forward( + self, + source: torch.Tensor, + target: torch.Tensor, + mask: torch.Tensor | None = None, # noqa: ARG002 + ) -> torch.Tensor: # noqa: ARG002 com_fixed = center_of_mass(target) com_warped = center_of_mass(source) l_com = torch.norm(com_fixed - com_warped) if l_com < 10: l_com = source.sum() * 0 + l_com = torch.nan_to_num(l_com, nan=0) return l_com # type: ignore +def center_of_mass_cc(tensor: torch.Tensor) -> torch.Tensor: + """ + Computes the center of mass for each channel in a (B, C, X, Y, Z) tensor. + Returns a tensor of shape (B, C, 3) containing the (x, y, z) coordinates per channel. + """ + dtype = tensor.dtype + B, C, *spatial_shape = tensor.shape + tensor = tensor.float() + grid = torch.meshgrid([torch.arange(s, device=tensor.device) for s in spatial_shape], indexing="ij") # each g is (X, Y, Z) + grid = torch.stack(grid, dim=0) # (3, X, Y, Z) + + # Flatten spatial dims + tensor_flat = tensor.view(B, C, -1) # (B, C, X*Y*Z) + grid_flat = grid.view(3, -1) # (3, X*Y*Z) + + # Normalize tensor + norm = tensor_flat.sum(dim=-1, keepdim=True) # (B, C, 1) + norm[norm == 0] = 1 # avoid division by zero + + com = torch.einsum("bcn,nm->bcm", tensor_flat, grid_flat.T.to(tensor_flat.dtype)) / norm # (B, C, 3) + return com.to(dtype) + + +class Tether_Seg(PairwiseSegImageLoss): + def __init__(self, delta=1, *args, **kwargs): + self.delta = delta + super().__init__(*args, **kwargs) + + def forward( + self, + source: torch.Tensor, # shape: (B, C, X, Y, Z) + target: torch.Tensor, # shape: (B, C, X, Y, Z) + mask: torch.Tensor | None = None, # noqa: ARG002 + ) -> torch.Tensor: + w = max(target.shape[2:]) + com_fixed = center_of_mass_cc(target) # (B, C, 3) + com_warped = center_of_mass_cc(source) # (B, C, 3) + + l_com = torch.norm(com_fixed - com_warped, dim=-1) / w # (B, C) + + # Zero out channels with small displacement (<10) or NaNs + l_com = torch.where(l_com < self.delta, torch.zeros_like(l_com), l_com) + l_com = torch.nan_to_num(l_com, nan=0.0) + + return l_com.mean() # type: ignore + + +class Tether(PairwiseImageLoss): + def __init__( + self, + delta=10, + uniq=False, + remember=False, + remember_c=10, + max_v=1, + *args, + **kwargs, + ) -> None: + self.delta = delta + self.uniq = uniq + self.remember = remember + self.remember_c = remember_c + self.count = 0 + self.max_v = max_v + super().__init__(*args, **kwargs) + + def forward( + self, + source: torch.Tensor, + target: torch.Tensor, + mask: torch.Tensor | None = None, # noqa: ARG002 + ) -> torch.Tensor: # noqa: ARG002 + if self.count != 0: + self.count -= 1 + return torch.zeros(1, device=source.device) + if self.uniq: + loss = torch.zeros(1, device=source.device) + k = 0 + target = (target * self.max_v).round(decimals=0) + source = (source * self.max_v).round(decimals=0) + u = torch.unique(target) + for i in u: + if i == 0: + continue + com_fixed = center_of_mass(target == i) + com_warped = center_of_mass(source == i) + l_com = torch.norm(com_fixed - com_warped) + l_com = torch.nan_to_num(l_com, nan=0) + # print(l_com) + if l_com > self.delta: + loss += l_com + k += 1 + # print(loss / k, k, len(u)) + if k == 0: + if self.remember: + self.count = 10 + return loss + return loss / k + else: + com_fixed = center_of_mass(target != 0) + com_warped = center_of_mass(source != 0) + l_com = torch.norm(com_fixed - com_warped) + if l_com < self.delta: + l_com = torch.zeros(1, device=source.device) + if self.remember: + self.count = 10 + l_com = torch.nan_to_num(l_com, nan=0) + return l_com # type: ignore + + +def subsample_coords(coords: torch.Tensor, k: int) -> torch.Tensor: + """ + If `coords` has more than `k` rows, return a random subset of size `k`; + otherwise return `coords` unchanged. + + Uses sampling *without* replacement (`torch.randperm`), so every + coordinate appears at most once. Works entirely on-device. + """ + n = coords.size(0) + if n <= k: + return coords + idx = torch.randperm(n, device=coords.device)[:k] + return coords[idx] + + +class DISTANCE_to_TARGET(PairwiseImageLoss): + def __init__( + self, + max_v=1, + res_gt=4, + *args, + **kwargs, + ) -> None: + self.max_v = max_v + self.res_gt = res_gt + super().__init__(*args, **kwargs) + + def forward( + self, + source: torch.Tensor, + target: torch.Tensor, + mask: torch.Tensor | None = None, # noqa: ARG002 + ) -> torch.Tensor: + """ + Chamfer-style distance loss for mis-labelled voxels. + + Parameters + ---------- + source : (D, H, W)[, …] torch.Tensor + Model prediction in label form (one channel per voxel). + target : (D, H, W)[, …] torch.Tensor + Ground-truth labels. + max_v : float, default 1 + Same scale factor you use elsewhere to map the continuous range [0, 1] + back to integer labels. Set to 1 if `source` and `target` are already + integer encoded. + + Returns + ------- + torch.Tensor scalar + The mean distance (in voxel units) from every wrongly predicted voxel + to the nearest correct voxel of the same class in the target. + """ + max_v = self.max_v + device = source.device + # Discretise + src = (source * max_v).round().short() # .long() + tgt = (target * max_v).round().short() # .long() + + classes = torch.unique(tgt) + classes = classes[classes != 0] # skip background label 0 + + if classes.numel() == 0: + return torch.zeros(1, device=device) + + per_class_losses = [] + + for c in classes: + wrong_mask = (src == c) & (tgt != c) # voxels we predicted as c but shouldn't + if not wrong_mask.any(): + continue # no penalty if we never made that error + res_gt = self.res_gt + gt_mask = tgt[..., ::res_gt, ::res_gt, ::res_gt] == c + if not gt_mask.any(): + # Optional: if the class is missing in GT you could add + # a constant penalty or skip it. Here we skip. + continue + + # Coordinates of voxels + wrong_coords = torch.nonzero(wrong_mask, as_tuple=False).float() + # print(gt_mask.shape) + # + gt_coords = torch.nonzero(gt_mask, as_tuple=False).float() + + # Pairwise distances (N_wrong, N_gt); differentiable + d = torch.cdist(subsample_coords(wrong_coords, 5000), gt_coords) + min_dists = d.min(dim=1).values # (N_wrong,) + + per_class_losses.append(min_dists.mean()) + + if not per_class_losses: + # Nothing to penalise - perfect overlap + return torch.zeros(1, device=device) + + # Average over foreground classes + return torch.stack(per_class_losses).mean() + + +class LABEL_LOSS(PairwiseImageLoss): + def __init__( + self, + max_v=1, + *args, + **kwargs, + ) -> None: + self.max_v = max_v + super().__init__(*args, **kwargs) + + def forward( + self, + source: torch.Tensor, + target: torch.Tensor, + mask: torch.Tensor | None = None, # noqa: ARG002 + ) -> torch.Tensor: # noqa: ARG002 + loss = torch.zeros(1, device=source.device) + target = (target * self.max_v).round(decimals=0) + source = (source * self.max_v).round(decimals=0) + u = torch.unique(target) + for i in u: + if i == 0: + continue + com_fixed = (target == i) ^ (source == i) + l_com = com_fixed.sum() / ((source == i).sum() + (target == i).sum()) + l_com = torch.nan_to_num(l_com, nan=0) + print(i, l_com) + loss += l_com + + return loss / len(u) + + class Rigid_Registration_with_Tether(General_Registration): def __init__( self, @@ -67,12 +311,12 @@ def __init__( fixed_mask=None, moving_mask=None, # normalize - normalize_strategy: Optional[Literal["auto", "CT", "MRI"]] = None, + normalize_strategy: Literal["auto", "CT", "MRI"] | None = None, # Pyramid - pyramid_levels: Optional[int] = None, # 1/None = no pyramid; int: number of stacks, tuple from to (0 is finest) + pyramid_levels: int | None = None, # 1/None = no pyramid; int: number of stacks, tuple from to (0 is finest) finest_level: int = 0, - coarsest_level: Optional[int] = None, - pyramid_finest_spacing: Optional[Sequence[int] | torch.Tensor] = None, + coarsest_level: int | None = None, + pyramid_finest_spacing: Sequence[int] | torch.Tensor | None = None, pyramid_min_size=16, dims=("x", "y", "z"), align=False, @@ -186,8 +430,6 @@ def on_converged(self) -> bool: self.early_stopping += 1 if value <= self.best2: - from copy import deepcopy - self.best_transform = deepcopy(self.transform) self.best2 = value diff --git a/TPTBox/registration/ridged_points/point_registration.py b/TPTBox/registration/ridged_points/point_registration.py index b812ed7..b53bf97 100755 --- a/TPTBox/registration/ridged_points/point_registration.py +++ b/TPTBox/registration/ridged_points/point_registration.py @@ -3,7 +3,6 @@ import math import pickle import warnings -from dataclasses import dataclass from pathlib import Path from typing import TypeVar @@ -11,7 +10,6 @@ import SimpleITK as sitk # noqa: N813 from TPTBox import ( - AX_CODES, NII, POI, Has_Grid, @@ -79,7 +77,7 @@ def __init__( log.print(f_keys, verbose=verbose) log.print(poi_fixed.orientation, verbose=verbose) - if len(inter) <= 2: + if len(inter) < 2: log.print("[!] To few points, skip registration", Log_Type.FAIL) raise ValueError("[!] To few points, skip registration", inter) img_movig = poi_moving.make_empty_nii() @@ -88,7 +86,13 @@ def __init__( if leave_worst_percent_out != 0.0: poi_fixed = poi_fixed.intersect(poi_moving) init_transform, error_reg, error_natural, delta_after = _compute_versor( - inter, poi_fixed, representative_f_sitk, poi_moving, representative_m_sitk, verbose=False, log=log + inter, + poi_fixed, + representative_f_sitk, + poi_moving, + representative_m_sitk, + verbose=False, + log=log, ) delta_after = sorted(delta_after.items(), key=lambda x: -x[1]) out_str = f"Did not use the following keys for registaiton (worst {leave_worst_percent_out * 100} %) " @@ -104,7 +108,13 @@ def __init__( # limit to only shared labels inter = [x for x in f_keys if x in m_keys] init_transform, error_reg, error_natural, _ = _compute_versor( - inter, poi_fixed, representative_f_sitk, poi_moving, representative_m_sitk, verbose=verbose, log=log + inter, + poi_fixed, + representative_f_sitk, + poi_moving, + representative_m_sitk, + verbose=verbose, + log=log, ) self._transform: sitk.VersorRigid3DTransform = init_transform @@ -163,8 +173,20 @@ def transform_cord(self, cord: tuple[float, ...], out: sitk.Image | None = None) ctr_b = out.TransformPhysicalPointToContinuousIndex(ctr_b) return np.array(ctr_b) + def transform_cord_inverse(self, cord: tuple[float, ...], out: sitk.Image | None = None): + if out is None: + out = self._img_fixed + ctr_b = out.TransformContinuousIndexToPhysicalPoint(cord) + ctr_b = self._transform.TransformPoint(ctr_b) + ctr_b = self._img_moving.TransformPhysicalPointToContinuousIndex(ctr_b) + return np.array(ctr_b) + def transform_nii( - self, moving_img_nii: NII, allow_only_same_grid_as_moving=True, output_space: NII | None = None, c_val: float | None = None + self, + moving_img_nii: NII, + allow_only_same_grid_as_moving=True, + output_space: NII | None = None, + c_val: float | None = None, ): if allow_only_same_grid_as_moving: text = "input image must be in the same space as moving. If you sure that this input is in same space as the moving image you can turn of 'only_allow_grid_as_moving'" @@ -297,7 +319,11 @@ def ridged_points_from_subreg_vert( semantic_nii = to_nii(subreg, True).copy() target_poi = ( calc_poi_from_subreg_vert( - instance_nii, semantic_nii, subreg_id=subreg_id, buffer_file=poi_target_buffer, save_buffer_file=save_buffer_file + instance_nii, + semantic_nii, + subreg_id=subreg_id, + buffer_file=poi_target_buffer, + save_buffer_file=save_buffer_file, ) .copy() .extract_subregion_(*subreg_id) @@ -319,7 +345,7 @@ def _compute_versor( verbose=False, log: Logger_Interface = No_Logger(), # noqa: B008 ): - assert len(inter) > 2, f"To few points: {inter}" + assert len(inter) >= 2, f"To few points: {inter}" # find shared points move_l = [] fix_l = [] @@ -336,7 +362,11 @@ def _compute_versor( moving_image_points_flat = [c for p in move_l for c in p if not math.isnan(c)] fixed_image_points_flat = [c for p in fix_l for c in p if not math.isnan(c)] init_transform = sitk.VersorRigid3DTransform( - sitk.LandmarkBasedTransformInitializer(sitk.VersorRigid3DTransform(), fixed_image_points_flat, moving_image_points_flat) + sitk.LandmarkBasedTransformInitializer( + sitk.VersorRigid3DTransform(), + fixed_image_points_flat, + moving_image_points_flat, + ) ) x_old = fix_l[0] @@ -371,14 +401,25 @@ def _compute_versor( y_ = f"{y[0]:7.1f},{y[1]:7.1f},{y[2]:7.1f}" y2_ = f"{y2[0]:7.1f},{y2[1]:7.1f},{y2[2]:7.1f}" d_ = f"{dif[0]:7.1f},{dif[1]:7.1f},{dif[2]:7.1f}" - log.print(f"{(k1, k2)!s: <7}|{x_: <23}|{y2_: <23}|{y_: <23}|{d_: <23}|{dist!s: <5}|{dist2!s: <5}|", verbose=verbose) + log.print( + f"{(k1, k2)!s: <7}|{x_: <23}|{y2_: <23}|{y_: <23}|{d_: <23}|{dist!s: <5}|{dist2!s: <5}|", + verbose=verbose, + ) x_old = x y_old = y k_old = k1 error_reg /= max(err_count, 1) error_natural /= max(err_count_n, 1) - log.print(f"Error avg registration error-vector length: {error_reg: 7.3f}", Log_Type.STAGE, verbose=verbose) - log.print(f"Error avg point-distances: {error_natural: 7.3f}", Log_Type.STAGE, verbose=verbose) + log.print( + f"Error avg registration error-vector length: {error_reg: 7.3f}", + Log_Type.STAGE, + verbose=verbose, + ) + log.print( + f"Error avg point-distances: {error_natural: 7.3f}", + Log_Type.STAGE, + verbose=verbose, + ) return init_transform, error_reg, error_natural, delta_after diff --git a/TPTBox/segmentation/TotalVibeSeg/inference_nnunet.py b/TPTBox/segmentation/TotalVibeSeg/inference_nnunet.py index 3607381..13a54e4 100644 --- a/TPTBox/segmentation/TotalVibeSeg/inference_nnunet.py +++ b/TPTBox/segmentation/TotalVibeSeg/inference_nnunet.py @@ -58,6 +58,11 @@ def run_inference_on_file( ddevice: Literal["cpu", "cuda", "mps"] = "cuda", _model_path=None, step_size=0.5, + memory_base=5000, # Base memory in MB, default is 5GB + memory_factor=160, # prod(shape)*memory_factor / 1000, 160 ~> 30 GB + memory_max=160000, # in MB, default is 160GB + wait_till_gpu_percent_is_free=0.1, + verbose=True, ) -> tuple[Image_Reference, np.ndarray | None]: global model_path # noqa: PLW0603 if _model_path is not None: @@ -67,24 +72,25 @@ def run_inference_on_file( if out_file is not None and Path(out_file).exists() and not override: return out_file, None - from TPTBox.segmentation.nnUnet_utils.inference_api import ( - load_inf_model, - run_inference, - ) + from TPTBox.segmentation.nnUnet_utils.inference_api import load_inf_model, run_inference # noqa: PLC0415 download_weights(idx, model_path) try: nnunet_path = next(next(iter(model_path.glob(f"*{idx:03}*"))).glob("*__nnUNet*ResEnc*")) except StopIteration: nnunet_path = next(next(iter(model_path.glob(f"*{idx:03}*"))).glob("*__nnUNetPlans*")) - folds = [int(f.name.split("fold_")[-1]) for f in nnunet_path.glob("fold*")] + folds = sorted([int(f.name.split("fold_")[-1]) for f in nnunet_path.glob("fold*")]) if max_folds is not None: folds = max_folds if isinstance(max_folds, list) else folds[:max_folds] # if idx in _unets: # nnunet = _unets[idx] # else: - + print("load model", nnunet_path.name, "; folds", folds) if verbose else None + with open(Path(nnunet_path, "plans.json")) as f: + plans_info = json.load(f) + with open(Path(nnunet_path, "dataset.json")) as f: + ds_info = json.load(f) nnunet = load_inf_model( nnunet_path, allow_non_final=True, @@ -92,13 +98,13 @@ def run_inference_on_file( gpu=gpu, ddevice=ddevice, step_size=step_size, + memory_base=memory_base, + memory_factor=memory_factor, + memory_max=memory_max, + wait_till_gpu_percent_is_free=wait_till_gpu_percent_is_free, ) # _unets[idx] = nnunet - with open(Path(nnunet_path, "plans.json")) as f: - plans_info = json.load(f) - with open(Path(nnunet_path, "dataset.json")) as f: - ds_info = json.load(f) if "orientation" in ds_info: orientation = ds_info["orientation"] zoom = None @@ -117,10 +123,13 @@ def run_inference_on_file( nnunet_path, ) if orientation is not None: + print("orientation", orientation) if verbose else None input_nii = [i.reorient(orientation) for i in input_nii] if zoom is not None: + print("rescale", zoom) if verbose else None input_nii = [i.rescale_(zoom, mode=mode) for i in input_nii] + print("squash to float16") if verbose else None input_nii = [squash_so_it_fits_in_float16(i) for i in input_nii] if crop: @@ -168,6 +177,10 @@ def run_total_seg( **_kargs, ): global model_path # noqa: PLW0603 + if out_path.exists() and not override: + logger.print(out_path, "already exists. SKIP!", Log_Type.OK) + return out_path + if _model_path is not None: model_path = _model_path if dataset_id is None: @@ -187,9 +200,6 @@ def run_total_seg( return else: download_weights(dataset_id) - if out_path.exists() and not override: - logger.print(out_path, "already exists. SKIP!", Log_Type.OK) - return out_path selected_gpu = gpu if gpu is None: gpu = "auto" # type: ignore @@ -221,4 +231,5 @@ def run_total_seg( crop=crop, max_folds=max_folds, step_size=step_size, + **_kargs, )[0] diff --git a/TPTBox/segmentation/nnUnet_utils/inference_api.py b/TPTBox/segmentation/nnUnet_utils/inference_api.py index 75c7534..0c015d9 100755 --- a/TPTBox/segmentation/nnUnet_utils/inference_api.py +++ b/TPTBox/segmentation/nnUnet_utils/inference_api.py @@ -32,6 +32,10 @@ def load_inf_model( use_gaussian=True, verbose: bool = False, gpu=None, + memory_base=5000, # Base memory in MB, default is 5GB + memory_factor=160, # prod(shape)*memory_factor / 1000, 160 ~> 30 GB + memory_max=160000, # in MB, default is 160GB + wait_till_gpu_percent_is_free=0.3 ) -> nnUNetPredictor: """Loads the Nako-Segmentor Model Predictor @@ -41,6 +45,9 @@ def load_inf_model( "the prediction. Default: 0.5. Cannot be larger than 1. ddevice (str, optional): The device the inference should run with. Available options are 'cuda' " "(GPU), 'cpu' (CPU) and 'mps' (Apple M1/M2). Do NOT use this to set which GPU ID!. Defaults to "cuda". + memory_base (int, optional): Base memory in MB for the model. Default is 5000 MB (5GB). + memory_factor (int, optional): Memory factor for the model. Default is 160, which is ~30GB for a 512x512x512 image. + memory_max (int, optional): Maximum memory in MB for the model. Default is 160000 MB (160GB). Returns: predictor: Loaded model predictor object @@ -77,6 +84,10 @@ def load_inf_model( verbose=verbose, verbose_preprocessing=False, cuda_id=0 if gpu is None else gpu, + memory_base=memory_base, + memory_factor=memory_factor, + memory_max=memory_max, + wait_till_gpu_percent_is_free=wait_till_gpu_percent_is_free ) check_name = "checkpoint_final.pth" # if not allow_non_final else "checkpoint_best.pth" try: diff --git a/TPTBox/segmentation/nnUnet_utils/predictor.py b/TPTBox/segmentation/nnUnet_utils/predictor.py index d19bcee..eecbb1f 100755 --- a/TPTBox/segmentation/nnUnet_utils/predictor.py +++ b/TPTBox/segmentation/nnUnet_utils/predictor.py @@ -8,12 +8,11 @@ import traceback from dataclasses import dataclass, field from math import ceil, floor -from pathlib import Path import numpy as np import torch from acvl_utils.cropping_and_padding.padding import pad_nd_image -from batchgenerators.utilities.file_and_folder_operations import join, load_json, load_pickle +from batchgenerators.utilities.file_and_folder_operations import join, load_json from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels from torch._dynamo import OptimizedModule from tqdm import tqdm @@ -50,6 +49,10 @@ def __init__( verbose: bool = False, verbose_preprocessing: bool = False, allow_tqdm: bool = True, + memory_base=5000, # Base memory in MB, default is 5GB + memory_factor=160, # prod(shape)*memory_factor / 1000, 160 ~> 30 GB + memory_max: int = 160000, # in MB, default is 160GB + wait_till_gpu_percent_is_free=0.3, ): self.verbose = verbose self.verbose_preprocessing = verbose_preprocessing @@ -76,6 +79,10 @@ def __init__( perform_everything_on_gpu = False self.device = device self.perform_everything_on_gpu = perform_everything_on_gpu + self.memory_base = memory_base + self.memory_factor = memory_factor + self.memory_max = memory_max + self.wait_till_gpu_percent_is_free = wait_till_gpu_percent_is_free def initialize_from_trained_model_folder( self, @@ -102,7 +109,14 @@ def initialize_from_trained_model_folder( trainer.output_folder_base = model_training_output_dir trainer.update_fold(0) trainer.initialize(False) - all_best_model_files = [join(model_training_output_dir, f"fold_{i}", "model_final_checkpoint.model") for i in use_folds] + all_best_model_files = [ + join( + model_training_output_dir, + f"fold_{i}", + "model_final_checkpoint.model", + ) + for i in use_folds + ] print("using the following model files: ", all_best_model_files) all_params = [torch.load(i, map_location=torch.device("cpu"))["state_dict"] for i in all_best_model_files] plans = trainer.plans @@ -303,6 +317,11 @@ def predict_logits_from_preprocessed_data(self, data: torch.Tensor, attempts=10) traceback.print_exc() prediction = None self.perform_everything_on_gpu = False + empty_cache(self.device) + if attempts == 0: + raise + + return self.predict_logits_from_preprocessed_data(data, attempts=attempts - 1) # CPU version if prediction is None: @@ -328,8 +347,8 @@ def predict_logits_from_preprocessed_data(self, data: torch.Tensor, attempts=10) prediction /= len(self.list_of_parameters) # type: ignore except RuntimeError: print(f"failed due to insufficient GPU memory. {attempts} attempts remaining.") - print("Error:") - traceback.print_exc() + # print("Error:") + # traceback.print_exc() empty_cache(self.device) if attempts == 0: raise @@ -430,7 +449,10 @@ def predict_sliding_window_return_logits(self, input_image: torch.Tensor, networ # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False # is set. Why. (this is why we don't make use of enabled=False) # So autocast will only be active if we have a cuda device. - with torch.no_grad(), torch.autocast(self.device.type, enabled=True) if self.device.type == "cuda" else dummy_context(): + with ( + torch.no_grad(), + torch.autocast(self.device.type, enabled=True) if self.device.type == "cuda" else dummy_context(), + ): assert len(input_image.shape) == 4, "input_image must be a 4D np.ndarray or torch.Tensor (c, x, y, z)" if self.verbose: print(f"Input shape: {input_image.shape}") @@ -451,27 +473,32 @@ def predict_sliding_window_return_logits(self, input_image: torch.Tensor, networ # print("pixel", np.prod(shape) / 1000000) # print("memory", get_gpu_memory_MB(device), device) - if get_gpu_util(device) > 0.7: - t = tqdm(range(1200)) # Wait 20 minutes + if get_gpu_util(device) > 1 - self.wait_till_gpu_percent_is_free: + t = tqdm(range(2400)) # Wait 40 minutes for i in t: util = get_gpu_util(device) - th = 0.7 + th = 1 - self.wait_till_gpu_percent_is_free if i > 60: - th = 0.8 + th = 1 - self.wait_till_gpu_percent_is_free / 4 * 3 if i > 180: - th = 0.9 - t.desc = f"not enough gpu space {util:.2f} must be under {th:.1f}" + th = 1 - self.wait_till_gpu_percent_is_free / 2 + if i > 1200: + th = 1 - self.wait_till_gpu_percent_is_free / 4 + t.desc = f"not enough gpu space in precent {util:.2f} must be under {th:.1f}" if util < th: break time.sleep(1) - def check_mem(shape, max_memory=160000, min_memory=5000, factor=160): + def check_mem(shape): memory = get_gpu_memory_MB(device) + max_memory = self.memory_max + min_memory = self.memory_base + factor = self.memory_factor # print(shape, "usage", np.prod(shape) / 1000000 * factor, max(min(memory, max_memory), min_memory)) return (np.prod(shape) / 1000000 * factor) + min_memory // 2 < max(min(memory, max_memory), min_memory) with tqdm(total=len(slicers), disable=not self.allow_tqdm) as pbar: - if not check_mem(shape): + if not check_mem(shape) or "nnUNetPlans_2d" not in self.configuration_manager.configuration.get("data_identifier", "3D"): pbar.desc = "splitting in to chunks" pbar.update(0) splits = [1 for _ in shape] @@ -486,9 +513,20 @@ def check_mem(shape, max_memory=160000, min_memory=5000, factor=160): shape_split = [ceil(s / sp) for s, sp in zip(shape, splits)] # print(shape, patch_size, splits, s, np.prod(shape) / 1000000) if check_mem(shape_split): - return self._run_prediction_splits( - data, network, global_shape=shape, splits=splits, slicers=slicers, pbar=pbar - )[(slice(None), *slicer_revert_padding[1:])] + try: + return self._run_prediction_splits( + data, + network, + global_shape=shape, + splits=splits, + slicers=slicers, + pbar=pbar, + )[(slice(None), *slicer_revert_padding[1:])] + except AttributeError as e: + print("_run_prediction_splits failed; fallback to non splits") + print(e) + break + predicted_logits, n_predictions = self._run_sub(data, network, device, slicers, pbar) pbar.desc = "finish" pbar.update(0) @@ -498,7 +536,15 @@ def check_mem(shape, max_memory=160000, min_memory=5000, factor=160): empty_cache(self.device) return predicted_logits[(slice(None), *slicer_revert_padding[1:])] - def _run_prediction_splits(self, data, network, global_shape, splits: list[int], slicers: list[tuple[slice, ...]], pbar: tqdm): + def _run_prediction_splits( + self, + data, + network, + global_shape, + splits: list[int], + slicers: list[tuple[slice, ...]], + pbar: tqdm, + ): widths = [ceil(s / sp) for s, sp in zip(global_shape, splits)] inter_mediate_slice: list[intermediate_slice] = [] pbar.desc = "split in to GPU chunks" @@ -519,15 +565,16 @@ def _run_prediction_splits(self, data, network, global_shape, splits: list[int], raise ValueError(s) # print(inter_mediate_slice) predicted_logits, n_predictions, _, _ = self._allocate(data, "cpu", pbar) - for i in inter_mediate_slice: + for e, i in enumerate(inter_mediate_slice, 1): slices = i.get_intermediate() sub_data = data[slices] - logits, n_pred = self._run_sub(sub_data, network, self.device, i.get_slices(), pbar) + logits, n_pred = self._run_sub( + sub_data, network, self.device, i.get_slices(), pbar, addendum=f"chunks={e}/{len(inter_mediate_slice)}" + ) pbar.desc = "save back chunk to cpu" pbar.update(0) logits = logits.cpu() n_pred = n_pred.cpu() - empty_cache(self.device) predicted_logits[slices] += logits n_predictions[slices[1:]] += n_pred del logits @@ -557,6 +604,9 @@ def _allocate(self, data: torch.Tensor, results_device, pbar: tqdm, gauss=True): device=results_device, ) except RuntimeError as e: + n_predictions = None + gaussian = 1 + predicted_logits = 1 print("allocate FALL BACK CPU") # raise empty_cache(self.device) print(e) @@ -575,24 +625,38 @@ def _allocate(self, data: torch.Tensor, results_device, pbar: tqdm, gauss=True): value_scaling_factor=1000, device=results_device, ) - finally: - empty_cache(self.device) + # finally: + # empty_cache(self.device) return predicted_logits, n_predictions, gaussian, results_device - def _run_sub(self, data: torch.Tensor, network, results_device, slicers, pbar: tqdm): - data = data.to(self.device) # type: ignore - predicted_logits, n_predictions, gaussian, results_device = self._allocate(data, results_device, pbar) - pbar.desc = "running prediction" - for sl in slicers: - pbar.update(1) - work_on = data[sl][None] - work_on = work_on.to(self.device, non_blocking=False) - prediction = self._internal_maybe_mirror_and_predict(work_on, network=network)[0].to(results_device) - if prediction.shape[0] != predicted_logits.shape[0]: - prediction.squeeze_(0) - predicted_logits[sl] += prediction * gaussian if self.use_gaussian else prediction - n_predictions[sl[1:]] += gaussian if self.use_gaussian else 1 - return predicted_logits, n_predictions + def _run_sub(self, data: torch.Tensor, network, results_device, slicers, pbar: tqdm, addendum=""): + try: + data = data.to(self.device) # type: ignore + predicted_logits, n_predictions, gaussian, results_device = self._allocate(data, results_device, pbar) + pbar.desc = f"running prediction {addendum}" + prediction = None + work_on = None + for sl in slicers: + pbar.update(1) + work_on = data[sl][None] + work_on = work_on.to(self.device, non_blocking=False) + prediction = self._internal_maybe_mirror_and_predict(work_on, network=network)[0].to(results_device) + if prediction.shape[0] != predicted_logits.shape[0]: + prediction.squeeze_(0) + predicted_logits[sl] += prediction * gaussian if self.use_gaussian else prediction + n_predictions[sl[1:]] += gaussian if self.use_gaussian else 1 + return predicted_logits, n_predictions # noqa: TRY300 + except RuntimeError: + del predicted_logits + del n_predictions + del gaussian + del work_on + del prediction + empty_cache(self.device) + empty_cache(results_device) + self.memory_base += 1000 + self.memory_factor += 10 + raise @dataclass @@ -625,12 +689,18 @@ def add_slicer(self, s: tuple[slice, ...]): self.slicers.append(s) def get_intermediate(self): - return (slice(None), *tuple(slice(mi, ma) for mi, ma in zip(self.min_s, self.max_s))) # type: ignore + return ( + slice(None), + *tuple(slice(mi, ma) for mi, ma in zip(self.min_s, self.max_s)), + ) # type: ignore def get_slices(self): assert self.min_s is not None for s in self.slicers: - yield (slice(None), *tuple(slice(a.start - o, a.stop - o if a.stop is not None else None) for a, o in zip(s[1:], self.min_s))) + yield ( + slice(None), + *tuple(slice(a.start - o, a.stop - o if a.stop is not None else None) for a, o in zip(s[1:], self.min_s)), + ) def empty_cache(device: torch.device): diff --git a/TPTBox/spine/snapshot2D/snapshot_modular.py b/TPTBox/spine/snapshot2D/snapshot_modular.py index bf86080..99549c8 100755 --- a/TPTBox/spine/snapshot2D/snapshot_modular.py +++ b/TPTBox/spine/snapshot2D/snapshot_modular.py @@ -810,7 +810,10 @@ def create_snapshot( # noqa: C901 seg = to_nii_optional(frame.segmentation, seg=True) # can be None ctd = copy.deepcopy(to_cdt(frame.centroids)) if (crop or frame.crop_msk) and seg is not None: # crop to segmentation - ex_slice = seg.compute_crop() + try: + ex_slice = seg.compute_crop() + except ValueError: + ex_slice = None img = img.copy().apply_crop_(ex_slice) seg = seg.copy().apply_crop_(ex_slice) ctd = ctd.apply_crop(ex_slice).filter_points_inside_shape() if ctd is not None else None diff --git a/TPTBox/spine/snapshot2D/snapshot_templates.py b/TPTBox/spine/snapshot2D/snapshot_templates.py index 564d05c..5fdd0ed 100755 --- a/TPTBox/spine/snapshot2D/snapshot_templates.py +++ b/TPTBox/spine/snapshot2D/snapshot_templates.py @@ -338,6 +338,7 @@ def ct_mri_snapshot( vert_msk_ct: Image_Reference | None = None, subreg_ctd_ct: POI_Reference | None = None, out_path=None, + return_frames=False, ): frames = [ Snapshot_Frame( @@ -361,6 +362,8 @@ def ct_mri_snapshot( crop_msk=False, ), ] + if return_frames: + return frames if out_path is None: assert isinstance(mrt_ref, BIDS_FILE) out_path = mrt_ref.get_changed_path(file_type="png", bids_format="snp", info={"desc": "vert-ct-mri"}) diff --git a/examples/registration/atlas_poi_transfer_leg/atlas_poi_transfer_leg_ct.py b/examples/registration/atlas_poi_transfer_leg/atlas_poi_transfer_leg_ct.py index 0b9d7f3..10be271 100644 --- a/examples/registration/atlas_poi_transfer_leg/atlas_poi_transfer_leg_ct.py +++ b/examples/registration/atlas_poi_transfer_leg/atlas_poi_transfer_leg_ct.py @@ -12,9 +12,12 @@ from time import time import numpy as np +from deepali.core import Axes, PathStr +from deepali.core import Grid as Deepali_Grid from TPTBox import POI, POI_Global, calc_centroids, to_nii from TPTBox.core.internal.deep_learning_utils import DEVICES +from TPTBox.core.nii_poi_abstract import Has_Grid from TPTBox.core.nii_wrapper import NII from TPTBox.core.poi import POI from TPTBox.core.poi_fun.poi_abstract import POI_Descriptor @@ -199,6 +202,46 @@ def parse_coordinates_to_poi(file_path: str | Path, left: bool): return POI_Global(global_points, itk_coords=True, level_one_info=Full_Body_Instance, level_two_info=Lower_Body) +def parse_poi_to_coordinates(poi: POI_Global, output_path: str | Path, left_side: bool | None = None): + poi = poi.to_global() + left_side = ("LEFT" if "LEFT" in str(output_path).upper() else "RIGHT") if left_side is None else "LEFT" if left_side else "RIGHT" + output_path = Path(output_path) + + lines = [] + lines.append("Points Export\n") + lines.append(f"Side: {left_side.upper()}\n\n") + ENUM_TO_ABBREVIATION = {k[1].value: v for v, k in ABBREVIATION_TO_ENUM.items()} + bone_structure = {"Femur proximal": [], "Femur distal": [], "Tibia proximal": [], "Tibia distal": [], "Patella": []} + + for _, k, coords in poi.items(): + abbrev = ENUM_TO_ABBREVIATION[k] + x, y, z = coords + coord_str = f"{abbrev}: ({x}, {y}, {z})" + # Assign to bone section heuristically (adjust as needed for your enums) + if abbrev in ["TGT", "FHC", "FNC", "FAAP"]: + bone_structure["Femur proximal"].append(coord_str) + elif abbrev in ["FLCD", "FMCD", "FLCP", "FMCP", "FNP", "FADP", "TGPP", "TGCP", "FMCPC", "FLCPC", "TRMP", "TRLP"]: + bone_structure["Femur distal"].append(coord_str) + elif abbrev in ["TLCL", "TMCM", "TKC", "TLCA", "TLCP", "TMCA", "TMCP", "TTP", "TAAP", "TMIT", "TLIT"]: + bone_structure["Tibia proximal"].append(coord_str) + elif abbrev in ["FLM", "TMM", "TAC", "TADP"]: + bone_structure["Tibia distal"].append(coord_str) + elif abbrev in ["PPP", "PDP", "PMP", "PLP", "PRPP", "PRDP", "PRHP"]: + bone_structure["Patella"].append(coord_str) + else: + raise ValueError(f"Unknown abbreviation '{abbrev}'") + + # Write structured output + for section, coords in bone_structure.items(): + if coords: + lines.append(f"{section}:\n") + lines.extend([line + "\n" for line in coords]) + lines.append("\n") + output_path.parent.mkdir(exist_ok=True, parents=True) + with open(output_path, "w") as f: + f.writelines(lines) + + default_setting = { "loss": { "config": {"be": {"stride": 1, "name": "BSplineBending"}, "seg": {"name": "MSE"}}, @@ -221,8 +264,8 @@ def parse_coordinates_to_poi(file_path: str | Path, left: bool): # reduce lr if optimization platos # increase min_delta when it stops to early -PATELLA = 14 -LEGS = [13, PATELLA, 15, 16] +PATELLA = [14, 114] +LEGS = [13, 113, *PATELLA, 115, 15, 116, 16] def split_left_right_leg(seg520: NII, c=2, min_volume=50) -> NII: @@ -452,13 +495,32 @@ def __init__( poi_cms: POI | None, same_side: bool, verbose=99, - setting=default_setting, - setting_patella=default_setting, + # setting=default_setting, + # setting_patella=default_setting, gpu=0, ddevice: DEVICES = "cuda", gaussian_sigma=0, + loss_terms=None, # type: ignore + weights=None, + weights_patella=None, + lr=0.001, + lr_patella=0.001, + max_steps=1500, + min_delta=0.00001, + pyramid_levels=4, + coarsest_level=3, + finest_level=0, + satellite_structure=PATELLA, + **args, ): + self.satellite_structure = satellite_structure # Assumes that you have removed the other leg. + if weights is None: + weights = {"be": 0.0001, "seg": 1} + if weights_patella is None: + weights_patella = {"be": 0.0001, "seg": 1} + if loss_terms is None: + loss_terms = {"be": ("BSplineBending", {"stride": 1}), "seg": "MSE"} assert target.seg assert atlas.seg target = target.copy() @@ -478,6 +540,7 @@ def __init__( raise ValueError(axis) for i in [200, 100000]: t_crop = (target).compute_crop(0, i) # if the angel is to different we need a larger crop... + self.t_crop = t_crop target_ = target.apply_crop(t_crop) # Point Registration poi_target = calc_centroids(target_, second_stage=40) @@ -504,38 +567,45 @@ def __init__( atlas_reg.set_dtype_(np.float32) target = target.smooth_gaussian(gaussian_sigma) atlas_reg = atlas_reg.smooth_gaussian(gaussian_sigma) + self.reg_deform = Deformable_Registration( - target, - atlas_reg, - loss_terms={"be": ("BSplineBending", {"stride": 1}), "seg": "MSE"}, # type: ignore - weights={"be": setting["loss"]["weights"]["be"], "seg": setting["loss"]["weights"]["seg"]}, - lr=setting["optim"]["args"]["lr"], - max_steps=setting["optim"]["loop"]["max_steps"], - min_delta=setting["optim"]["loop"]["min_delta"], - pyramid_levels=4, - coarsest_level=3, - finest_level=0, + target.remove_labels(satellite_structure), + atlas_reg.remove_labels(satellite_structure), + loss_terms=loss_terms, # type: ignore + weights=weights, + lr=lr, + max_steps=max_steps, + min_delta=min_delta, + pyramid_levels=pyramid_levels, + coarsest_level=coarsest_level, + finest_level=finest_level, verbose=verbose, + **args, ) + if len(self.satellite_structure) == 0: + self.crop_patella = None + self.reg_deform_p = None + return # self.reg_deform = Deformable_Registration_old(target, atlas_reg, config=setting, verbose=verbose, gpu=gpu, ddevice=ddevice) atlas_reg = self.reg_deform.transform_nii(atlas_reg) # Patella - patella_atlas = atlas_reg.extract_label(PATELLA) - patella_target = target.extract_label(PATELLA) - self.crop_patella = (patella_target + patella_atlas).compute_crop(0, 2) + patella_atlas = atlas_reg.extract_label(satellite_structure) + patella_target = target.extract_label(satellite_structure) + + self.crop_patella = (patella_target + patella_atlas).compute_crop(0, 20) patella_atlas.apply_crop_(self.crop_patella) patella_target.apply_crop_(self.crop_patella) self.reg_deform_p = Deformable_Registration( patella_target, atlas_reg, - loss_terms={"be": ("BSplineBending", {"stride": 1}), "seg": "MSE"}, # type: ignore - weights={"be": setting_patella["loss"]["weights"]["be"], "seg": setting_patella["loss"]["weights"]["seg"]}, - lr=setting_patella["optim"]["args"]["lr"], - max_steps=setting_patella["optim"]["loop"]["max_steps"], - min_delta=setting_patella["optim"]["loop"]["min_delta"], - pyramid_levels=4, - coarsest_level=3, - finest_level=0, + loss_terms=loss_terms, + weights=weights_patella, + lr=lr_patella, + max_steps=max_steps, + min_delta=min_delta, + pyramid_levels=pyramid_levels, + coarsest_level=coarsest_level, + finest_level=finest_level, verbose=verbose, gpu=gpu, ddevice=ddevice, @@ -546,8 +616,16 @@ def get_dump(self): 1, # version (self.reg_point.get_dump()), (self.reg_deform.get_dump()), - (self.reg_deform_p.get_dump()), - (self.same_side, self.atlas_org, self.target_grid_org, self.target_grid, self.crop, self.crop_patella), + (self.reg_deform_p.get_dump() if self.reg_deform_p is not None else None), + ( + self.same_side, + self.atlas_org, + self.target_grid_org, + self.target_grid, + self.crop, + self.crop_patella, + self.satellite_structure, + ), ) def save(self, path: str | Path): @@ -561,12 +639,13 @@ def load(cls, path): @classmethod def load_(cls, w): - (version, t0, t1, t2, x) = w + (version, t0, t1, t2, x, satellite_structure) = w assert version == 1, f"Version mismatch {version=}" self = cls.__new__(cls) self.reg_point = Point_Registration.load_(t0) self.reg_deform = Deformable_Registration.load_(t1) self.reg_deform_p = Deformable_Registration.load_(t2) + self.satellite_structure = satellite_structure self.same_side, self.atlas_org, self.target_grid_org, self.target_grid, self.crop, self.crop_patella = x return self @@ -574,36 +653,46 @@ def forward_nii(self, nii_atlas: NII): nii_atlas = self.reg_point.transform_nii(nii_atlas) nii_atlas = nii_atlas.apply_crop(self.crop) nii_reg = self.reg_deform.transform_nii(nii_atlas) - patella_atlas = nii_reg.extract_label(PATELLA) - nii_reg[patella_atlas == 1] = 0 - patella_atlas.apply_crop_(self.crop_patella) - patella_atlas_reg = self.reg_deform_p.transform_nii(patella_atlas) - patella_atlas_reg.resample_from_to_(nii_reg) - nii_reg[patella_atlas_reg != 0] = PATELLA - nii_reg = nii_reg.resample_from_to(self.target_grid_org) + if self.crop_patella is not None and self.reg_deform_p is not None: + # Patella + patella_atlas = nii_reg.extract_label(*self.satellite_structure) + nii_reg.remove_labels_(self.satellite_structure) + patella_atlas_reg = self.reg_deform_p.transform_nii(patella_atlas) + patella_atlas_reg.resample_from_to_(nii_reg, mode="constant") + nii_reg[patella_atlas_reg != 0] = patella_atlas_reg[patella_atlas_reg != 0] + out = nii_reg.resample_from_to(self.target_grid_org) + if self.same_side: - return nii_reg - return nii_reg.set_array(nii_reg.get_array()[::-1]) + return out + axis = out.get_axis("R") + if axis == 0: + target = out.set_array(out.get_array()[::-1]).copy() + elif axis == 1: + target = out.set_array(out.get_array()[:, ::-1]).copy() + elif axis == 2: + target = out.set_array(out.get_array()[:, :, ::-1]).copy() + else: + raise ValueError(axis) + return target def forward_txt(self, file_path: str | Path, left: bool): poi_glob = parse_coordinates_to_poi(file_path, left) return self.forward_poi(poi_glob, left) - def forward_poi(self, poi_atlas: POI_Global | POI, left): + def forward_poi(self, poi_atlas: POI_Global | POI): poi_atlas = poi_atlas.resample_from_to(self.atlas_org) + # Point Reg poi_atlas = self.reg_point.transform_poi(poi_atlas) # Deformable poi_atlas = poi_atlas.apply_crop(self.crop) poi_reg = self.reg_deform.transform_poi(poi_atlas) - # Patella - poi_patella = poi_reg.apply_crop(self.crop_patella).extract_region( - Full_Body_Instance.patella_left.value, Full_Body_Instance.patella_right.value - ) - patella_poi_reg = self.reg_deform_p.transform_poi(poi_patella) - for k1, k2, v in patella_poi_reg.resample_from_to(poi_reg).items(): - poi_reg[k1, k2] = v - # poi_reg.save(root / "test" / "subreg_reg.json") + if self.crop_patella is not None and self.reg_deform_p is not None: + # Patella + poi_patella = poi_reg.apply_crop(self.crop_patella).extract_region(*self.satellite_structure) + patella_poi_reg = self.reg_deform_p.transform_poi(poi_patella) + for k1, k2, v in patella_poi_reg.resample_from_to(poi_reg).items(): + poi_reg[k1, k2] = v poi_reg = poi_reg.resample_from_to(self.target_grid_org) poi_reg.level_one_info = Full_Body_Instance poi_reg.level_two_info = Lower_Body @@ -611,8 +700,8 @@ def forward_poi(self, poi_atlas: POI_Global | POI, left): return poi_reg for k1, k2, v in poi_reg.copy().items(): k = k1 % 100 - if left: - k += 100 + # if left: + # k += 100 poi_reg[k, k2] = v poi_reg_flip = poi_reg.make_empty_POI() for k1, k2, (x, y, z) in poi_reg.copy().items(): @@ -629,6 +718,48 @@ def forward_poi(self, poi_atlas: POI_Global | POI, left): poi_reg_flip.level_two_info = Lower_Body return poi_reg_flip + def invert_points( + self, points, axes: Axes, grid: Deepali_Grid | Has_Grid, to_axes: Axes = Axes.CUBE, to_grid: Deepali_Grid | Has_Grid | None = None + ): + raise NotImplementedError("This method is not tested and might not work as expected.") + import torch + + reg_deform = self.reg_deform.inverse() + # invert reg_deform + points = reg_deform.transform_points(points, axes=axes, to_axes=Axes.GRID, grid=grid, to_grid=self.target_grid) + # invert self.crop + shift = torch.tensor([x.start for x in self.crop]).unsqueeze(0).to(points.device) + points = points + shift + # invert reg_point + out = [] + for x, y, z in points: + o = self.reg_point.transform_cord_inverse((x, y, z)) # type: ignore + out.append(o) + out = torch.tensor(out) + # invert self.t_crop + shift = torch.tensor([x.start for x in self.t_crop]).unsqueeze(0) + points = points + shift + # flip if needed + if not self.same_side: + grid = self.target_grid_org + axis = grid.get_axis("R") + if axis == 0: + out[:, 0] = grid.shape[0] - 1 - out[:, 0] + elif axis == 1: + out[:, 1] = grid.shape[1] - 1 - out[:, 1] + elif axis == 2: + out[:, 2] = grid.shape[2] - 1 - out[:, 2] + else: + raise ValueError(axis) + # transform to axes and grid + if to_grid is not None or to_axes != Axes.CUBE: + if to_grid is None: + to_grid = self.target_grid_org + if isinstance(to_grid, Has_Grid): + to_grid = to_grid.to_deepali_grid() + out = to_grid.transform_points(out, axes.CUBE, to_axes=to_axes, to_grid=to_grid) + return out + class Register_Point_Atlas_bone_by_bone: def __init__( @@ -855,7 +986,7 @@ def forward_poi(self, poi_atlas: POI_Global | POI, left): p = poi_atlas.to_global(itk_coords=True) coords_dict = parse_coordinates("010__left.txt") for e, (k2, _) in enumerate(coords_dict.items(), 1): - k = PATELLA if k2[0] == "P" else 60 + k = PATELLA[0] if k2[0] == "P" else 60 print(k2, p[k, e]) print(time() - st) diff --git a/examples/registration/atlas_poi_transfer_leg/example.py b/examples/registration/atlas_poi_transfer_leg/example.py index 812a394..2e80189 100644 --- a/examples/registration/atlas_poi_transfer_leg/example.py +++ b/examples/registration/atlas_poi_transfer_leg/example.py @@ -14,7 +14,7 @@ text_file_is_left_leg = True file_text = "/DATA/NAS/tools/TPTBox/examples/atlas_poi_transfer_leg/010__left.txt" segmentation_path = "/DATA/NAS/datasets_processed/CT_fullbody/dataset-watrinet/source/Dataset001_all/0001/bone.nii.gz" -out_folder = Path("/DATA/NAS/datasets_processed/CT_fullbody/dataset-watrinet/atlas") +out_folder = Path("/DATA/NAS/datasets_processed/CT_fullbody/dataset-watrinet/atlas2") atlas_id = 1 ########################################## # Load segmentation diff --git a/pyproject.toml b/pyproject.toml index b2f4dfd..d9d0029 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -156,6 +156,7 @@ ignore = [ "N806", "B905", # strict= in zip "UP007", # Union and "|" python 3.9 + "PLC0415", # import-outside-top-level ] # Allow fix for all enabled rules (when `--fix`) is provided. diff --git a/unit_tests/test_centroids.py b/unit_tests/test_poi_centroids.py old mode 100755 new mode 100644 similarity index 99% rename from unit_tests/test_centroids.py rename to unit_tests/test_poi_centroids.py index b523f0f..24909d0 --- a/unit_tests/test_centroids.py +++ b/unit_tests/test_poi_centroids.py @@ -17,7 +17,7 @@ from TPTBox.core.poi_fun.save_load import _poi_to_dict_list, load_poi from TPTBox.core.vert_constants import Location from TPTBox.tests.test_utils import get_nii, get_random_ax_code, overlap, repeats -from unit_tests.test_centroids_save import get_centroids +from unit_tests.test_poi_save import get_centroids class Test_Centroids(unittest.TestCase): diff --git a/unit_tests/test_centroids_save.py b/unit_tests/test_poi_save.py old mode 100755 new mode 100644 similarity index 81% rename from unit_tests/test_centroids_save.py rename to unit_tests/test_poi_save.py index 2cc4b10..2131f91 --- a/unit_tests/test_centroids_save.py +++ b/unit_tests/test_poi_save.py @@ -115,14 +115,16 @@ def test_save_Glob(self): def test_save_Glob_2(self): for _ in range(repeats): p = Path(s, "test_save_glob_2.json") + p2 = Path(s, "test_save_glob_3.json") cdt = get_centroids(x=get_random_shape(), num_point=20) glob_poi = cdt.to_global() cdt.save(p, verbose=False) - glob_poi.save(p, verbose=False) + glob_poi.save(p2, verbose=False) cdt_a = POI_Global.load(p) - cdt_b = POI_Global.load(p) + cdt_b = POI_Global.load(p2) self.assertEqual(cdt_a, cdt_b) Path(p).unlink() + Path(p2).unlink() def test_save_all(self): for _ in range(repeats): @@ -135,6 +137,29 @@ def test_save_all(self): self.assertEqual(cdt, cdt2) Path(p).unlink() + def test_save_Glob_mkr(self): + for _ in range(repeats): + p = Path(s, "test_save_glob.mrk.json") + cdt = get_centroids(x=get_random_shape(), num_point=20).to_global() + cdt.save_mrk(p) + cdt2 = POI_Global.load(p) + self.assertEqual(cdt, cdt2) + Path(p).unlink() + + def test_save_Glob_2_mkr(self): + for _ in range(repeats): + p = Path(s, "test_save_glob_2.json") + p2 = Path(s, "test_save_glob_3.mrk.json") + cdt = get_centroids(x=get_random_shape(), num_point=20) + glob_poi = cdt.to_global() + cdt.save(p, verbose=False) + glob_poi.save_mrk(p2) + cdt_a = POI_Global.load(p) + cdt_b = POI_Global.load(p2) + self.assertEqual(cdt_a, cdt_b) + Path(p).unlink() + Path(p2).unlink() + if __name__ == "__main__": unittest.main()