Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions TPTBox/core/nii_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,7 @@ def astype(self,dtype,order:Literal["C","F","A","K"] ='K', casting:Literal["no",
return c
else:
return self.get_array().astype(dtype,order=order,casting=casting, subok=subok,copy=copy)
def reorient(self:Self, axcodes_to: AX_CODES|None = ("P", "I", "R"), verbose:logging=False, inplace=False)-> Self:
def reorient(self:Self, axcodes_to: AX_CODES|str|None = ("P", "I", "R"), verbose:logging=False, inplace=False)-> Self:
"""
Reorients the input Nifti image to the desired orientation, specified by the axis codes.

Expand Down Expand Up @@ -1863,7 +1863,7 @@ def save(self,file:str|Path,make_parents=True,verbose:logging=True, dtype = None
self.set_dtype_("smallest_uint")
arr = self.get_array() if not self.seg else self.get_seg_array()


self.header.set_data_dtype(arr.dtype)
out = Nifti1Image(arr, self.affine,self.header)#,dtype=arr.dtype)
if dtype is not None:
out.set_data_dtype(dtype)
Expand Down Expand Up @@ -1929,7 +1929,7 @@ def save_nrrd(self:Self, file: str | Path|bids_files.BIDS_FILE,make_parents=True
# Save NRRD file

log.print(f"Saveing {file}",verbose=verbose,ltype=Log_Type.SAVE,end='\r')
nrrd.write(file, data=data, header=header,**args)
nrrd.write(str(file), data=data, header=header,**args) # nrrd only acepts strings...
log.print(f"Save {file} as {header['type']}",verbose=verbose,ltype=Log_Type.SAVE)

def __str__(self) -> str:
Expand Down
10 changes: 8 additions & 2 deletions TPTBox/segmentation/VibeSeg/inference_nnunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ def run_inference_on_file(
if out_file is not None and Path(out_file).exists() and not override:
return out_file, None

from TPTBox.segmentation.nnUnet_utils.inference_api import load_inf_model, run_inference
from TPTBox.segmentation.nnUnet_utils.inference_api import (
load_inf_model,
run_inference,
)

if isinstance(idx, int):
download_weights(idx, model_path)
Expand Down Expand Up @@ -102,7 +105,7 @@ def run_inference_on_file(
if "model_expected_orientation" in ds_info2:
ds_info["orientation"] = ds_info2["model_expected_orientation"]
if "resolution_range" in ds_info2:
ds_info["spacing"] = ds_info2["resolution_range"]
ds_info["resolution_range"] = ds_info2["resolution_range"]

nnunet = load_inf_model(
nnunet_path,
Expand All @@ -127,7 +130,10 @@ def run_inference_on_file(

try:
zoom_old = ds_info.get("spacing")
if idx not in [527] and zoom_old is not None:
zoom_old = zoom_old[::-1]

zoom_old = ds_info.get("resolution_range", zoom_old)
if zoom_old is None:
zoom = plans_info["configurations"]["3d_fullres"]["spacing"]
if all(zoom[0] == z for z in zoom):
Expand Down
4 changes: 2 additions & 2 deletions TPTBox/segmentation/nnUnet_utils/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,9 +523,8 @@ def check_mem(shape):
j = np.argmax(s)
if s[j] == 1:
device = "cpu"
print("Fall Back CPU. Not enough space", shape, patch_size, splits, s)
print("Fall Back CPU. Not enough space; s[j] == 1", shape, patch_size, splits, s)
break
splits[j] += 1
shape_split = [ceil(s / sp) for s, sp in zip(shape, splits)]
# print(shape, patch_size, splits, s, np.prod(shape) / 1000000)
if check_mem(shape_split):
Expand All @@ -543,6 +542,7 @@ def check_mem(shape):
print(e)
break

splits[j] += 1
predicted_logits, n_predictions = self._run_sub(data, network, device, slicers, pbar)
pbar.desc = "finish"
pbar.update(0)
Expand Down