1515import numpy as np
1616
1717from monai .deploy .utils .importutil import optional_import
18+ from monai .utils import StrEnum # Will use the built-in StrEnum when SDK requires Python 3.11.
1819
1920MONAI_UTILS = "monai.utils"
2021torch , _ = optional_import ("torch" , "1.5" )
2829 ImageReader = object # for 'class InMemImageReader(ImageReader):' to work
2930decollate_batch , _ = optional_import ("monai.data" , name = "decollate_batch" )
3031sliding_window_inference , _ = optional_import ("monai.inferers" , name = "sliding_window_inference" )
32+ simple_inference , _ = optional_import ("monai.inferers" , name = "SimpleInferer" )
3133ensure_tuple , _ = optional_import (MONAI_UTILS , name = "ensure_tuple" )
3234MetaKeys , _ = optional_import (MONAI_UTILS , name = "MetaKeys" )
3335SpaceKeys , _ = optional_import (MONAI_UTILS , name = "SpaceKeys" )
4042
4143from .inference_operator import InferenceOperator
4244
43- __all__ = ["MonaiSegInferenceOperator" , "InMemImageReader" ]
45+ __all__ = ["MonaiSegInferenceOperator" , "InfererType" , "InMemImageReader" ]
46+
47+
48+ class InfererType (StrEnum ):
49+ """Represents the supported types of the inferer, e.g. Simple and Sliding Window."""
50+
51+ SIMPLE = "simple"
52+ SLIDING_WINDOW = "sliding_window"
4453
4554
4655@md .input ("image" , Image , IOType .IN_MEMORY )
@@ -61,22 +70,30 @@ class MonaiSegInferenceOperator(InferenceOperator):
6170
6271 def __init__ (
6372 self ,
64- roi_size : Union [Sequence [int ], int ],
73+ roi_size : Optional [ Union [Sequence [int ], int ] ],
6574 pre_transforms : Compose ,
6675 post_transforms : Compose ,
6776 model_name : Optional [str ] = "" ,
68- overlap : float = 0.5 ,
77+ overlap : float = 0.25 ,
78+ sw_batch_size : int = 4 ,
79+ inferer : Union [InfererType , str ] = InfererType .SLIDING_WINDOW ,
6980 * args ,
7081 ** kwargs ,
7182 ):
7283 """Creates a instance of this class.
7384
7485 Args:
75- roi_size (Union[Sequence[int], int]): The tensor size used in inference.
86+ roi_size (Union[Sequence[int], int]): The window size to execute "SLIDING_WINDOW" evaluation.
87+ An optional input only to be passed for "SLIDING_WINDOW".
88+ If using a "SIMPLE" Inferer, this input is ignored.
7689 pre_transforms (Compose): MONAI Compose object used for pre-transforms.
7790 post_transforms (Compose): MONAI Compose object used for post-transforms.
7891 model_name (str, optional): Name of the model. Default to "" for single model app.
79- overlap (float): The overlap used in sliding window inference.
92+ overlap (float): The amount of overlap between scans along each spatial dimension. Defaults to 0.25.
93+ Applicable for "SLIDING_WINDOW" only.
94+ sw_batch_size(int): The batch size to run window slices. Defaults to 4.
95+ Applicable for "SLIDING_WINDOW" only.
96+ inferer (InfererType): The type of inferer to use, "SIMPLE" or "SLIDING_WINDOW". Defaults to "SLIDING_WINDOW".
8097 """
8198
8299 super ().__init__ ()
@@ -90,7 +107,9 @@ def __init__(
90107 self ._pre_transform = pre_transforms
91108 self ._post_transforms = post_transforms
92109 self ._model_name = model_name .strip () if isinstance (model_name , str ) else ""
93- self .overlap = overlap
110+ self ._overlap = overlap
111+ self ._sw_batch_size = sw_batch_size
112+ self ._inferer = inferer
94113
95114 @property
96115 def roi_size (self ):
@@ -134,6 +153,28 @@ def overlap(self, val: float):
134153 raise ValueError ("Overlap must be between 0 and 1." )
135154 self ._overlap = val
136155
156+ @property
157+ def sw_batch_size (self ):
158+ """The batch size to run window slices"""
159+ return self ._sw_batch_size
160+
161+ @sw_batch_size .setter
162+ def sw_batch_size (self , val : int ):
163+ if not isinstance (val , int ) or val < 0 :
164+ raise ValueError ("sw_batch_size must be a positive integer." )
165+ self ._sw_batch_size = val
166+
167+ @property
168+ def inferer (self ) -> Union [InfererType , str ]:
169+ """The type of inferer to use"""
170+ return self ._inferer
171+
172+ @inferer .setter
173+ def inferer (self , val : InfererType ):
174+ if not isinstance (val , InfererType ):
175+ raise ValueError (f"Value must be of the correct type { InfererType } ." )
176+ self ._inferer = val
177+
137178 def _convert_dicom_metadata_datatype (self , metadata : Dict ):
138179 """Converts metadata in pydicom types to the corresponding native types.
139180
@@ -218,14 +259,22 @@ def compute(self, op_input: InputContext, op_output: OutputContext, context: Exe
218259 with torch .no_grad ():
219260 for d in dataloader :
220261 images = d [self ._input_dataset_key ].to (device )
221- sw_batch_size = 4
222- d [self ._pred_dataset_key ] = sliding_window_inference (
223- inputs = images ,
224- roi_size = self ._roi_size ,
225- sw_batch_size = sw_batch_size ,
226- overlap = self .overlap ,
227- predictor = model ,
228- )
262+ if self ._inferer == InfererType .SLIDING_WINDOW :
263+ # Uses the util function to drive the sliding_window inferer
264+ d [self ._pred_dataset_key ] = sliding_window_inference (
265+ inputs = images ,
266+ roi_size = self ._roi_size ,
267+ sw_batch_size = self ._sw_batch_size ,
268+ overlap = self ._overlap ,
269+ predictor = model ,
270+ )
271+ elif self ._inferer == InfererType .SIMPLE :
272+ # Instantiates the SimpleInferer and directly uses its __call__ function
273+ d [self ._pred_dataset_key ] = simple_inference ()(inputs = images , network = model )
274+ else :
275+ raise ValueError (
276+ f"Unknown inferer: { self ._inferer } . Available options are sliding_window or simple."
277+ )
229278 d = [post_transforms (i ) for i in decollate_batch (d )]
230279 out_ndarray = d [0 ][self ._pred_dataset_key ].cpu ().numpy ()
231280 # Need to squeeze out the channel dim fist
@@ -241,8 +290,10 @@ def compute(self, op_input: InputContext, op_output: OutputContext, context: Exe
241290 out_ndarray = out_ndarray .T .astype (np .uint8 )
242291 print (f"Output Seg image numpy array shaped: { out_ndarray .shape } " )
243292 print (f"Output Seg image pixel max value: { np .amax (out_ndarray )} " )
293+ print (f"Output Seg image pixel min value: { np .amin (out_ndarray )} " )
244294 out_image = Image (out_ndarray , input_img_metadata )
245295 op_output .set (out_image , "seg_image" )
296+
246297 finally :
247298 # Reset state on completing this method execution.
248299 with self ._lock :
0 commit comments