Skip to content

Commit 5e62b9c

Browse files
authored
Merge pull request #96 from Hendrik-code/snapshot_readability
fix issues with the nnunet slitting script
2 parents b49a9d5 + 0c5e9c1 commit 5e62b9c

File tree

3 files changed

+13
-7
lines changed

3 files changed

+13
-7
lines changed

TPTBox/core/nii_wrapper.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -667,7 +667,7 @@ def astype(self,dtype,order:Literal["C","F","A","K"] ='K', casting:Literal["no",
667667
return c
668668
else:
669669
return self.get_array().astype(dtype,order=order,casting=casting, subok=subok,copy=copy)
670-
def reorient(self:Self, axcodes_to: AX_CODES|None = ("P", "I", "R"), verbose:logging=False, inplace=False)-> Self:
670+
def reorient(self:Self, axcodes_to: AX_CODES|str|None = ("P", "I", "R"), verbose:logging=False, inplace=False)-> Self:
671671
"""
672672
Reorients the input Nifti image to the desired orientation, specified by the axis codes.
673673
@@ -1863,7 +1863,7 @@ def save(self,file:str|Path,make_parents=True,verbose:logging=True, dtype = None
18631863
self.set_dtype_("smallest_uint")
18641864
arr = self.get_array() if not self.seg else self.get_seg_array()
18651865

1866-
1866+
self.header.set_data_dtype(arr.dtype)
18671867
out = Nifti1Image(arr, self.affine,self.header)#,dtype=arr.dtype)
18681868
if dtype is not None:
18691869
out.set_data_dtype(dtype)
@@ -1929,7 +1929,7 @@ def save_nrrd(self:Self, file: str | Path|bids_files.BIDS_FILE,make_parents=True
19291929
# Save NRRD file
19301930

19311931
log.print(f"Saveing {file}",verbose=verbose,ltype=Log_Type.SAVE,end='\r')
1932-
nrrd.write(file, data=data, header=header,**args)
1932+
nrrd.write(str(file), data=data, header=header,**args) # nrrd only acepts strings...
19331933
log.print(f"Save {file} as {header['type']}",verbose=verbose,ltype=Log_Type.SAVE)
19341934

19351935
def __str__(self) -> str:

TPTBox/segmentation/VibeSeg/inference_nnunet.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,10 @@ def run_inference_on_file(
7272
if out_file is not None and Path(out_file).exists() and not override:
7373
return out_file, None
7474

75-
from TPTBox.segmentation.nnUnet_utils.inference_api import load_inf_model, run_inference
75+
from TPTBox.segmentation.nnUnet_utils.inference_api import (
76+
load_inf_model,
77+
run_inference,
78+
)
7679

7780
if isinstance(idx, int):
7881
download_weights(idx, model_path)
@@ -102,7 +105,7 @@ def run_inference_on_file(
102105
if "model_expected_orientation" in ds_info2:
103106
ds_info["orientation"] = ds_info2["model_expected_orientation"]
104107
if "resolution_range" in ds_info2:
105-
ds_info["spacing"] = ds_info2["resolution_range"]
108+
ds_info["resolution_range"] = ds_info2["resolution_range"]
106109

107110
nnunet = load_inf_model(
108111
nnunet_path,
@@ -127,7 +130,10 @@ def run_inference_on_file(
127130

128131
try:
129132
zoom_old = ds_info.get("spacing")
133+
if idx not in [527] and zoom_old is not None:
134+
zoom_old = zoom_old[::-1]
130135

136+
zoom_old = ds_info.get("resolution_range", zoom_old)
131137
if zoom_old is None:
132138
zoom = plans_info["configurations"]["3d_fullres"]["spacing"]
133139
if all(zoom[0] == z for z in zoom):

TPTBox/segmentation/nnUnet_utils/predictor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -523,9 +523,8 @@ def check_mem(shape):
523523
j = np.argmax(s)
524524
if s[j] == 1:
525525
device = "cpu"
526-
print("Fall Back CPU. Not enough space", shape, patch_size, splits, s)
526+
print("Fall Back CPU. Not enough space; s[j] == 1", shape, patch_size, splits, s)
527527
break
528-
splits[j] += 1
529528
shape_split = [ceil(s / sp) for s, sp in zip(shape, splits)]
530529
# print(shape, patch_size, splits, s, np.prod(shape) / 1000000)
531530
if check_mem(shape_split):
@@ -543,6 +542,7 @@ def check_mem(shape):
543542
print(e)
544543
break
545544

545+
splits[j] += 1
546546
predicted_logits, n_predictions = self._run_sub(data, network, device, slicers, pbar)
547547
pbar.desc = "finish"
548548
pbar.update(0)

0 commit comments

Comments
 (0)