From 6e3201f3ced62ab2efc40d93fa2fb485514b87ea Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Mon, 15 Jul 2024 19:33:42 -0700 Subject: [PATCH 01/31] Initial profiling/export Signed-off-by: Boris Fomitchev --- scripts/export.bash | 3 + scripts/export.py | 321 ++++++++++++++++++++++++++++++++ scripts/utils/trans_utils.py | 3 +- vista3d/modeling/segresnetds.py | 7 +- vista3d/modeling/vista3d.py | 51 ++++- 5 files changed, 379 insertions(+), 6 deletions(-) create mode 100755 scripts/export.bash create mode 100644 scripts/export.py diff --git a/scripts/export.bash b/scripts/export.bash new file mode 100755 index 0000000..69c734d --- /dev/null +++ b/scripts/export.bash @@ -0,0 +1,3 @@ +# python3 -m scripts.export --config_file 'configs/infer.yaml' - infer_everything --image_file 'example-1.nii.gz' + +python3 -m scripts.export --config_file 'configs/infer.yaml' - infer --image_file 'example-1.nii.gz' --label_prompt [1] --save_mask true diff --git a/scripts/export.py b/scripts/export.py new file mode 100644 index 0000000..0423549 --- /dev/null +++ b/scripts/export.py @@ -0,0 +1,321 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import sys +from functools import partial + +import monai +import numpy as np +import torch +import torch.distributed as dist +from monai import transforms +from monai.apps.auto3dseg.auto_runner import logger +from monai.auto3dseg.utils import datafold_read +from monai.bundle import ConfigParser +from monai.bundle.scripts import _pop_args, _update_args +from monai.data import decollate_batch, list_data_collate, partition_dataset +from monai.utils import optional_import + +from vista3d import vista_model_registry + +from .sliding_window import point_based_window_inferer, sliding_window_inference +from .train import CONFIG +from .utils.trans_utils import VistaPostTransform + +rearrange, _ = optional_import("einops", name="rearrange") +sys.path.insert(0, os.path.abspath(os.path.dirname(__file__))) +IGNORE_PROMPT = set( + [ + 2, # kidney + 16, # prostate or uterus + 18, # rectum + 20, # lung + 21, # bone + 23, # lung tumor + 24, # pancreatic tumor + 25, # hepatic vessel + 26, # hepatic tumor + 27, # colon cancer primaries + 128, # bone lesion + 129, # kidney mass + 130, # liver tumor + 131, # vertebrae L6 + 132, + ] +) # airway +EVERYTHING_PROMPT = list(set([i + 1 for i in range(133)]) - IGNORE_PROMPT) + + +def infer_wrapper(inputs, model, **kwargs): + outputs = model(input_images=inputs, **kwargs) + return outputs.transpose(1, 0) + + +class InferClass: + def __init__(self, config_file="./configs/infer.yaml", **override): + logging.basicConfig(stream=sys.stdout, level=logging.INFO) + + _args = _update_args(config_file=config_file, **override) + config_file_ = _pop_args(_args, "config_file")[0] + + parser = ConfigParser() + parser.read_config(config_file_) + parser.update(pairs=_args) + + # We do not use AMP for export + self.amp = False # parser.get_parsed_content("amp") + input_channels = parser.get_parsed_content("input_channels") + patch_size = parser.get_parsed_content("patch_size") + self.patch_size = patch_size + + ckpt_name = parser.get_parsed_content("infer")["ckpt_name"] + output_path = parser.get_parsed_content("infer")["output_path"] + if not os.path.exists(output_path): + os.makedirs(output_path, exist_ok=True) + + CONFIG["handlers"]["file"]["filename"] = parser.get_parsed_content("infer")[ + "log_output_file" + ] + logging.config.dictConfig(CONFIG) + self.infer_transforms = parser.get_parsed_content("transforms_infer") + + self.device = torch.device("cuda:0") + model_registry = parser.get_parsed_content("model") + model = vista_model_registry[model_registry]( + in_channels=input_channels, image_size=patch_size + ) + self.model = model.to(self.device) + + pretrained_ckpt = torch.load(ckpt_name, map_location=self.device) + self.model.load_state_dict(pretrained_ckpt, strict=False) + logger.debug(f"[debug] checkpoint {ckpt_name:s} loaded") + post_transforms = [ + VistaPostTransform(keys="pred"), + transforms.Invertd( + keys="pred", + transform=self.infer_transforms, + orig_keys="image", + meta_keys="pred_meta_dict", + orig_meta_keys="image_meta_dict", + meta_key_postfix="meta_dict", + nearest_interp=True, + to_tensor=True, + ), + ] + + # For Vista3d, sigmoid is always used, but for visualization, argmax is needed + save_transforms = [ + transforms.SaveImaged( + keys="pred", + meta_keys="pred_meta_dict", + output_dir=output_path, + output_postfix="seg", + resample=False, + data_root_dir=None, + print_log=False, + ) + ] + self.post_transforms = transforms.Compose(post_transforms) + self.save_transforms = transforms.Compose(save_transforms) + self.prev_mask = None + self.batch_data = None + return + + def clear_cache(self): + self.prev_mask = None + self.batch_data = None + + def transform_points(self, point, affine): + """transform point to the coordinates of the transformed image + point: numpy array [bs, N, 3] + """ + bs, N = point.shape[:2] + point = np.concatenate((point, np.ones((bs, N, 1))), axis=-1) + point = rearrange(point, "b n d -> d (b n)") + point = affine @ point + point = rearrange(point, "d (b n)-> b n d", b=bs)[:, :, :3] + return point + + @torch.no_grad() + def infer( + self, + image_file, + point=None, + point_label=None, + label_prompt=None, + prompt_class=None, + save_mask=False, + point_start=0, + ): + """Infer a single image_file. If save_mask is true, save the argmax prediction to disk. If false, + do not save and return the probability maps (usually used by autorunner emsembler). point_start is + used together with prev_mask. If prev_mask is generated by N points, point_start should be N+1 to save + time and avoid repeated inference. This is by default disabled. + """ + self.model.eval() + if not isinstance(image_file, dict): + image_file = {"image": image_file} + if self.batch_data is not None: + batch_data = self.batch_data + else: + batch_data = self.infer_transforms(image_file) + batch_data["label_prompt"] = label_prompt + batch_data = list_data_collate([batch_data]) + self.batch_data = batch_data + if point is not None: + point = self.transform_points( + point, + np.linalg.inv(batch_data["image"].affine[0]) + @ batch_data["image"].meta["original_affine"][0].numpy(), + ) + self.sliding_window_inferer = partial( + point_based_window_inferer, point_start=point_start + ) + else: + self.sliding_window_inferer = sliding_window_inference + device_list_input = [self.device, self.device, "cpu"] + device_list_output = [self.device, "cpu", "cpu"] + for _device_in, _device_out in zip(device_list_input, device_list_output): + try: + with torch.cuda.amp.autocast(enabled=self.amp): + batch_data["pred"] = self.sliding_window_inferer( + inputs=batch_data["image"].to(_device_in), + roi_size=self.patch_size, + sw_batch_size=1, + predictor=partial(infer_wrapper, model=self.model), + mode="gaussian", + overlap=0.625, + progress=True, + sw_device=self.device, + device=_device_out, + point_coords=( + torch.tensor(point).to(_device_in) + if point is not None + else None + ), + point_labels=( + torch.tensor(point_label).to(_device_in) + if point_label is not None + else None + ), + class_vector=( + torch.tensor(label_prompt).to(_device_in) + if label_prompt is not None + else None + ), + prompt_class=( + torch.tensor(prompt_class).to(_device_in) + if prompt_class is not None + else None + ), + prev_mask=( + torch.tensor(self.prev_mask).to(_device_in) + if self.prev_mask is not None + else None + ), + ) + + if not hasattr(batch_data["pred"], "meta"): + batch_data["pred"] = monai.data.MetaTensor( + batch_data["pred"], + affine=batch_data["image"].meta["affine"], + meta=batch_data["image"].meta, + ) + self.prev_mask = batch_data["pred"] + batch_data["image"] = batch_data["image"].to("cpu") + batch_data["pred"] = batch_data["pred"].to("cpu") + torch.cuda.empty_cache() + batch_data = [ + self.post_transforms(i) for i in decollate_batch(batch_data) + ] + if save_mask: + batch_data = [self.save_transforms(i) for i in batch_data] + + finished = True + except RuntimeError as e: + if not any(x in str(e).lower() for x in ("memory", "cuda", "cudnn")): + raise e + finished = False + if finished: + break + if not finished: + raise RuntimeError("Infer not finished due to OOM.") + return batch_data[0]["pred"] + + @torch.no_grad() + def infer_everything(self, image_file, label_prompt=EVERYTHING_PROMPT, rank=0): + self.model.eval() + device = f"cuda:{rank}" + if not isinstance(image_file, dict): + image_file = {"image": image_file} + batch_data = self.infer_transforms(image_file) + batch_data["label_prompt"] = label_prompt + batch_data = list_data_collate([batch_data]) + device_list_input = [device, device, "cpu"] + device_list_output = [device, "cpu", "cpu"] + for _device_in, _device_out in zip(device_list_input, device_list_output): + try: + with torch.cuda.amp.autocast(enabled=self.amp): + batch_data["pred"] = sliding_window_inference( + inputs=batch_data["image"].to(_device_in), + roi_size=self.patch_size, + sw_batch_size=1, + predictor=partial(infer_wrapper, model=self.model), + mode="gaussian", + overlap=0.625, + sw_device=device, + device=_device_out, + class_vector=torch.tensor(label_prompt).to(_device_in), + ) + if not hasattr(batch_data["pred"], "meta"): + batch_data["pred"] = monai.data.MetaTensor( + batch_data["pred"], + affine=batch_data["image"].meta["affine"], + meta=batch_data["image"].meta, + ) + torch.cuda.empty_cache() + batch_data = [ + self.post_transforms(i) for i in decollate_batch(batch_data) + ] + batch_data = [self.save_transforms(i) for i in batch_data] + finished = True + except RuntimeError as e: + if not any(x in str(e).lower() for x in ("memory", "cuda", "cudnn")): + raise e + finished = False + if finished: + break + if not finished: + raise RuntimeError("Infer not finished due to OOM.") + + @torch.no_grad() + def batch_infer_everything(self, datalist=str, basedir=str): + train_files, _ = datafold_read(datalist=datalist, basedir=basedir, fold=0) + train_files = [_["image"] for _ in train_files] + dist.init_process_group(backend="nccl", init_method="env://") + world_size = dist.get_world_size() + rank = dist.get_rank() + # no need to wrap model with DistributedDataParallel + self.model = self.model.to(f"cuda:{rank}") + infer_files = partition_dataset( + data=train_files, + shuffle=False, + num_partitions=world_size, + even_divisible=False, + )[rank] + self.infer(infer_files, label_prompt=EVERYTHING_PROMPT, rank=rank) + + +if __name__ == "__main__": + fire, _ = optional_import("fire") + fire.Fire(InferClass) diff --git a/scripts/utils/trans_utils.py b/scripts/utils/trans_utils.py index c6dc0c0..d4520a1 100644 --- a/scripts/utils/trans_utils.py +++ b/scripts/utils/trans_utils.py @@ -349,7 +349,8 @@ def __call__( pred += 0.5 # inplace mapping to avoid cloning pred for i in range(1, object_num + 1): frac = i + 0.5 - pred[pred == frac] = data["label_prompt"][i - 1].to(pred.dtype) + pred[pred == frac] = torch.tensor(data["label_prompt"][i - 1]).to(pred.dtype) + # pred[pred == frac] = data["label_prompt"][i - 1].to(pred.dtype) pred[pred == 0.5] = 0.0 data[keys] = pred return data diff --git a/vista3d/modeling/segresnetds.py b/vista3d/modeling/segresnetds.py index e8c96d5..0a39508 100644 --- a/vista3d/modeling/segresnetds.py +++ b/vista3d/modeling/segresnetds.py @@ -238,7 +238,6 @@ def _forward(self, x: torch.Tensor) -> list[torch.Tensor]: if self.head_module is not None: outputs = self.head_module(outputs) - return outputs def forward(self, x: torch.Tensor) -> list[torch.Tensor]: @@ -464,7 +463,7 @@ def is_valid_shape(self, x): def _forward( self, x: torch.Tensor, with_point, with_label - ) -> Union[None, torch.Tensor, list[torch.Tensor]]: + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: if self.preprocess is not None: x = self.preprocess(x) @@ -521,8 +520,8 @@ def _forward( return outputs, outputs_auto def forward( - self, x: torch.Tensor, with_point=True, with_label=True, **kwargs - ) -> Union[None, torch.Tensor, list[torch.Tensor]]: + self, x: torch.Tensor, with_point=True, with_label=True, # **kwargs + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: return self._forward(x, with_point, with_label) def set_auto_grad(self, auto_freeze=False, point_freeze=False): diff --git a/vista3d/modeling/vista3d.py b/vista3d/modeling/vista3d.py index 8352ba9..39505be 100644 --- a/vista3d/modeling/vista3d.py +++ b/vista3d/modeling/vista3d.py @@ -16,6 +16,7 @@ import torch import torch.nn as nn from monai.utils import optional_import +import time from scripts.utils.trans_utils import convert_points_to_disc from scripts.utils.trans_utils import get_largest_connected_component_mask as lcc @@ -41,7 +42,8 @@ def __init__(self, image_encoder, class_head, point_head, feature_size): ) self.auto_freeze = False self.point_freeze = False - + self.engine = None + def precompute_embedding(self, input_images): """precompute image embedding, require sliding window inference""" raise NotImplementedError @@ -203,6 +205,8 @@ def set_auto_grad(self, auto_freeze=False, point_freeze=False): param.requires_grad = not point_freeze self.point_freeze = point_freeze + + def forward( self, input_images, @@ -245,6 +249,8 @@ def forward( val_point_sampler: function used to sample points from labels. This is only used for point-only evaluation. """ + time00 = time.time() + image_size = input_images.shape[-3:] device = input_images.device if point_coords is None and class_vector is None: @@ -296,21 +302,59 @@ def forward( ): out, out_auto = self.image_embeddings, None else: + time0 = time.time() out, out_auto = self.image_encoder( input_images, with_point=point_coords is not None, with_label=class_vector is not None, ) + torch.cuda.synchronize() + print(f"Encoder Time: {time.time() - time0}, shape : {input_images.shape}, point: {point_coords is not None}") + if False: + # breakpoint() + torch.onnx.export(self.image_encoder, + (input_images,), + "Encoder.onnx", + verbose=False, + opset_version=18 + ) + self.engine = True + input_images = None + time1 = time.time() # force releasing memories that set to None torch.cuda.empty_cache() if class_vector is not None: + time2 = time.time() logits, _ = self.class_head(out_auto, class_vector) + torch.cuda.synchronize() + print(f"Class Head Time: {time.time() - time2}") + + if self.engine is None: + torch.onnx.export(self.class_head, + (out_auto, class_vector,), + "class_head.onnx", + verbose=True, + opset_version=18 + ) + if False: + torch.onnx.export(self.point_head, + (out, point_coords, point_labels, {"class_vector":prompt_class}), + "point_head.onnx", + verbose=False, + opset_version=18 + ) + self.engine = True + if point_coords is not None: + time3 = time.time() point_logits = self.point_head( out, point_coords, point_labels, class_vector=prompt_class ) + torch.cuda.synchronize() + print(f"Point Head Time: {time.time() - time3}") + time4 = time.time() if patch_coords is None: logits = self.gaussian_combine( logits, @@ -325,6 +369,8 @@ def forward( logits = self.connected_components_combine( logits, point_logits, point_coords, point_labels, mapping_index ) + torch.cuda.synchronize() + print(f"Combine Time: {time.time() - time4}") else: logits = NINF_VALUE + torch.zeros( [bs, 1, *image_size], device=device, dtype=out.dtype @@ -341,6 +387,9 @@ def forward( mapping_index, ) + torch.cuda.synchronize() + print(f"Head time: {time.time() - time1}, total time : {time.time() - time00} shape : {logits.shape}") + if kwargs.get("keep_cache", False) and class_vector is None: self.image_embeddings = out.detach() return logits From a1da5e2f99805633e29c270d676ebaf5a5616157 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Thu, 18 Jul 2024 19:49:16 -0700 Subject: [PATCH 02/31] stash Signed-off-by: Boris Fomitchev --- vista3d/modeling/vista3d.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) mode change 100644 => 100755 vista3d/modeling/vista3d.py diff --git a/vista3d/modeling/vista3d.py b/vista3d/modeling/vista3d.py old mode 100644 new mode 100755 index 39505be..51f357d --- a/vista3d/modeling/vista3d.py +++ b/vista3d/modeling/vista3d.py @@ -310,7 +310,7 @@ def forward( ) torch.cuda.synchronize() print(f"Encoder Time: {time.time() - time0}, shape : {input_images.shape}, point: {point_coords is not None}") - if False: + if self.engine is None: # breakpoint() torch.onnx.export(self.image_encoder, (input_images,), @@ -318,7 +318,6 @@ def forward( verbose=False, opset_version=18 ) - self.engine = True input_images = None time1 = time.time() From d854bc59021a5986abd04d931ab7a1dbdeef6f3c Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Tue, 23 Jul 2024 17:07:54 -0700 Subject: [PATCH 03/31] Working TRT wrappers for encoder and class head Signed-off-by: Boris Fomitchev --- scripts/export.bash | 4 +- scripts/export.py | 28 +- scripts/utils/cast_utils.py | 96 ++++++ scripts/utils/export_utils.py | 324 +++++++++++++++++++ scripts/utils/trt_utils.py | 534 ++++++++++++++++++++++++++++++++ vista3d/modeling/segresnetds.py | 8 +- vista3d/modeling/vista3d.py | 70 ++--- 7 files changed, 1017 insertions(+), 47 deletions(-) create mode 100644 scripts/utils/cast_utils.py create mode 100644 scripts/utils/export_utils.py create mode 100644 scripts/utils/trt_utils.py diff --git a/scripts/export.bash b/scripts/export.bash index 69c734d..6e267ff 100755 --- a/scripts/export.bash +++ b/scripts/export.bash @@ -1,3 +1,3 @@ -# python3 -m scripts.export --config_file 'configs/infer.yaml' - infer_everything --image_file 'example-1.nii.gz' +python3 -m scripts.export --config_file 'configs/infer.yaml' - infer_everything --image_file 'example-1.nii.gz' -python3 -m scripts.export --config_file 'configs/infer.yaml' - infer --image_file 'example-1.nii.gz' --label_prompt [1] --save_mask true +# python3 -m scripts.export --config_file 'configs/infer.yaml' - infer --image_file 'example-1.nii.gz' --label_prompt [1] --save_mask true diff --git a/scripts/export.py b/scripts/export.py index 0423549..ff53b84 100644 --- a/scripts/export.py +++ b/scripts/export.py @@ -31,6 +31,8 @@ from .sliding_window import point_based_window_inferer, sliding_window_inference from .train import CONFIG from .utils.trans_utils import VistaPostTransform +from .utils.trt_utils import ExportWrapper, TRTWrapper +import time rearrange, _ = optional_import("einops", name="rearrange") sys.path.insert(0, os.path.abspath(os.path.dirname(__file__))) @@ -60,7 +62,6 @@ def infer_wrapper(inputs, model, **kwargs): outputs = model(input_images=inputs, **kwargs) return outputs.transpose(1, 0) - class InferClass: def __init__(self, config_file="./configs/infer.yaml", **override): logging.basicConfig(stream=sys.stdout, level=logging.INFO) @@ -73,7 +74,7 @@ def __init__(self, config_file="./configs/infer.yaml", **override): parser.update(pairs=_args) # We do not use AMP for export - self.amp = False # parser.get_parsed_content("amp") + self.amp = parser.get_parsed_content("amp") input_channels = parser.get_parsed_content("input_channels") patch_size = parser.get_parsed_content("patch_size") self.patch_size = patch_size @@ -129,6 +130,17 @@ def __init__(self, config_file="./configs/infer.yaml", **override): self.save_transforms = transforms.Compose(save_transforms) self.prev_mask = None self.batch_data = None + + en_wrapper = ExportWrapper.wrap(self.model.image_encoder.encoder, + input_names = ['x'], output_names = ['x_out']) + self.model.image_encoder.encoder = TRTWrapper("Encoder", en_wrapper, use_cuda_graph=False) + # self.model.image_encoder.encoder.load_engine() + + cls_wrapper = ExportWrapper.wrap(self.model.class_head, + input_names = ['src', 'class_vector'], output_names = ['masks', 'class_embedding']) + self.model.class_head = TRTWrapper("ClassHead", cls_wrapper, use_cuda_graph=False) + # self.model.class_head.load_engine() + return def clear_cache(self): @@ -162,6 +174,7 @@ def infer( used together with prev_mask. If prev_mask is generated by N points, point_start should be N+1 to save time and avoid repeated inference. This is by default disabled. """ + time00=time.time() self.model.eval() if not isinstance(image_file, dict): image_file = {"image": image_file} @@ -248,12 +261,15 @@ def infer( finished = False if finished: break + print(f"Infer Time: {time.time() - time00}") + if not finished: raise RuntimeError("Infer not finished due to OOM.") return batch_data[0]["pred"] @torch.no_grad() def infer_everything(self, image_file, label_prompt=EVERYTHING_PROMPT, rank=0): + time00=time.time() self.model.eval() device = f"cuda:{rank}" if not isinstance(image_file, dict): @@ -295,6 +311,8 @@ def infer_everything(self, image_file, label_prompt=EVERYTHING_PROMPT, rank=0): finished = False if finished: break + print(f"InferEverything Time: {time.time() - time00}") + if not finished: raise RuntimeError("Infer not finished due to OOM.") @@ -317,5 +335,11 @@ def batch_infer_everything(self, datalist=str, basedir=str): if __name__ == "__main__": + try: + #import torch_onnx + #torch_onnx.patch_torch(error_report=True) + print("patch succeeded") + except Exception: + pass fire, _ = optional_import("fire") fire.Fire(InferClass) diff --git a/scripts/utils/cast_utils.py b/scripts/utils/cast_utils.py new file mode 100644 index 0000000..ff58dde --- /dev/null +++ b/scripts/utils/cast_utils.py @@ -0,0 +1,96 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextlib import nullcontext + +import torch + +def avoid_bfloat16_autocast_context(): + """ + If the current autocast context is bfloat16, + cast it to float32 + """ + + if torch.is_autocast_enabled() and torch.get_autocast_gpu_dtype() == torch.bfloat16: + return torch.cuda.amp.autocast(dtype=torch.float32) + else: + return nullcontext() + + +def avoid_float16_autocast_context(): + """ + If the current autocast context is float16, cast it to bfloat16 + if available (unless we're in jit) or float32 + """ + + if torch.is_autocast_enabled() and torch.get_autocast_gpu_dtype() == torch.float16: + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return torch.cuda.amp.autocast(dtype=torch.float32) + + if torch.cuda.is_bf16_supported(): + return torch.cuda.amp.autocast(dtype=torch.bfloat16) + else: + return torch.cuda.amp.autocast(dtype=torch.float32) + else: + return nullcontext() + + +def cast_tensor(x, from_dtype=torch.float16, to_dtype=torch.float32): + return x.to(dtype=to_dtype) if x.dtype == from_dtype else x + + +def cast_all(x, from_dtype=torch.float16, to_dtype=torch.float32): + if isinstance(x, torch.Tensor): + return cast_tensor(x, from_dtype=from_dtype, to_dtype=to_dtype) + else: + if isinstance(x, dict): + new_dict = {} + for k in x.keys(): + new_dict[k] = cast_all(x[k], from_dtype=from_dtype, to_dtype=to_dtype) + return new_dict + elif isinstance(x, tuple): + return tuple(cast_all(y, from_dtype=from_dtype, to_dtype=to_dtype) for y in x) + + +class CastToFloat(torch.nn.Module): + def __init__(self, mod): + super(CastToFloat, self).__init__() + self.mod = mod + + def forward(self, x): + with torch.cuda.amp.autocast(enabled=False): + ret = self.mod.forward(x.to(torch.float32)).to(x.dtype) + return ret + + +class CastToFloatAll(torch.nn.Module): + def __init__(self, mod): + super(CastToFloatAll, self).__init__() + self.mod = mod + + def forward(self, *args): + from_dtype = args[0].dtype + with torch.cuda.amp.autocast(enabled=False): + ret = self.mod.forward(*cast_all(args, from_dtype=from_dtype, to_dtype=torch.float32)) + return cast_all(ret, from_dtype=torch.float32, to_dtype=from_dtype) diff --git a/scripts/utils/export_utils.py b/scripts/utils/export_utils.py new file mode 100644 index 0000000..d09cfce --- /dev/null +++ b/scripts/utils/export_utils.py @@ -0,0 +1,324 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from contextlib import nullcontext +from enum import Enum +from typing import Callable, Dict, Optional, Type +import logging +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .cast_utils import CastToFloat, CastToFloatAll + +class LinearWithBiasSkip(nn.Module): + def __init__(self, weight, bias, skip_bias_add): + super(LinearWithBiasSkip, self).__init__() + self.bias = bias + self.weight = weight + self.skip_bias_add = skip_bias_add + + def forward(self, x): + if self.skip_bias_add: + return F.linear(x, self.weight), self.bias + return F.linear(x, self.weight, self.bias), None + +def run_ts_and_compare(ts_model, ts_input_list, ts_input_dict, output_example, check_tolerance=0.01): + # Verify the model can be read, and is valid + ts_out = ts_model(*ts_input_list, **ts_input_dict) + + all_good = True + for i, out in enumerate(ts_out): + expected = output_example[i] + + if torch.is_tensor(expected): + tout = out.to('cpu') + print(f"Checking output {i}, shape: {expected.shape}:\n") + this_good = True + try: + if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=check_tolerance): + this_good = False + except Exception: # there may ne size mismatch and it may be OK + this_good = False + if not this_good: + print(f"Results mismatch! PyTorch(expected):\n{expected}\nTorchScript:\n{tout}") + all_good = False + return all_good + + +def run_ort_and_compare(sess, ort_input, output_example, check_tolerance=0.01): + # Verify the model can be read, and is valid + ort_out = sess.run(None, ort_input) + all_good = True + for i, out in enumerate(ort_out): + expected = output_example[i] + + if torch.is_tensor(expected): + tout = torch.from_numpy(out) + print(f"Checking output {i}, shape: {expected.shape}:\n") + this_good = True + try: + if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=100 * check_tolerance): + this_good = False + except Exception: # there may ne size mismatch and it may be OK + this_good = False + if not this_good: + print(f"onnxruntime results mismatch! PyTorch(expected):\n{expected}\nONNXruntime:\n{tout}") + all_good = False + return all_good + + +apex_available = True + +try: + from apex.contrib.layer_norm.layer_norm import FastLayerNorm + from apex.normalization.fused_layer_norm import FusedLayerNorm, MixedFusedLayerNorm + from apex.transformer.functional.fused_softmax import FusedScaleMaskSoftmax + from apex.transformer.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear + + def replace_FusedLayerNorm(n: nn.Module) -> Optional[nn.LayerNorm]: + """ + Replaces Apex's FusedLayerNorm with nn.LayerNorm. This is required for ONNX export. + Args: + n: the FusedLayerNorm pytorch module to replace + Returns: + Equivalent LayerNorm module + """ + + p = next(n.parameters()) + if isinstance(n, FusedLayerNorm) or isinstance(n, MixedFusedLayerNorm): + shape, eps, affine = n.normalized_shape, n.eps, n.elementwise_affine + elif isinstance(n, FastLayerNorm): + shape, eps, affine = n.weight.shape, n.epsilon, True + else: + return None + + mod = nn.LayerNorm(shape, eps=eps, elementwise_affine=affine, device=p.device, dtype=p.dtype) + n_state = n.state_dict() + mod.load_state_dict(n_state) + return mod + + def replace_RowParallelLinear(n: nn.Module) -> Optional[nn.Linear]: + """ + Replaces Apex's FusedLayerNorm with nn.LayerNorm. This is required for ONNX export. + Args: + n: the FusedLayerNorm pytorch module to replace + Returns: + Equivalent LayerNorm module + """ + if not isinstance(n, RowParallelLinear): + raise ValueError("This function can only change the RowParallelLinear module.") + + dev = next(n.parameters()).device + mod = LinearWithBiasSkip(n.weight, n.bias, n.skip_bias_add).to(device=dev) + + n_state = n.state_dict() + mod.load_state_dict(n_state) + return mod + + def replace_ParallelLinear(n: nn.Module) -> Optional[nn.Linear]: + """ + Replaces Apex's ColumnParallelLinear or RowParallelLinear with nn.Linear + Args: + n: the nn.Module pytorch module to replace + Returns: + Equivalent Linear module + """ + if not (isinstance(n, ColumnParallelLinear) or isinstance(n, RowParallelLinear)): + raise ValueError("This function can only change the ColumnParallelLinear or RowParallelLinear module.") + + dev = next(n.parameters()).device + mod = LinearWithBiasSkip(n.weight, n.bias, n.skip_bias_add).to(dev) + + n_state = n.state_dict() + mod.load_state_dict(n_state) + return mod + + def replace_FusedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]: + """ + Replaces Apex's FusedScaleMaskSoftmax with nn.LayerNorm. This is required for ONNX export. + Args: + n: the FusedScaleMaskSoftmax module to replace + Returns: + Equivalent LayerNorm module + """ + if not isinstance(n, FusedScaleMaskSoftmax): + raise ValueError("This function can only change the FusedScaleMaskSoftmax module.") + + # disable the fusion only + mod = FusedScaleMaskSoftmax( + n.input_in_fp16, n.input_in_bf16, n.attn_mask_type, False, n.mask_func, n.softmax_in_fp32, n.scale + ) + + return mod + + default_Apex_replacements = { + "FusedLayerNorm": replace_FusedLayerNorm, + "MixedFusedLayerNorm": replace_FusedLayerNorm, + "FastLayerNorm": replace_FusedLayerNorm, + "ESM1bLayerNorm" : replace_FusedLayerNorm, + "RowParallelLinear": replace_ParallelLinear, + "ColumnParallelLinear": replace_ParallelLinear, + "FusedScaleMaskSoftmax": replace_FusedScaleMaskSoftmax, + } + +except Exception as e: + default_Apex_replacements = {} + apex_available = False + + +def simple_replace(BaseT: Type[nn.Module], DestT: Type[nn.Module]) -> Callable[[nn.Module], Optional[nn.Module]]: + """ + Generic function generator to replace BaseT module with DestT. BaseT and DestT should have same atrributes. No weights are copied. + Args: + BaseT : module type to replace + DestT : destination module type + Returns: + swap function to replace BaseT module with DestT + """ + + def expansion_fn(mod: nn.Module) -> Optional[nn.Module]: + if not isinstance(mod, BaseT): + return None + args = [getattr(mod, name, None) for name in mod.__constants__] + out = DestT(*args) + return out + + return expansion_fn + + +def replace_MatchedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]: + """ + Replaces MatchedScaleMaskSoftmax with exportable softmax layer + Args: + n: module to replace + Returns: + exportable module + """ + # including the import here to avoid circular imports + from nemo.collections.nlp.modules.common.megatron.fused_softmax import MatchedScaleMaskSoftmax + + # disabling fusion for the MatchedScaleMaskSoftmax + mod = MatchedScaleMaskSoftmax( + n.input_in_fp16, n.input_in_bf16, n.attn_mask_type, False, n.mask_func, n.softmax_in_fp32, n.scale + ) + return mod + + +def wrap_module(BaseT: Type[nn.Module], DestT: Type[nn.Module]) -> Callable[[nn.Module], Optional[nn.Module]]: + """ + Generic function generator to replace BaseT module with DestT wrapper. + Args: + BaseT : module type to replace + DestT : destination module type + Returns: + swap function to replace BaseT module with DestT + """ + + def expansion_fn(mod: nn.Module) -> Optional[nn.Module]: + out = DestT(mod) + return out + + return expansion_fn + + +def swap_modules(model: nn.Module, mapping: Dict[str, nn.Module]): + """ + This function swaps nested modules as specified by "dot paths" in mod with a desired replacement. This allows + for swapping nested modules through arbitrary levels if children + + NOTE: This occurs in place, if you want to preserve model then make sure to copy it first. + + """ + for path, new_mod in mapping.items(): + expanded_path = path.split(".") + parent_mod = model + for sub_path in expanded_path[:-1]: + parent_mod = parent_mod._modules[sub_path] # noqa + parent_mod._modules[expanded_path[-1]] = new_mod # noqa + + return model + + +def replace_modules( + model: nn.Module, expansions: Dict[str, Callable[[nn.Module], Optional[nn.Module]]] = None +) -> nn.Module: + """ + Top-level function to replace modules in model, specified by class name with a desired replacement. + NOTE: This occurs in place, if you want to preserve model then make sure to copy it first. + Args: + model : top level module + expansions : replacement dictionary: module class name -> replacement function generator + Returns: + model, possibly modified in-place + """ + mapping: Dict[str, nn.Module] = {} + for name, m in model.named_modules(): + m_type = type(m).__name__ + if m_type in expansions: + # print (f"Found {m_type} in expansions ...") + swapped = expansions[m_type](m) + if swapped: + mapping[name] = swapped + + print(f"Swapped {len(mapping)} modules") + swap_modules(model, mapping) + return model + + +def script_module(m: nn.Module): + return torch.jit.script(m) + + +script_replacements = {} + + +def replace_for_export(model: nn.Module, do_cast: bool = False) -> nn.Module: + """ + Top-level function to replace default set of modules in model + NOTE: This occurs in place, if you want to preserve model then make sure to copy it first. + Args: + model : top level module + replace_1D_2D : include 1D -> 2D replacements + Returns: + model, possibly modified in-place + """ + if apex_available: + print("Replacing Apex layers ...") + replace_modules(model, default_Apex_replacements) + + if do_cast: + print("Adding casts around norms...") + cast_replacements = { + "BatchNorm1d": wrap_module(nn.BatchNorm1d, CastToFloat), + "BatchNorm2d": wrap_module(nn.BatchNorm2d, CastToFloat), + "LayerNorm": wrap_module(nn.LayerNorm, CastToFloat), + "InstanceNorm1d": wrap_module(nn.InstanceNorm1d, CastToFloat), + "InstanceNorm3d": wrap_module(nn.InstanceNorm3d, CastToFloat), + } + replace_modules(model, cast_replacements) + + # This one has to be the last + replace_modules(model, script_replacements) diff --git a/scripts/utils/trt_utils.py b/scripts/utils/trt_utils.py new file mode 100644 index 0000000..6275a0e --- /dev/null +++ b/scripts/utils/trt_utils.py @@ -0,0 +1,534 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +# +# Copyright 2022 The HuggingFace Inc. team. +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +from typing import List +from copy import copy +import numpy as np +import os +import pickle +from PIL import Image +from polygraphy.backend.common import bytes_from_path +from polygraphy.backend.onnx import onnx_from_path, fold_constants, save_onnx +from polygraphy.backend.onnxrt import OnnxrtRunner, session_from_onnx +from polygraphy.backend.trt import TrtRunner, CreateConfig, ModifyNetworkOutputs, Profile +from polygraphy.backend.trt import engine_from_bytes, engine_from_network, network_from_onnx_path, save_engine +from polygraphy.logger import G_LOGGER as L_ + +import random +from scipy import integrate +import tensorrt as trt +import torch +import traceback + +from io import BytesIO +from cuda import cudart +from enum import Enum, auto + +import threading + +# TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE) +# trt.init_libnvinfer_plugins(TRT_LOGGER, '') + +lock_sm = threading.Lock() + +@torch.jit.script +def check_m(m): + t = torch.isnan(m) + return not torch.any(t) + +# Map of torch dtype -> numpy dtype +trt_to_torch_dtype_dict = { + trt.int32 : torch.int32, + trt.float32: torch.float32, + trt.float16: torch.float16, + trt.bfloat16 : torch.float16, + trt.int64 : torch.int64, + trt.int8 : torch.int8, + trt.bool : torch.bool, +} + +def CUASSERT(cuda_ret): + err = cuda_ret[0] + if err != 0: + raise RuntimeError(f"CUDA ERROR: {err}, error code reference: https://nvidia.github.io/cuda-python/module/cudart.html#cuda.cudart.cudaError_t") + if len(cuda_ret) > 1: + return cuda_ret[1] + return None + +class ShapeException(Exception): + pass + +class Engine(): + def __init__( + self, + engine_path, + ): + self.engine_path = engine_path + self.engine = None + self.context = None + self.tensors = OrderedDict() + self.cuda_graph_instance = None # cuda graph + + def build(self, onnx_path, + profiles=[], fp16=False, bf16=False, tf32=True, + builder_optimization_level=3, + enable_all_tactics=True, + direct_io=False, + timing_cache=None, + update_output_names=None): + L_.info(f"Building TensorRT engine for {onnx_path}: {self.engine_path}") + config_kwargs = { + 'builder_optimization_level' : builder_optimization_level, + 'direct_io' : direct_io, + } + if not enable_all_tactics: + config_kwargs['tactic_sources'] = [] + + network = network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]) + if update_output_names: + L_.info(f"Updating network outputs to {update_output_names}") + network = ModifyNetworkOutputs(network, update_output_names) + # with L.verbosity(0): + L_.info("Calling engine_from_network...") + + engine = engine_from_network( + network, + config=CreateConfig( + fp16=fp16, + bf16=bf16, + tf32=tf32, + profiles=profiles, + load_timing_cache=timing_cache, + **config_kwargs + ), + save_timing_cache=timing_cache + ) + self.engine = engine + + def save(self): + save_engine(self.engine, path=self.engine_path) + + def load(self): + L_.info(f"Loading TensorRT engine: {self.engine_path}") + self.engine = engine_from_bytes(bytes_from_path(self.engine_path)) + + def activate(self, profile_num=0, reuse_device_memory=None): + if reuse_device_memory: + self.context = self.engine.create_execution_context_without_device_memory() + self.context.device_memory = reuse_device_memory + else: + self.context = self.engine.create_execution_context() + self.input_names = [] + self.output_names = [] + self.dtypes = [] + for idx in range(self.engine.num_io_tensors): + binding = self.engine[idx] + if self.engine.get_tensor_mode(binding) == trt.TensorIOMode.INPUT: + self.input_names.append(binding) + elif self.engine.get_tensor_mode(binding) == trt.TensorIOMode.OUTPUT: + self.output_names.append(binding) + dtype = trt_to_torch_dtype_dict[self.engine.get_tensor_dtype(binding)] + self.dtypes.append(dtype) + self.cur_profile = profile_num + # L_.info(self.input_names) + # L_.info(self.output_names) + + def allocate_buffers(self, device): + # allocate outputs + e = self.engine + ctx = self.context + + for i, binding in enumerate(self.output_names): + shape=ctx.get_tensor_shape(binding) + t = torch.empty(list(shape), dtype=self.dtypes[i], device=device).contiguous() + self.tensors[binding] = t + ctx.set_tensor_address(binding, t.data_ptr()) + + @staticmethod + def check_shape(shape, profile): + shape = list(shape) + minlist = profile[0] + maxlist = profile[2] + good = True + for i, s in enumerate(shape): + if s < minlist[i] or s > maxlist[i]: + good = False + return good + + def set_inputs(self, feed_dict, stream): + e = self.engine + ctx = self.context + last_profile = self.cur_profile + + def try_set_inputs(): + for binding, t in feed_dict.items(): + if t is not None: + t = t.contiguous() + shape = t.shape + # mincurmax = list(e.get_profile_shape(self.cur_profile, binding)) + # if not self.check_shape(shape, mincurmax): + # raise ShapeException(f"Input shape to be set is outside the bounds: {binding} -> {shape}, profile is {mincurmax}, trying another profile: {self.cur_profile}") + ctx.set_input_shape(binding, shape) + ctx.set_tensor_address(binding, t.data_ptr()) + + while True: + try: + try_set_inputs() + break; + except ShapeException: + next_profile = (self.cur_profile+1) % e.num_optimization_profiles + if next_profile == last_profile: + raise + self.cur_profile = next_profile + ctx.set_optimization_profile_async(self.cur_profile, stream) + # torch.cuda.synchronize() + + left = ctx.infer_shapes() + assert len(left)==0 + + + + def infer(self, stream, use_cuda_graph=False): + e = self.engine + ctx = self.context + if use_cuda_graph: + if self.cuda_graph_instance is not None: + CUASSERT(cudart.cudaGraphLaunch(self.cuda_graph_instance, stream)) + CUASSERT(cudart.cudaStreamSynchronize(stream)) + else: + # do inference before CUDA graph capture + noerror = self.context.execute_async_v3(stream) + if not noerror: + raise ValueError(f"ERROR: inference failed.") + # capture cuda graph + CUASSERT(cudart.cudaStreamBeginCapture(stream, cudart.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal)) + self.context.execute_async_v3(stream) + graph = CUASSERT(cudart.cudaStreamEndCapture(stream)) + self.cuda_graph_instance = CUASSERT(cudart.cudaGraphInstantiate(graph, 0)) + print("CUDA Graph captured!") + else: + noerror = self.context.execute_async_v3(stream) + CUASSERT(cudart.cudaStreamSynchronize(stream)) + if not noerror: + raise ValueError(f"ERROR: inference failed.") + + return self.tensors + + +class ExportWrapper(torch.nn.Module): + """ + An auxiliary class to facilitate ONNX->TRT export of a module + """ + def __init__(self, model, + input_names=None, + output_names=None, + precision="fp32"): + super().__init__() + self.input_names = input_names + self.output_names = output_names + self.dynamic_shapes = None + + self.model = model + self.precision = precision + + def get_export_obj(self): + return self.model + + def sample_profile(self, min_len=None, max_len=None): + return None + + def can_handle(self, **args): + return True + + @classmethod + def wrap(cls, model, **args): + wrapper = cls(model, **args) + return wrapper + + +@torch.jit.script +def no_nans(m): + t = torch.isnan(m) + return not torch.any(t) + +class TRTWrapper(torch.nn.Module): + """ + An auxiliary class to implement running of TRT optimized engines + + """ + def __init__(self, path, exp, use_cuda_graph=False): + super().__init__() + self.exp_wrapper = None + self.prev_wrapper = None + self.profiles = None + self.engine = None + self.jit_model = None + self.onnx_runner = None + self.path = path + self.use_cuda_graph=use_cuda_graph + if exp is not None: + self.attach(exp) + + @property + def engine_path(self): + return self.path + '.plan' + @property + def jit_path(self): + return self.path + '.ts' + @property + def onnx_path(self): + return self.path + '.onnx' + @property + def profiles_path(self): + return self.path + '.profiles.pkl' + + def has_engine(self): + return self.engine is not None + + def has_onnx(self): + return os.path.exists(self.onnx_path) + + def has_jit(self): + return os.path.exists(self.jit_path) + + def has_profiles(self): + return os.path.exists(self.profiles_path) + + def load_engine(self): + try: + engine=Engine(self.engine_path) + engine.load() + engine.activate() + self.engine = engine + except Exception as e: + print (f"Exception while loading the engine:\n{e}") + pass + + def load_jit(self): + try: + self.jit_model = torch.jit.load(self.jit_path) + except Exception: + pass + + def load_onnx(self, providers=["CUDAExecutionProvider"]): + try: + onnx_runner = OnnxrtRunner( + session_from_onnx(self.onnx_path, providers=providers) + ) + onnx_runner.activate() + self.onnx_runner = onnx_runner + except Exception: + pass + + def load_profiles(self): + with open(self.profiles_path, "rb") as fp: + profiles = pickle.load(fp) + self.profiles = profiles + return profiles + + def save_profiles(self): + with open(self.profiles_path, "wb") as fp: + pickle.dump(self.profiles, fp) + + def attach(self, exp): + self.exp_wrapper = exp + self.input_names = exp.input_names + self.output_names = exp.output_names + + def can_handle(self, **args): + return self.exp_wrapper.can_handle(**args) + + def inputs_to_dict(self, input_example): + trt_inputs = {} + for i, inp in enumerate(input_example): + input_name=self.engine.input_names[i] + trt_inputs[input_name] = inp + return trt_inputs + + def forward(self, **args): + try: + if self.engine is not None: + if self.can_handle(**args): + # print(f"Running {self.engine_path}...") + # forward_trt is not thread safe as we do not use per-thread execution contexts + with lock_sm: + return self.forward_trt(args) + elif self.jit_model is not None: + return self.jit_model.forward(**args) + elif self.onnx_runner is not None: + print(f"Running {self.onnx_path}...") + ret = self.onnx_runner.infer(args) + ret = list(ret.values()) + ret = [r.cuda() for r in ret] + if len(ret)==1: + ret = ret[0] + return ret + except Exception as e: + print(f"Exception: {e}\nFalling back to Pytorch ...") + + return self.exp_wrapper.get_export_obj().forward(**args) + + def forward_trt(self, trt_inputs): + stream = torch.cuda.Stream(device=torch.cuda.current_device()) + self.engine.set_inputs(trt_inputs, stream.cuda_stream) + self.engine.allocate_buffers(torch.device("cuda")) + # Need this to synchronize with Torch stream + stream.wait_stream(torch.cuda.current_stream()) + ret = self.engine.infer(stream.cuda_stream, use_cuda_graph=self.use_cuda_graph) + ret = list(ret.values()) + #for r in ret: + # assert no_nans(r), "NaNs in TRT output!" + if len(ret)==1: + ret = ret[0] + return ret + + def forward_trt_runner(self, trt_inputs): + with TrtRunner(self.engine) as runner: + ret = runner.infer(trt_inputs) + ret = list(ret.values()) + ret = [r.cuda() for r in ret] + check = [check_m(r) for r in ret] + if len(ret)==1: + ret = ret[0] + return ret + + + def build_engine(self, input_profiles=[], + fp16=False, bf16=False, tf32=False, + builder_optimization_level=3, + direct_io=False, + enable_all_tactics=True): + profiles = [] + if len(input_profiles) > 0: + for input_profile in input_profiles: + if isinstance(input_profile, Profile): + profiles.append(input_profile) + else: + p = Profile() + for name, dims in input_profile.items(): + assert len(dims) == 3 + p.add(name, min=dims[0], opt=dims[1], max=dims[2]) + profiles.append(p) + self.profiles = profiles + self.save_profiles() + + engine = Engine(self.path+'.plan') + engine.build(self.onnx_path, profiles, + fp16=fp16, + bf16=bf16, + tf32=tf32, + direct_io=direct_io, + builder_optimization_level=builder_optimization_level, + enable_all_tactics=enable_all_tactics + ) + engine.activate() + self.engine = engine + + def jit_export(self, input_example, + verbose=False, ): + self.jit_model = torch.jit.trace( + self.exp_wrapper, + input_example, + ).eval() + self.jit_model = torch.jit.freeze(self.jit_model) + torch.jit.save(self.jit_model, self.jit_path) + + def onnx_export(self, input_example, + dynamo=False, + onnx_registry=None, + dynamic_shapes=None, + verbose=False, + opset_version=18, + ): + L_.info(f"Exporting to ONNX, dynamic shapes: {dynamic_shapes}") + model = self.exp_wrapper.get_export_obj() + from .export_utils import replace_for_export + replace_for_export(model, do_cast=True) + + if dynamo: + torch.onnx.export(model, + input_example, + self.onnx_path, + dynamo=dynamo, + verbose=verbose, + opset_version=opset_version, + do_constant_folding=True, + input_names=self.input_names, + output_names=self.output_names, + dynamic_shapes=dynamic_shapes + ) + else: + torch.onnx.export(model, + input_example, + self.onnx_path, + verbose=verbose, + opset_version=opset_version, + do_constant_folding=True, + input_names=self.input_names, + output_names=self.output_names, + dynamic_axes=dynamic_shapes + ) + L_.info("Folding constants...") + model_onnx = onnx_from_path(self.onnx_path) + fold_constants(model_onnx, allow_onnxruntime_shape_inference=False) + L_.info("Done folding constants.") + + L_.info("Saving model...") + save_onnx( + model_onnx, + self.onnx_path, + ) + L_.info("Done saving model.") + + def build_and_save(self, + input_example, + dynamo=False, + verbose=False, + input_profiles=[], + fp16=False, bf16=False, tf32=True, + builder_optimization_level=3, + direct_io=False, + enable_all_tactics=True): + return + if not self.has_engine(): + if not self.has_onnx(): + self.onnx_export( + input_example, + dynamo=dynamo, + verbose=verbose, + ) + self.build_engine( + fp16=fp16, tf32=tf32, + direct_io=direct_io, + builder_optimization_level=5, + enable_all_tactics=enable_all_tactics) + self.engine.save() + + + diff --git a/vista3d/modeling/segresnetds.py b/vista3d/modeling/segresnetds.py index 0a39508..6fabe2a 100644 --- a/vista3d/modeling/segresnetds.py +++ b/vista3d/modeling/segresnetds.py @@ -472,7 +472,7 @@ def _forward( f"Input spatial dims {x.shape} must be divisible by {self.shape_factor()}" ) - x_down = self.encoder(x) + x_down = self.encoder(x=x) x_down.reverse() x = x_down.pop(0) @@ -482,8 +482,9 @@ def _forward( outputs: list[torch.Tensor] = [] outputs_auto: list[torch.Tensor] = [] - x_ = x.clone() + if with_point: + x_ = x.clone() i = 0 for level in self.up_layers: x = level["upsample"](x) @@ -495,7 +496,8 @@ def _forward( i = i + 1 outputs.reverse() - x = x_ + x = x_ + if with_label: i = 0 for level in self.up_layers_auto: diff --git a/vista3d/modeling/vista3d.py b/vista3d/modeling/vista3d.py index 51f357d..2cf9663 100755 --- a/vista3d/modeling/vista3d.py +++ b/vista3d/modeling/vista3d.py @@ -302,58 +302,48 @@ def forward( ): out, out_auto = self.image_embeddings, None else: + # print(input_images.dtype) + self.image_encoder.encoder.build_and_save( + (input_images,), + dynamo=False, + verbose=False, + fp16=True, tf32=True, + builder_optimization_level=5, + enable_all_tactics=True + ) + time0 = time.time() out, out_auto = self.image_encoder( - input_images, + x=input_images, with_point=point_coords is not None, with_label=class_vector is not None, ) - torch.cuda.synchronize() - print(f"Encoder Time: {time.time() - time0}, shape : {input_images.shape}, point: {point_coords is not None}") - if self.engine is None: - # breakpoint() - torch.onnx.export(self.image_encoder, - (input_images,), - "Encoder.onnx", - verbose=False, - opset_version=18 - ) - - input_images = None - time1 = time.time() - + # torch.cuda.synchronize() + # time1 = time.time() + # print(f"Encoder Time: {time.time() - time0}, shape : {input_images.shape}, point: {point_coords is not None}") + input_images = None # force releasing memories that set to None torch.cuda.empty_cache() if class_vector is not None: + self.class_head.build_and_save( + (out_auto, class_vector,), + fp16=True, tf32=True, + dynamo=False, + verbose=False, + ) time2 = time.time() - logits, _ = self.class_head(out_auto, class_vector) - torch.cuda.synchronize() - print(f"Class Head Time: {time.time() - time2}") - - if self.engine is None: - torch.onnx.export(self.class_head, - (out_auto, class_vector,), - "class_head.onnx", - verbose=True, - opset_version=18 - ) - if False: - torch.onnx.export(self.point_head, - (out, point_coords, point_labels, {"class_vector":prompt_class}), - "point_head.onnx", - verbose=False, - opset_version=18 - ) - self.engine = True + logits, _ = self.class_head(src=out_auto, class_vector=class_vector) + # torch.cuda.synchronize() + # print(f"Class Head Time: {time.time() - time2}") if point_coords is not None: time3 = time.time() point_logits = self.point_head( out, point_coords, point_labels, class_vector=prompt_class ) - torch.cuda.synchronize() - print(f"Point Head Time: {time.time() - time3}") - time4 = time.time() + # torch.cuda.synchronize() + # print(f"Point Head Time: {time.time() - time3}") + # time4 = time.time() if patch_coords is None: logits = self.gaussian_combine( logits, @@ -368,8 +358,8 @@ def forward( logits = self.connected_components_combine( logits, point_logits, point_coords, point_labels, mapping_index ) - torch.cuda.synchronize() - print(f"Combine Time: {time.time() - time4}") + # torch.cuda.synchronize() + # print(f"Combine Time: {time.time() - time4}") else: logits = NINF_VALUE + torch.zeros( [bs, 1, *image_size], device=device, dtype=out.dtype @@ -387,7 +377,7 @@ def forward( ) torch.cuda.synchronize() - print(f"Head time: {time.time() - time1}, total time : {time.time() - time00} shape : {logits.shape}") + # print(f"Head time: {time.time() - time1}, total time : {time.time() - time00} shape : {logits.shape}") if kwargs.get("keep_cache", False) and class_vector is None: self.image_embeddings = out.detach() From 818a548eed32a3b9dc8c4aef410359078bd16a71 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Tue, 23 Jul 2024 21:07:09 -0700 Subject: [PATCH 04/31] Cleaned up, working TRT wrapping Signed-off-by: Boris Fomitchev --- scripts/export.py | 8 ++++---- vista3d/modeling/vista3d.py | 40 +++++++++++++++++++------------------ 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/scripts/export.py b/scripts/export.py index 59784d3..7f1a3c2 100644 --- a/scripts/export.py +++ b/scripts/export.py @@ -133,13 +133,13 @@ def __init__(self, config_file="./configs/infer.yaml", **override): en_wrapper = ExportWrapper.wrap(self.model.image_encoder.encoder, input_names = ['x'], output_names = ['x_out']) - self.model.image_encoder.encoder = TRTWrapper("Encoder", en_wrapper, use_cuda_graph=False) - # self.model.image_encoder.encoder.load_engine() + self.model.image_encoder.encoder = TRTWrapper("Encoder", en_wrapper) + self.model.image_encoder.encoder.load_engine() cls_wrapper = ExportWrapper.wrap(self.model.class_head, input_names = ['src', 'class_vector'], output_names = ['masks', 'class_embedding']) - self.model.class_head = TRTWrapper("ClassHead", cls_wrapper, use_cuda_graph=False) - # self.model.class_head.load_engine() + self.model.class_head = TRTWrapper("ClassHead", cls_wrapper) + self.model.class_head.load_engine() return diff --git a/vista3d/modeling/vista3d.py b/vista3d/modeling/vista3d.py index 2cf9663..accfbde 100755 --- a/vista3d/modeling/vista3d.py +++ b/vista3d/modeling/vista3d.py @@ -302,15 +302,16 @@ def forward( ): out, out_auto = self.image_embeddings, None else: - # print(input_images.dtype) - self.image_encoder.encoder.build_and_save( - (input_images,), - dynamo=False, - verbose=False, - fp16=True, tf32=True, - builder_optimization_level=5, - enable_all_tactics=True - ) + # Support for TRT wrappping + if hasattr(self.image_encoder.encoder, "build_and_save"): + self.image_encoder.encoder.build_and_save( + (input_images,), + dynamo=False, + verbose=False, + fp16=True, tf32=True, + builder_optimization_level=5, + enable_all_tactics=True + ) time0 = time.time() out, out_auto = self.image_encoder( @@ -325,19 +326,20 @@ def forward( # force releasing memories that set to None torch.cuda.empty_cache() if class_vector is not None: - self.class_head.build_and_save( - (out_auto, class_vector,), - fp16=True, tf32=True, - dynamo=False, - verbose=False, - ) - time2 = time.time() + if hasattr(self.class_head, "build_and_save"): + self.class_head.build_and_save( + (out_auto, class_vector,), + fp16=True, tf32=True, + dynamo=False, + verbose=False, + ) + # time2 = time.time() logits, _ = self.class_head(src=out_auto, class_vector=class_vector) # torch.cuda.synchronize() # print(f"Class Head Time: {time.time() - time2}") if point_coords is not None: - time3 = time.time() + # time3 = time.time() point_logits = self.point_head( out, point_coords, point_labels, class_vector=prompt_class ) @@ -376,8 +378,8 @@ def forward( mapping_index, ) - torch.cuda.synchronize() - # print(f"Head time: {time.time() - time1}, total time : {time.time() - time00} shape : {logits.shape}") + # torch.cuda.synchronize() + # print(f"Total time : {time.time() - time00} shape : {logits.shape}") if kwargs.get("keep_cache", False) and class_vector is None: self.image_embeddings = out.detach() From 3e4a84b2fd7ecb95f6e5308da40d09dba81b7904 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Wed, 31 Jul 2024 12:59:33 -0700 Subject: [PATCH 05/31] Cleanup Signed-off-by: Boris Fomitchev --- configs/infer.yaml | 1 + scripts/export.bash | 3 - scripts/export.py | 353 ------------------------------------- scripts/infer.py | 21 +++ scripts/utils/trt_utils.py | 1 - 5 files changed, 22 insertions(+), 357 deletions(-) delete mode 100755 scripts/export.bash delete mode 100644 scripts/export.py diff --git a/configs/infer.yaml b/configs/infer.yaml index 89d018b..8db5bbb 100644 --- a/configs/infer.yaml +++ b/configs/infer.yaml @@ -1,3 +1,4 @@ +trt: true amp: true input_channels: 1 patch_size: [128, 128, 128] diff --git a/scripts/export.bash b/scripts/export.bash deleted file mode 100755 index 6e267ff..0000000 --- a/scripts/export.bash +++ /dev/null @@ -1,3 +0,0 @@ -python3 -m scripts.export --config_file 'configs/infer.yaml' - infer_everything --image_file 'example-1.nii.gz' - -# python3 -m scripts.export --config_file 'configs/infer.yaml' - infer --image_file 'example-1.nii.gz' --label_prompt [1] --save_mask true diff --git a/scripts/export.py b/scripts/export.py deleted file mode 100644 index 7f1a3c2..0000000 --- a/scripts/export.py +++ /dev/null @@ -1,353 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import os -import sys -from functools import partial - -import monai -import numpy as np -import torch -import torch.distributed as dist -from monai import transforms -from monai.apps.auto3dseg.auto_runner import logger -from monai.auto3dseg.utils import datafold_read -from monai.bundle import ConfigParser -from monai.bundle.scripts import _pop_args, _update_args -from monai.data import decollate_batch, list_data_collate, partition_dataset -from monai.utils import optional_import - -from vista3d import vista_model_registry - -from .sliding_window import point_based_window_inferer, sliding_window_inference -from .train import CONFIG -from .utils.trans_utils import VistaPostTransform, get_largest_connected_component_point -from .utils.trt_utils import ExportWrapper, TRTWrapper -import time - -rearrange, _ = optional_import("einops", name="rearrange") -sys.path.insert(0, os.path.abspath(os.path.dirname(__file__))) -IGNORE_PROMPT = set( - [ - 2, # kidney - 16, # prostate or uterus - 18, # rectum - 20, # lung - 21, # bone - 23, # lung tumor - 24, # pancreatic tumor - 25, # hepatic vessel - 26, # hepatic tumor - 27, # colon cancer primaries - 128, # bone lesion - 129, # kidney mass - 130, # liver tumor - 131, # vertebrae L6 - 132, - ] -) # airway -EVERYTHING_PROMPT = list(set([i + 1 for i in range(133)]) - IGNORE_PROMPT) - - -def infer_wrapper(inputs, model, **kwargs): - outputs = model(input_images=inputs, **kwargs) - return outputs.transpose(1, 0) - - -class InferClass: - def __init__(self, config_file="./configs/infer.yaml", **override): - logging.basicConfig(stream=sys.stdout, level=logging.INFO) - - _args = _update_args(config_file=config_file, **override) - config_file_ = _pop_args(_args, "config_file")[0] - - parser = ConfigParser() - parser.read_config(config_file_) - parser.update(pairs=_args) - - self.amp = parser.get_parsed_content("amp") - input_channels = parser.get_parsed_content("input_channels") - patch_size = parser.get_parsed_content("patch_size") - self.patch_size = patch_size - - ckpt_name = parser.get_parsed_content("infer")["ckpt_name"] - output_path = parser.get_parsed_content("infer")["output_path"] - if not os.path.exists(output_path): - os.makedirs(output_path, exist_ok=True) - - CONFIG["handlers"]["file"]["filename"] = parser.get_parsed_content("infer")[ - "log_output_file" - ] - logging.config.dictConfig(CONFIG) - self.infer_transforms = parser.get_parsed_content("transforms_infer") - - self.device = torch.device("cuda:0") - model_registry = parser.get_parsed_content("model") - model = vista_model_registry[model_registry]( - in_channels=input_channels, image_size=patch_size - ) - self.model = model.to(self.device) - - pretrained_ckpt = torch.load(ckpt_name, map_location=self.device) - self.model.load_state_dict(pretrained_ckpt, strict=False) - logger.debug(f"[debug] checkpoint {ckpt_name:s} loaded") - post_transforms = [ - VistaPostTransform(keys="pred"), - transforms.Invertd( - keys="pred", - transform=self.infer_transforms, - orig_keys="image", - meta_keys="pred_meta_dict", - orig_meta_keys="image_meta_dict", - meta_key_postfix="meta_dict", - nearest_interp=True, - to_tensor=True, - ), - ] - - # For Vista3d, sigmoid is always used, but for visualization, argmax is needed - save_transforms = [ - transforms.SaveImaged( - keys="pred", - meta_keys="pred_meta_dict", - output_dir=output_path, - output_postfix="seg", - resample=False, - data_root_dir=None, - print_log=False, - ) - ] - self.post_transforms = transforms.Compose(post_transforms) - self.save_transforms = transforms.Compose(save_transforms) - self.prev_mask = None - self.batch_data = None - - en_wrapper = ExportWrapper.wrap(self.model.image_encoder.encoder, - input_names = ['x'], output_names = ['x_out']) - self.model.image_encoder.encoder = TRTWrapper("Encoder", en_wrapper) - self.model.image_encoder.encoder.load_engine() - - cls_wrapper = ExportWrapper.wrap(self.model.class_head, - input_names = ['src', 'class_vector'], output_names = ['masks', 'class_embedding']) - self.model.class_head = TRTWrapper("ClassHead", cls_wrapper) - self.model.class_head.load_engine() - - return - - def clear_cache(self): - self.prev_mask = None - self.batch_data = None - - def transform_points(self, point, affine): - """transform point to the coordinates of the transformed image - point: numpy array [bs, N, 3] - """ - bs, N = point.shape[:2] - point = np.concatenate((point, np.ones((bs, N, 1))), axis=-1) - point = rearrange(point, "b n d -> d (b n)") - point = affine @ point - point = rearrange(point, "d (b n)-> b n d", b=bs)[:, :, :3] - return point - - @torch.no_grad() - def infer( - self, - image_file, - point=None, - point_label=None, - label_prompt=None, - prompt_class=None, - save_mask=False, - point_start=0, - ): - """Infer a single image_file. If save_mask is true, save the argmax prediction to disk. If false, - do not save and return the probability maps (usually used by autorunner emsembler). point_start is - used together with prev_mask. If prev_mask is generated by N points, point_start should be N+1 to save - time and avoid repeated inference. This is by default disabled. - """ - time00=time.time() - self.model.eval() - if not isinstance(image_file, dict): - image_file = {"image": image_file} - if self.batch_data is not None: - batch_data = self.batch_data - else: - batch_data = self.infer_transforms(image_file) - if label_prompt is not None: - batch_data["label_prompt"] = label_prompt - batch_data = list_data_collate([batch_data]) - self.batch_data = batch_data - if point is not None: - if type(point) is list: - point = np.array(point)[np.newaxis, ...] - point_label = np.array(point_label)[np.newaxis, ...] - point = self.transform_points( - point, - np.linalg.inv(batch_data["image"].affine[0]) - @ batch_data["image"].meta["original_affine"][0].numpy(), - ) - self.sliding_window_inferer = partial( - point_based_window_inferer, point_start=point_start - ) - else: - self.sliding_window_inferer = sliding_window_inference - device_list_input = [self.device, self.device, "cpu"] - device_list_output = [self.device, "cpu", "cpu"] - for _device_in, _device_out in zip(device_list_input, device_list_output): - try: - with torch.cuda.amp.autocast(enabled=self.amp): - batch_data["pred"] = self.sliding_window_inferer( - inputs=batch_data["image"].to(_device_in), - roi_size=self.patch_size, - sw_batch_size=1, - predictor=partial(infer_wrapper, model=self.model), - mode="gaussian", - overlap=0.625, - progress=True, - sw_device=self.device, - device=_device_out, - point_coords=( - torch.tensor(point).to(_device_in) - if point is not None - else None - ), - point_labels=( - torch.tensor(point_label).to(_device_in) - if point_label is not None - else None - ), - class_vector=( - torch.tensor(label_prompt).to(_device_in) - if label_prompt is not None - else None - ), - prompt_class=( - torch.tensor(prompt_class).to(_device_in) - if prompt_class is not None - else None - ), - prev_mask=( - torch.tensor(self.prev_mask).to(_device_in) - if self.prev_mask is not None - else None - ), - ) - - if not hasattr(batch_data["pred"], "meta"): - batch_data["pred"] = monai.data.MetaTensor( - batch_data["pred"], - affine=batch_data["image"].meta["affine"], - meta=batch_data["image"].meta, - ) - self.prev_mask = batch_data["pred"] - if label_prompt is None and point is not None: - batch_data["pred"] = get_largest_connected_component_point( - batch_data["pred"], point_coords=point, point_labels=point_label - ) - batch_data["image"] = batch_data["image"].to("cpu") - batch_data["pred"] = batch_data["pred"].to("cpu") - torch.cuda.empty_cache() - batch_data = [ - self.post_transforms(i) for i in decollate_batch(batch_data) - ] - if save_mask: - batch_data = [self.save_transforms(i) for i in batch_data] - - finished = True - except RuntimeError as e: - if not any(x in str(e).lower() for x in ("memory", "cuda", "cudnn")): - raise e - finished = False - if finished: - break - print(f"Infer Time: {time.time() - time00}") - - if not finished: - raise RuntimeError("Infer not finished due to OOM.") - return batch_data[0]["pred"] - - @torch.no_grad() - def infer_everything(self, image_file, label_prompt=EVERYTHING_PROMPT, rank=0): - time00=time.time() - self.model.eval() - device = f"cuda:{rank}" - if not isinstance(image_file, dict): - image_file = {"image": image_file} - batch_data = self.infer_transforms(image_file) - batch_data["label_prompt"] = label_prompt - batch_data = list_data_collate([batch_data]) - device_list_input = [device, device, "cpu"] - device_list_output = [device, "cpu", "cpu"] - for _device_in, _device_out in zip(device_list_input, device_list_output): - try: - with torch.cuda.amp.autocast(enabled=self.amp): - batch_data["pred"] = sliding_window_inference( - inputs=batch_data["image"].to(_device_in), - roi_size=self.patch_size, - sw_batch_size=1, - predictor=partial(infer_wrapper, model=self.model), - mode="gaussian", - overlap=0.625, - sw_device=device, - device=_device_out, - class_vector=torch.tensor(label_prompt).to(_device_in), - ) - if not hasattr(batch_data["pred"], "meta"): - batch_data["pred"] = monai.data.MetaTensor( - batch_data["pred"], - affine=batch_data["image"].meta["affine"], - meta=batch_data["image"].meta, - ) - torch.cuda.empty_cache() - batch_data = [ - self.post_transforms(i) for i in decollate_batch(batch_data) - ] - batch_data = [self.save_transforms(i) for i in batch_data] - finished = True - except RuntimeError as e: - if not any(x in str(e).lower() for x in ("memory", "cuda", "cudnn")): - raise e - finished = False - if finished: - break - print(f"InferEverything Time: {time.time() - time00}") - - if not finished: - raise RuntimeError("Infer not finished due to OOM.") - - @torch.no_grad() - def batch_infer_everything(self, datalist=str, basedir=str): - train_files, _ = datafold_read(datalist=datalist, basedir=basedir, fold=0) - train_files = [_["image"] for _ in train_files] - dist.init_process_group(backend="nccl", init_method="env://") - world_size = dist.get_world_size() - rank = dist.get_rank() - # no need to wrap model with DistributedDataParallel - self.model = self.model.to(f"cuda:{rank}") - infer_files = partition_dataset( - data=train_files, - shuffle=False, - num_partitions=world_size, - even_divisible=False, - )[rank] - self.infer(infer_files, label_prompt=EVERYTHING_PROMPT, rank=rank) - - -if __name__ == "__main__": - try: - #import torch_onnx - #torch_onnx.patch_torch(error_report=True) - print("patch succeeded") - except Exception: - pass - fire, _ = optional_import("fire") - fire.Fire(InferClass) diff --git a/scripts/infer.py b/scripts/infer.py index 924a5ab..e01dc96 100644 --- a/scripts/infer.py +++ b/scripts/infer.py @@ -31,6 +31,8 @@ from .sliding_window import point_based_window_inferer, sliding_window_inference from .train import CONFIG from .utils.trans_utils import VistaPostTransform, get_largest_connected_component_point +from .utils.trt_utils import ExportWrapper, TRTWrapper +import time rearrange, _ = optional_import("einops", name="rearrange") sys.path.insert(0, os.path.abspath(os.path.dirname(__file__))) @@ -73,6 +75,7 @@ def __init__(self, config_file="./configs/infer.yaml", **override): parser.update(pairs=_args) self.amp = parser.get_parsed_content("amp") + self.trt = parser.get_parsed_content("trt") input_channels = parser.get_parsed_content("input_channels") patch_size = parser.get_parsed_content("patch_size") self.patch_size = patch_size @@ -128,6 +131,16 @@ def __init__(self, config_file="./configs/infer.yaml", **override): self.save_transforms = transforms.Compose(save_transforms) self.prev_mask = None self.batch_data = None + if self.trt: + en_wrapper = ExportWrapper.wrap(self.model.image_encoder.encoder, + input_names = ['x'], output_names = ['x_out']) + self.model.image_encoder.encoder = TRTWrapper("Encoder", en_wrapper) + self.model.image_encoder.encoder.load_engine() + + cls_wrapper = ExportWrapper.wrap(self.model.class_head, + input_names = ['src', 'class_vector'], output_names = ['masks', 'class_embedding']) + self.model.class_head = TRTWrapper("ClassHead", cls_wrapper) + self.model.class_head.load_engine() return def clear_cache(self): @@ -161,6 +174,7 @@ def infer( used together with prev_mask. If prev_mask is generated by N points, point_start should be N+1 to save time and avoid repeated inference. This is by default disabled. """ + time00=time.time() self.model.eval() if not isinstance(image_file, dict): image_file = {"image": image_file} @@ -255,12 +269,15 @@ def infer( finished = False if finished: break + print(f"Infer Time: {time.time() - time00}") + if not finished: raise RuntimeError("Infer not finished due to OOM.") return batch_data[0]["pred"] @torch.no_grad() def infer_everything(self, image_file, label_prompt=EVERYTHING_PROMPT, rank=0): + time00=time.time() self.model.eval() device = f"cuda:{rank}" if not isinstance(image_file, dict): @@ -302,6 +319,8 @@ def infer_everything(self, image_file, label_prompt=EVERYTHING_PROMPT, rank=0): finished = False if finished: break + print(f"InferEverything Time: {time.time() - time00}") + if not finished: raise RuntimeError("Infer not finished due to OOM.") @@ -324,5 +343,7 @@ def batch_infer_everything(self, datalist=str, basedir=str): if __name__ == "__main__": + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True fire, _ = optional_import("fire") fire.Fire(InferClass) diff --git a/scripts/utils/trt_utils.py b/scripts/utils/trt_utils.py index 6275a0e..0fc6d6b 100644 --- a/scripts/utils/trt_utils.py +++ b/scripts/utils/trt_utils.py @@ -515,7 +515,6 @@ def build_and_save(self, builder_optimization_level=3, direct_io=False, enable_all_tactics=True): - return if not self.has_engine(): if not self.has_onnx(): self.onnx_export( From 91606c385655e9110550b3e9c8b40cf3df3f011b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 31 Jul 2024 22:15:57 +0000 Subject: [PATCH 06/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- README.md | 2 +- data/README.md | 2 +- scripts/debugger.py | 10 +- scripts/infer.py | 20 +- scripts/utils/cast_utils.py | 11 +- scripts/utils/export_utils.py | 106 ++++++--- scripts/utils/trt_utils.py | 367 ++++++++++++++++++-------------- vista3d/modeling/segresnetds.py | 2 +- vista3d/modeling/vista3d.py | 17 +- 9 files changed, 320 insertions(+), 217 deletions(-) diff --git a/README.md b/README.md index 9ce85fd..1ac5bc9 100644 --- a/README.md +++ b/README.md @@ -128,7 +128,7 @@ Ask and answer questions on [MONAI VISTA's GitHub discussions tab](https://githu ## License -The codebase is under Apache 2.0 Licence. The model weight is under special NVIDIA license. +The codebase is under Apache 2.0 Licence. The model weight is under special NVIDIA license. ## Reference diff --git a/data/README.md b/data/README.md index 3fbdde0..03354c3 100644 --- a/data/README.md +++ b/data/README.md @@ -81,7 +81,7 @@ The output of this step is multiple JSON files, each file corresponds to one dataset. ##### 2. Add label_dict.json and label_mapping.json -Add new class indexes to `label_dict.json` and the local to global mapping to `label_mapping.json`. +Add new class indexes to `label_dict.json` and the local to global mapping to `label_mapping.json`. ## SupverVoxel Generation 1. Download the segment anything repo and download the ViT-H weights diff --git a/scripts/debugger.py b/scripts/debugger.py index c924862..b2568f4 100644 --- a/scripts/debugger.py +++ b/scripts/debugger.py @@ -123,8 +123,12 @@ def on_button_click(event, ax=ax): print("-- segmenting ---") self.generate_mask() print("-- done ---") - print("-- Note: Point only prompts will only do 128 cubic segmentation, a cropping artefact will be observed. ---") - print("-- Note: Point without class will be treated as supported class, which has worse zero-shot ability. Try class > 132 to perform better zeroshot. ---") + print( + "-- Note: Point only prompts will only do 128 cubic segmentation, a cropping artefact will be observed. ---" + ) + print( + "-- Note: Point without class will be treated as supported class, which has worse zero-shot ability. Try class > 132 to perform better zeroshot. ---" + ) print("-- Note: CTRL + Right Click will be adding negative points. ---") print( "-- Note: Click points on different foreground class will cause segmentation conflicts. Clear first. ---" @@ -132,7 +136,7 @@ def on_button_click(event, ax=ax): print( "-- Note: Click points not matching class prompts will also cause confusion. ---" ) - + self.update_slice(ax) # self.point_start = len(self.clicked_points) diff --git a/scripts/infer.py b/scripts/infer.py index e01dc96..b55be61 100644 --- a/scripts/infer.py +++ b/scripts/infer.py @@ -12,6 +12,7 @@ import logging import os import sys +import time from functools import partial import monai @@ -32,7 +33,6 @@ from .train import CONFIG from .utils.trans_utils import VistaPostTransform, get_largest_connected_component_point from .utils.trt_utils import ExportWrapper, TRTWrapper -import time rearrange, _ = optional_import("einops", name="rearrange") sys.path.insert(0, os.path.abspath(os.path.dirname(__file__))) @@ -132,13 +132,19 @@ def __init__(self, config_file="./configs/infer.yaml", **override): self.prev_mask = None self.batch_data = None if self.trt: - en_wrapper = ExportWrapper.wrap(self.model.image_encoder.encoder, - input_names = ['x'], output_names = ['x_out']) + en_wrapper = ExportWrapper.wrap( + self.model.image_encoder.encoder, + input_names=["x"], + output_names=["x_out"], + ) self.model.image_encoder.encoder = TRTWrapper("Encoder", en_wrapper) self.model.image_encoder.encoder.load_engine() - cls_wrapper = ExportWrapper.wrap(self.model.class_head, - input_names = ['src', 'class_vector'], output_names = ['masks', 'class_embedding']) + cls_wrapper = ExportWrapper.wrap( + self.model.class_head, + input_names=["src", "class_vector"], + output_names=["masks", "class_embedding"], + ) self.model.class_head = TRTWrapper("ClassHead", cls_wrapper) self.model.class_head.load_engine() return @@ -174,7 +180,7 @@ def infer( used together with prev_mask. If prev_mask is generated by N points, point_start should be N+1 to save time and avoid repeated inference. This is by default disabled. """ - time00=time.time() + time00 = time.time() self.model.eval() if not isinstance(image_file, dict): image_file = {"image": image_file} @@ -277,7 +283,7 @@ def infer( @torch.no_grad() def infer_everything(self, image_file, label_prompt=EVERYTHING_PROMPT, rank=0): - time00=time.time() + time00 = time.time() self.model.eval() device = f"cuda:{rank}" if not isinstance(image_file, dict): diff --git a/scripts/utils/cast_utils.py b/scripts/utils/cast_utils.py index ff58dde..329033e 100644 --- a/scripts/utils/cast_utils.py +++ b/scripts/utils/cast_utils.py @@ -1,6 +1,6 @@ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# +# # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, @@ -26,6 +26,7 @@ import torch + def avoid_bfloat16_autocast_context(): """ If the current autocast context is bfloat16, @@ -70,7 +71,9 @@ def cast_all(x, from_dtype=torch.float16, to_dtype=torch.float32): new_dict[k] = cast_all(x[k], from_dtype=from_dtype, to_dtype=to_dtype) return new_dict elif isinstance(x, tuple): - return tuple(cast_all(y, from_dtype=from_dtype, to_dtype=to_dtype) for y in x) + return tuple( + cast_all(y, from_dtype=from_dtype, to_dtype=to_dtype) for y in x + ) class CastToFloat(torch.nn.Module): @@ -92,5 +95,7 @@ def __init__(self, mod): def forward(self, *args): from_dtype = args[0].dtype with torch.cuda.amp.autocast(enabled=False): - ret = self.mod.forward(*cast_all(args, from_dtype=from_dtype, to_dtype=torch.float32)) + ret = self.mod.forward( + *cast_all(args, from_dtype=from_dtype, to_dtype=torch.float32) + ) return cast_all(ret, from_dtype=torch.float32, to_dtype=from_dtype) diff --git a/scripts/utils/export_utils.py b/scripts/utils/export_utils.py index d09cfce..b7eecbd 100644 --- a/scripts/utils/export_utils.py +++ b/scripts/utils/export_utils.py @@ -1,6 +1,6 @@ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# +# # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, @@ -22,16 +22,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -from contextlib import nullcontext -from enum import Enum from typing import Callable, Dict, Optional, Type -import logging + import torch import torch.nn as nn import torch.nn.functional as F -from .cast_utils import CastToFloat, CastToFloatAll +from .cast_utils import CastToFloat + class LinearWithBiasSkip(nn.Module): def __init__(self, weight, bias, skip_bias_add): @@ -45,7 +43,10 @@ def forward(self, x): return F.linear(x, self.weight), self.bias return F.linear(x, self.weight, self.bias), None -def run_ts_and_compare(ts_model, ts_input_list, ts_input_dict, output_example, check_tolerance=0.01): + +def run_ts_and_compare( + ts_model, ts_input_list, ts_input_dict, output_example, check_tolerance=0.01 +): # Verify the model can be read, and is valid ts_out = ts_model(*ts_input_list, **ts_input_dict) @@ -54,16 +55,20 @@ def run_ts_and_compare(ts_model, ts_input_list, ts_input_dict, output_example, c expected = output_example[i] if torch.is_tensor(expected): - tout = out.to('cpu') + tout = out.to("cpu") print(f"Checking output {i}, shape: {expected.shape}:\n") this_good = True try: - if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=check_tolerance): + if not torch.allclose( + tout, expected.cpu(), rtol=check_tolerance, atol=check_tolerance + ): this_good = False except Exception: # there may ne size mismatch and it may be OK this_good = False if not this_good: - print(f"Results mismatch! PyTorch(expected):\n{expected}\nTorchScript:\n{tout}") + print( + f"Results mismatch! PyTorch(expected):\n{expected}\nTorchScript:\n{tout}" + ) all_good = False return all_good @@ -80,12 +85,19 @@ def run_ort_and_compare(sess, ort_input, output_example, check_tolerance=0.01): print(f"Checking output {i}, shape: {expected.shape}:\n") this_good = True try: - if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=100 * check_tolerance): + if not torch.allclose( + tout, + expected.cpu(), + rtol=check_tolerance, + atol=100 * check_tolerance, + ): this_good = False except Exception: # there may ne size mismatch and it may be OK this_good = False if not this_good: - print(f"onnxruntime results mismatch! PyTorch(expected):\n{expected}\nONNXruntime:\n{tout}") + print( + f"onnxruntime results mismatch! PyTorch(expected):\n{expected}\nONNXruntime:\n{tout}" + ) all_good = False return all_good @@ -96,7 +108,10 @@ def run_ort_and_compare(sess, ort_input, output_example, check_tolerance=0.01): from apex.contrib.layer_norm.layer_norm import FastLayerNorm from apex.normalization.fused_layer_norm import FusedLayerNorm, MixedFusedLayerNorm from apex.transformer.functional.fused_softmax import FusedScaleMaskSoftmax - from apex.transformer.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear + from apex.transformer.tensor_parallel.layers import ( + ColumnParallelLinear, + RowParallelLinear, + ) def replace_FusedLayerNorm(n: nn.Module) -> Optional[nn.LayerNorm]: """ @@ -115,7 +130,9 @@ def replace_FusedLayerNorm(n: nn.Module) -> Optional[nn.LayerNorm]: else: return None - mod = nn.LayerNorm(shape, eps=eps, elementwise_affine=affine, device=p.device, dtype=p.dtype) + mod = nn.LayerNorm( + shape, eps=eps, elementwise_affine=affine, device=p.device, dtype=p.dtype + ) n_state = n.state_dict() mod.load_state_dict(n_state) return mod @@ -129,7 +146,9 @@ def replace_RowParallelLinear(n: nn.Module) -> Optional[nn.Linear]: Equivalent LayerNorm module """ if not isinstance(n, RowParallelLinear): - raise ValueError("This function can only change the RowParallelLinear module.") + raise ValueError( + "This function can only change the RowParallelLinear module." + ) dev = next(n.parameters()).device mod = LinearWithBiasSkip(n.weight, n.bias, n.skip_bias_add).to(device=dev) @@ -146,8 +165,12 @@ def replace_ParallelLinear(n: nn.Module) -> Optional[nn.Linear]: Returns: Equivalent Linear module """ - if not (isinstance(n, ColumnParallelLinear) or isinstance(n, RowParallelLinear)): - raise ValueError("This function can only change the ColumnParallelLinear or RowParallelLinear module.") + if not ( + isinstance(n, ColumnParallelLinear) or isinstance(n, RowParallelLinear) + ): + raise ValueError( + "This function can only change the ColumnParallelLinear or RowParallelLinear module." + ) dev = next(n.parameters()).device mod = LinearWithBiasSkip(n.weight, n.bias, n.skip_bias_add).to(dev) @@ -165,11 +188,19 @@ def replace_FusedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]: Equivalent LayerNorm module """ if not isinstance(n, FusedScaleMaskSoftmax): - raise ValueError("This function can only change the FusedScaleMaskSoftmax module.") + raise ValueError( + "This function can only change the FusedScaleMaskSoftmax module." + ) # disable the fusion only mod = FusedScaleMaskSoftmax( - n.input_in_fp16, n.input_in_bf16, n.attn_mask_type, False, n.mask_func, n.softmax_in_fp32, n.scale + n.input_in_fp16, + n.input_in_bf16, + n.attn_mask_type, + False, + n.mask_func, + n.softmax_in_fp32, + n.scale, ) return mod @@ -178,18 +209,20 @@ def replace_FusedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]: "FusedLayerNorm": replace_FusedLayerNorm, "MixedFusedLayerNorm": replace_FusedLayerNorm, "FastLayerNorm": replace_FusedLayerNorm, - "ESM1bLayerNorm" : replace_FusedLayerNorm, + "ESM1bLayerNorm": replace_FusedLayerNorm, "RowParallelLinear": replace_ParallelLinear, "ColumnParallelLinear": replace_ParallelLinear, "FusedScaleMaskSoftmax": replace_FusedScaleMaskSoftmax, } -except Exception as e: +except Exception: default_Apex_replacements = {} apex_available = False -def simple_replace(BaseT: Type[nn.Module], DestT: Type[nn.Module]) -> Callable[[nn.Module], Optional[nn.Module]]: +def simple_replace( + BaseT: Type[nn.Module], DestT: Type[nn.Module] +) -> Callable[[nn.Module], Optional[nn.Module]]: """ Generic function generator to replace BaseT module with DestT. BaseT and DestT should have same atrributes. No weights are copied. Args: @@ -218,18 +251,28 @@ def replace_MatchedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]: exportable module """ # including the import here to avoid circular imports - from nemo.collections.nlp.modules.common.megatron.fused_softmax import MatchedScaleMaskSoftmax + from nemo.collections.nlp.modules.common.megatron.fused_softmax import ( + MatchedScaleMaskSoftmax, + ) # disabling fusion for the MatchedScaleMaskSoftmax mod = MatchedScaleMaskSoftmax( - n.input_in_fp16, n.input_in_bf16, n.attn_mask_type, False, n.mask_func, n.softmax_in_fp32, n.scale + n.input_in_fp16, + n.input_in_bf16, + n.attn_mask_type, + False, + n.mask_func, + n.softmax_in_fp32, + n.scale, ) return mod -def wrap_module(BaseT: Type[nn.Module], DestT: Type[nn.Module]) -> Callable[[nn.Module], Optional[nn.Module]]: +def wrap_module( + BaseT: Type[nn.Module], DestT: Type[nn.Module] +) -> Callable[[nn.Module], Optional[nn.Module]]: """ - Generic function generator to replace BaseT module with DestT wrapper. + Generic function generator to replace BaseT module with DestT wrapper. Args: BaseT : module type to replace DestT : destination module type @@ -256,14 +299,15 @@ def swap_modules(model: nn.Module, mapping: Dict[str, nn.Module]): expanded_path = path.split(".") parent_mod = model for sub_path in expanded_path[:-1]: - parent_mod = parent_mod._modules[sub_path] # noqa - parent_mod._modules[expanded_path[-1]] = new_mod # noqa + parent_mod = parent_mod._modules[sub_path] + parent_mod._modules[expanded_path[-1]] = new_mod return model def replace_modules( - model: nn.Module, expansions: Dict[str, Callable[[nn.Module], Optional[nn.Module]]] = None + model: nn.Module, + expansions: Dict[str, Callable[[nn.Module], Optional[nn.Module]]] = None, ) -> nn.Module: """ Top-level function to replace modules in model, specified by class name with a desired replacement. @@ -308,7 +352,7 @@ def replace_for_export(model: nn.Module, do_cast: bool = False) -> nn.Module: if apex_available: print("Replacing Apex layers ...") replace_modules(model, default_Apex_replacements) - + if do_cast: print("Adding casts around norms...") cast_replacements = { @@ -319,6 +363,6 @@ def replace_for_export(model: nn.Module, do_cast: bool = False) -> nn.Module: "InstanceNorm3d": wrap_module(nn.InstanceNorm3d, CastToFloat), } replace_modules(model, cast_replacements) - + # This one has to be the last replace_modules(model, script_replacements) diff --git a/scripts/utils/trt_utils.py b/scripts/utils/trt_utils.py index 0fc6d6b..77cc7c6 100644 --- a/scripts/utils/trt_utils.py +++ b/scripts/utils/trt_utils.py @@ -1,6 +1,6 @@ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# +# # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, @@ -26,65 +26,69 @@ # limitations under the License. # -from collections import OrderedDict -from typing import List -from copy import copy -import numpy as np import os import pickle -from PIL import Image -from polygraphy.backend.common import bytes_from_path -from polygraphy.backend.onnx import onnx_from_path, fold_constants, save_onnx -from polygraphy.backend.onnxrt import OnnxrtRunner, session_from_onnx -from polygraphy.backend.trt import TrtRunner, CreateConfig, ModifyNetworkOutputs, Profile -from polygraphy.backend.trt import engine_from_bytes, engine_from_network, network_from_onnx_path, save_engine -from polygraphy.logger import G_LOGGER as L_ +import threading +from collections import OrderedDict -import random -from scipy import integrate import tensorrt as trt import torch -import traceback - -from io import BytesIO from cuda import cudart -from enum import Enum, auto - -import threading +from polygraphy.backend.common import bytes_from_path +from polygraphy.backend.onnx import fold_constants, onnx_from_path, save_onnx +from polygraphy.backend.onnxrt import OnnxrtRunner, session_from_onnx +from polygraphy.backend.trt import ( + CreateConfig, + ModifyNetworkOutputs, + Profile, + TrtRunner, + engine_from_bytes, + engine_from_network, + network_from_onnx_path, + save_engine, +) +from polygraphy.logger import G_LOGGER as L_ # TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE) # trt.init_libnvinfer_plugins(TRT_LOGGER, '') lock_sm = threading.Lock() + @torch.jit.script def check_m(m): t = torch.isnan(m) return not torch.any(t) + # Map of torch dtype -> numpy dtype trt_to_torch_dtype_dict = { - trt.int32 : torch.int32, + trt.int32: torch.int32, trt.float32: torch.float32, trt.float16: torch.float16, - trt.bfloat16 : torch.float16, - trt.int64 : torch.int64, - trt.int8 : torch.int8, - trt.bool : torch.bool, + trt.bfloat16: torch.float16, + trt.int64: torch.int64, + trt.int8: torch.int8, + trt.bool: torch.bool, } + def CUASSERT(cuda_ret): err = cuda_ret[0] if err != 0: - raise RuntimeError(f"CUDA ERROR: {err}, error code reference: https://nvidia.github.io/cuda-python/module/cudart.html#cuda.cudart.cudaError_t") + raise RuntimeError( + f"CUDA ERROR: {err}, error code reference: https://nvidia.github.io/cuda-python/module/cudart.html#cuda.cudart.cudaError_t" + ) if len(cuda_ret) > 1: return cuda_ret[1] return None + class ShapeException(Exception): pass -class Engine(): + +class Engine: def __init__( self, engine_path, @@ -93,24 +97,32 @@ def __init__( self.engine = None self.context = None self.tensors = OrderedDict() - self.cuda_graph_instance = None # cuda graph - - def build(self, onnx_path, - profiles=[], fp16=False, bf16=False, tf32=True, - builder_optimization_level=3, - enable_all_tactics=True, - direct_io=False, - timing_cache=None, - update_output_names=None): + self.cuda_graph_instance = None # cuda graph + + def build( + self, + onnx_path, + profiles=[], + fp16=False, + bf16=False, + tf32=True, + builder_optimization_level=3, + enable_all_tactics=True, + direct_io=False, + timing_cache=None, + update_output_names=None, + ): L_.info(f"Building TensorRT engine for {onnx_path}: {self.engine_path}") config_kwargs = { - 'builder_optimization_level' : builder_optimization_level, - 'direct_io' : direct_io, + "builder_optimization_level": builder_optimization_level, + "direct_io": direct_io, } if not enable_all_tactics: - config_kwargs['tactic_sources'] = [] + config_kwargs["tactic_sources"] = [] - network = network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]) + network = network_from_onnx_path( + onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM] + ) if update_output_names: L_.info(f"Updating network outputs to {update_output_names}") network = ModifyNetworkOutputs(network, update_output_names) @@ -118,22 +130,22 @@ def build(self, onnx_path, L_.info("Calling engine_from_network...") engine = engine_from_network( - network, - config=CreateConfig( - fp16=fp16, - bf16=bf16, - tf32=tf32, - profiles=profiles, - load_timing_cache=timing_cache, - **config_kwargs - ), - save_timing_cache=timing_cache + network, + config=CreateConfig( + fp16=fp16, + bf16=bf16, + tf32=tf32, + profiles=profiles, + load_timing_cache=timing_cache, + **config_kwargs, + ), + save_timing_cache=timing_cache, ) self.engine = engine - + def save(self): save_engine(self.engine, path=self.engine_path) - + def load(self): L_.info(f"Loading TensorRT engine: {self.engine_path}") self.engine = engine_from_bytes(bytes_from_path(self.engine_path)) @@ -155,18 +167,20 @@ def activate(self, profile_num=0, reuse_device_memory=None): self.output_names.append(binding) dtype = trt_to_torch_dtype_dict[self.engine.get_tensor_dtype(binding)] self.dtypes.append(dtype) - self.cur_profile = profile_num + self.cur_profile = profile_num # L_.info(self.input_names) # L_.info(self.output_names) - - def allocate_buffers(self, device): + + def allocate_buffers(self, device): # allocate outputs e = self.engine ctx = self.context - + for i, binding in enumerate(self.output_names): - shape=ctx.get_tensor_shape(binding) - t = torch.empty(list(shape), dtype=self.dtypes[i], device=device).contiguous() + shape = ctx.get_tensor_shape(binding) + t = torch.empty( + list(shape), dtype=self.dtypes[i], device=device + ).contiguous() self.tensors[binding] = t ctx.set_tensor_address(binding, t.data_ptr()) @@ -180,41 +194,39 @@ def check_shape(shape, profile): if s < minlist[i] or s > maxlist[i]: good = False return good - - def set_inputs(self, feed_dict, stream): + + def set_inputs(self, feed_dict, stream): e = self.engine ctx = self.context last_profile = self.cur_profile - + def try_set_inputs(): - for binding, t in feed_dict.items(): + for binding, t in feed_dict.items(): if t is not None: t = t.contiguous() shape = t.shape # mincurmax = list(e.get_profile_shape(self.cur_profile, binding)) # if not self.check_shape(shape, mincurmax): - # raise ShapeException(f"Input shape to be set is outside the bounds: {binding} -> {shape}, profile is {mincurmax}, trying another profile: {self.cur_profile}") + # raise ShapeException(f"Input shape to be set is outside the bounds: {binding} -> {shape}, profile is {mincurmax}, trying another profile: {self.cur_profile}") ctx.set_input_shape(binding, shape) ctx.set_tensor_address(binding, t.data_ptr()) while True: try: try_set_inputs() - break; + break except ShapeException: - next_profile = (self.cur_profile+1) % e.num_optimization_profiles + next_profile = (self.cur_profile + 1) % e.num_optimization_profiles if next_profile == last_profile: raise self.cur_profile = next_profile ctx.set_optimization_profile_async(self.cur_profile, stream) # torch.cuda.synchronize() - - left = ctx.infer_shapes() - assert len(left)==0 + left = ctx.infer_shapes() + assert len(left) == 0 - - def infer(self, stream, use_cuda_graph=False): + def infer(self, stream, use_cuda_graph=False): e = self.engine ctx = self.context if use_cuda_graph: @@ -225,19 +237,26 @@ def infer(self, stream, use_cuda_graph=False): # do inference before CUDA graph capture noerror = self.context.execute_async_v3(stream) if not noerror: - raise ValueError(f"ERROR: inference failed.") + raise ValueError("ERROR: inference failed.") # capture cuda graph - CUASSERT(cudart.cudaStreamBeginCapture(stream, cudart.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal)) + CUASSERT( + cudart.cudaStreamBeginCapture( + stream, + cudart.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal, + ) + ) self.context.execute_async_v3(stream) graph = CUASSERT(cudart.cudaStreamEndCapture(stream)) - self.cuda_graph_instance = CUASSERT(cudart.cudaGraphInstantiate(graph, 0)) + self.cuda_graph_instance = CUASSERT( + cudart.cudaGraphInstantiate(graph, 0) + ) print("CUDA Graph captured!") else: noerror = self.context.execute_async_v3(stream) CUASSERT(cudart.cudaStreamSynchronize(stream)) if not noerror: - raise ValueError(f"ERROR: inference failed.") - + raise ValueError("ERROR: inference failed.") + return self.tensors @@ -245,10 +264,8 @@ class ExportWrapper(torch.nn.Module): """ An auxiliary class to facilitate ONNX->TRT export of a module """ - def __init__(self, model, - input_names=None, - output_names=None, - precision="fp32"): + + def __init__(self, model, input_names=None, output_names=None, precision="fp32"): super().__init__() self.input_names = input_names self.output_names = output_names @@ -256,13 +273,13 @@ def __init__(self, model, self.model = model self.precision = precision - + def get_export_obj(self): return self.model def sample_profile(self, min_len=None, max_len=None): return None - + def can_handle(self, **args): return True @@ -271,17 +288,19 @@ def wrap(cls, model, **args): wrapper = cls(model, **args) return wrapper - + @torch.jit.script def no_nans(m): t = torch.isnan(m) return not torch.any(t) + class TRTWrapper(torch.nn.Module): """ An auxiliary class to implement running of TRT optimized engines - + """ + def __init__(self, path, exp, use_cuda_graph=False): super().__init__() self.exp_wrapper = None @@ -291,26 +310,29 @@ def __init__(self, path, exp, use_cuda_graph=False): self.jit_model = None self.onnx_runner = None self.path = path - self.use_cuda_graph=use_cuda_graph + self.use_cuda_graph = use_cuda_graph if exp is not None: self.attach(exp) @property def engine_path(self): - return self.path + '.plan' + return self.path + ".plan" + @property def jit_path(self): - return self.path + '.ts' + return self.path + ".ts" + @property def onnx_path(self): - return self.path + '.onnx' + return self.path + ".onnx" + @property def profiles_path(self): - return self.path + '.profiles.pkl' + return self.path + ".profiles.pkl" def has_engine(self): return self.engine is not None - + def has_onnx(self): return os.path.exists(self.onnx_path) @@ -322,12 +344,12 @@ def has_profiles(self): def load_engine(self): try: - engine=Engine(self.engine_path) + engine = Engine(self.engine_path) engine.load() engine.activate() self.engine = engine except Exception as e: - print (f"Exception while loading the engine:\n{e}") + print(f"Exception while loading the engine:\n{e}") pass def load_jit(self): @@ -344,18 +366,18 @@ def load_onnx(self, providers=["CUDAExecutionProvider"]): onnx_runner.activate() self.onnx_runner = onnx_runner except Exception: - pass + pass def load_profiles(self): with open(self.profiles_path, "rb") as fp: profiles = pickle.load(fp) self.profiles = profiles return profiles - + def save_profiles(self): with open(self.profiles_path, "wb") as fp: pickle.dump(self.profiles, fp) - + def attach(self, exp): self.exp_wrapper = exp self.input_names = exp.input_names @@ -367,10 +389,10 @@ def can_handle(self, **args): def inputs_to_dict(self, input_example): trt_inputs = {} for i, inp in enumerate(input_example): - input_name=self.engine.input_names[i] + input_name = self.engine.input_names[i] trt_inputs[input_name] = inp return trt_inputs - + def forward(self, **args): try: if self.engine is not None: @@ -386,7 +408,7 @@ def forward(self, **args): ret = self.onnx_runner.infer(args) ret = list(ret.values()) ret = [r.cuda() for r in ret] - if len(ret)==1: + if len(ret) == 1: ret = ret[0] return ret except Exception as e: @@ -402,9 +424,9 @@ def forward_trt(self, trt_inputs): stream.wait_stream(torch.cuda.current_stream()) ret = self.engine.infer(stream.cuda_stream, use_cuda_graph=self.use_cuda_graph) ret = list(ret.values()) - #for r in ret: + # for r in ret: # assert no_nans(r), "NaNs in TRT output!" - if len(ret)==1: + if len(ret) == 1: ret = ret[0] return ret @@ -414,18 +436,22 @@ def forward_trt_runner(self, trt_inputs): ret = list(ret.values()) ret = [r.cuda() for r in ret] check = [check_m(r) for r in ret] - if len(ret)==1: + if len(ret) == 1: ret = ret[0] return ret - - def build_engine(self, input_profiles=[], - fp16=False, bf16=False, tf32=False, - builder_optimization_level=3, - direct_io=False, - enable_all_tactics=True): + def build_engine( + self, + input_profiles=[], + fp16=False, + bf16=False, + tf32=False, + builder_optimization_level=3, + direct_io=False, + enable_all_tactics=True, + ): profiles = [] - if len(input_profiles) > 0: + if len(input_profiles) > 0: for input_profile in input_profiles: if isinstance(input_profile, Profile): profiles.append(input_profile) @@ -438,61 +464,71 @@ def build_engine(self, input_profiles=[], self.profiles = profiles self.save_profiles() - engine = Engine(self.path+'.plan') - engine.build(self.onnx_path, profiles, - fp16=fp16, - bf16=bf16, - tf32=tf32, - direct_io=direct_io, - builder_optimization_level=builder_optimization_level, - enable_all_tactics=enable_all_tactics - ) + engine = Engine(self.path + ".plan") + engine.build( + self.onnx_path, + profiles, + fp16=fp16, + bf16=bf16, + tf32=tf32, + direct_io=direct_io, + builder_optimization_level=builder_optimization_level, + enable_all_tactics=enable_all_tactics, + ) engine.activate() self.engine = engine - def jit_export(self, input_example, - verbose=False, ): + def jit_export( + self, + input_example, + verbose=False, + ): self.jit_model = torch.jit.trace( self.exp_wrapper, input_example, ).eval() self.jit_model = torch.jit.freeze(self.jit_model) torch.jit.save(self.jit_model, self.jit_path) - - def onnx_export(self, input_example, - dynamo=False, - onnx_registry=None, - dynamic_shapes=None, - verbose=False, - opset_version=18, - ): + + def onnx_export( + self, + input_example, + dynamo=False, + onnx_registry=None, + dynamic_shapes=None, + verbose=False, + opset_version=18, + ): L_.info(f"Exporting to ONNX, dynamic shapes: {dynamic_shapes}") model = self.exp_wrapper.get_export_obj() from .export_utils import replace_for_export + replace_for_export(model, do_cast=True) if dynamo: - torch.onnx.export(model, - input_example, - self.onnx_path, - dynamo=dynamo, - verbose=verbose, - opset_version=opset_version, - do_constant_folding=True, - input_names=self.input_names, - output_names=self.output_names, - dynamic_shapes=dynamic_shapes + torch.onnx.export( + model, + input_example, + self.onnx_path, + dynamo=dynamo, + verbose=verbose, + opset_version=opset_version, + do_constant_folding=True, + input_names=self.input_names, + output_names=self.output_names, + dynamic_shapes=dynamic_shapes, ) else: - torch.onnx.export(model, - input_example, - self.onnx_path, - verbose=verbose, - opset_version=opset_version, - do_constant_folding=True, - input_names=self.input_names, - output_names=self.output_names, - dynamic_axes=dynamic_shapes + torch.onnx.export( + model, + input_example, + self.onnx_path, + verbose=verbose, + opset_version=opset_version, + do_constant_folding=True, + input_names=self.input_names, + output_names=self.output_names, + dynamic_axes=dynamic_shapes, ) L_.info("Folding constants...") model_onnx = onnx_from_path(self.onnx_path) @@ -506,28 +542,31 @@ def onnx_export(self, input_example, ) L_.info("Done saving model.") - def build_and_save(self, - input_example, - dynamo=False, - verbose=False, - input_profiles=[], - fp16=False, bf16=False, tf32=True, - builder_optimization_level=3, - direct_io=False, - enable_all_tactics=True): - if not self.has_engine(): + def build_and_save( + self, + input_example, + dynamo=False, + verbose=False, + input_profiles=[], + fp16=False, + bf16=False, + tf32=True, + builder_optimization_level=3, + direct_io=False, + enable_all_tactics=True, + ): + if not self.has_engine(): if not self.has_onnx(): self.onnx_export( - input_example, - dynamo=dynamo, - verbose=verbose, - ) + input_example, + dynamo=dynamo, + verbose=verbose, + ) self.build_engine( - fp16=fp16, tf32=tf32, + fp16=fp16, + tf32=tf32, direct_io=direct_io, builder_optimization_level=5, - enable_all_tactics=enable_all_tactics) + enable_all_tactics=enable_all_tactics, + ) self.engine.save() - - - diff --git a/vista3d/modeling/segresnetds.py b/vista3d/modeling/segresnetds.py index 1bf2600..bdbd134 100644 --- a/vista3d/modeling/segresnetds.py +++ b/vista3d/modeling/segresnetds.py @@ -498,7 +498,7 @@ def _forward( outputs.reverse() x = x_ - + if with_label: i = 0 for level in self.up_layers_auto: diff --git a/vista3d/modeling/vista3d.py b/vista3d/modeling/vista3d.py index c569615..9c09e97 100755 --- a/vista3d/modeling/vista3d.py +++ b/vista3d/modeling/vista3d.py @@ -302,11 +302,12 @@ def forward( (input_images,), dynamo=False, verbose=False, - fp16=True, tf32=True, + fp16=True, + tf32=True, builder_optimization_level=5, - enable_all_tactics=True + enable_all_tactics=True, ) - + out, out_auto = self.image_encoder( x=input_images, with_point=point_coords is not None, @@ -319,13 +320,17 @@ def forward( if class_vector is not None: if hasattr(self.class_head, "build_and_save"): self.class_head.build_and_save( - (out_auto, class_vector,), - fp16=True, tf32=True, + ( + out_auto, + class_vector, + ), + fp16=True, + tf32=True, dynamo=False, verbose=False, ) logits, _ = self.class_head(src=out_auto, class_vector=class_vector) - + if point_coords is not None: point_logits = self.point_head( out, point_coords, point_labels, class_vector=prompt_class From 1004dad8d0e4b1a29da16d4b5ffdd17600aedb70 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Wed, 31 Jul 2024 17:47:58 -0700 Subject: [PATCH 07/31] Fixing CI issues Signed-off-by: Boris Fomitchev --- scripts/utils/trt_utils.py | 7 ++----- vista3d/modeling/segresnetds.py | 2 +- vista3d/modeling/vista3d.py | 0 3 files changed, 3 insertions(+), 6 deletions(-) mode change 100755 => 100644 vista3d/modeling/vista3d.py diff --git a/scripts/utils/trt_utils.py b/scripts/utils/trt_utils.py index 0fc6d6b..8f37607 100644 --- a/scripts/utils/trt_utils.py +++ b/scripts/utils/trt_utils.py @@ -161,7 +161,6 @@ def activate(self, profile_num=0, reuse_device_memory=None): def allocate_buffers(self, device): # allocate outputs - e = self.engine ctx = self.context for i, binding in enumerate(self.output_names): @@ -214,9 +213,7 @@ def try_set_inputs(): - def infer(self, stream, use_cuda_graph=False): - e = self.engine - ctx = self.context + def infer(self, stream, use_cuda_graph=False): if use_cuda_graph: if self.cuda_graph_instance is not None: CUASSERT(cudart.cudaGraphLaunch(self.cuda_graph_instance, stream)) @@ -413,7 +410,7 @@ def forward_trt_runner(self, trt_inputs): ret = runner.infer(trt_inputs) ret = list(ret.values()) ret = [r.cuda() for r in ret] - check = [check_m(r) for r in ret] + # check = [check_m(r) for r in ret] if len(ret)==1: ret = ret[0] return ret diff --git a/vista3d/modeling/segresnetds.py b/vista3d/modeling/segresnetds.py index 1bf2600..5d43bb8 100644 --- a/vista3d/modeling/segresnetds.py +++ b/vista3d/modeling/segresnetds.py @@ -12,7 +12,7 @@ from __future__ import annotations from collections.abc import Callable -from typing import Union +from typing import Union, List, Tuple import numpy as np import torch diff --git a/vista3d/modeling/vista3d.py b/vista3d/modeling/vista3d.py old mode 100755 new mode 100644 From 1942aa5bb59b07ece2b0f20e3655bd16e2282c65 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 1 Aug 2024 00:51:58 +0000 Subject: [PATCH 08/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- scripts/utils/trt_utils.py | 4 ++-- vista3d/modeling/segresnetds.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/utils/trt_utils.py b/scripts/utils/trt_utils.py index af95821..3632119 100644 --- a/scripts/utils/trt_utils.py +++ b/scripts/utils/trt_utils.py @@ -224,7 +224,7 @@ def try_set_inputs(): left = ctx.infer_shapes() assert len(left) == 0 - + def infer(self, stream, use_cuda_graph=False): if use_cuda_graph: if self.cuda_graph_instance is not None: @@ -433,7 +433,7 @@ def forward_trt_runner(self, trt_inputs): ret = list(ret.values()) ret = [r.cuda() for r in ret] # check = [check_m(r) for r in ret] - if len(ret)==1: + if len(ret) == 1: ret = ret[0] return ret diff --git a/vista3d/modeling/segresnetds.py b/vista3d/modeling/segresnetds.py index 72bf586..ebd1d9c 100644 --- a/vista3d/modeling/segresnetds.py +++ b/vista3d/modeling/segresnetds.py @@ -12,7 +12,7 @@ from __future__ import annotations from collections.abc import Callable -from typing import Union, List, Tuple +from typing import List, Tuple, Union import numpy as np import torch From 20dce0e22494c916eda2ef821e8104a56bfb361a Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Thu, 1 Aug 2024 19:25:07 -0700 Subject: [PATCH 09/31] Improved TRT engine handling, fallback added Signed-off-by: Boris Fomitchev --- scripts/infer.py | 14 +++-- scripts/utils/trt_utils.py | 113 +++++++++++++++++++++++-------------- 2 files changed, 81 insertions(+), 46 deletions(-) diff --git a/scripts/infer.py b/scripts/infer.py index b55be61..4d3a64e 100644 --- a/scripts/infer.py +++ b/scripts/infer.py @@ -32,7 +32,12 @@ from .sliding_window import point_based_window_inferer, sliding_window_inference from .train import CONFIG from .utils.trans_utils import VistaPostTransform, get_largest_connected_component_point -from .utils.trt_utils import ExportWrapper, TRTWrapper + +try: + from .utils.trt_utils import ExportWrapper, TRTWrapper + TRT_AVAILABLE=True +except Exception: + TRT_AVAILABLE=False rearrange, _ = optional_import("einops", name="rearrange") sys.path.insert(0, os.path.abspath(os.path.dirname(__file__))) @@ -131,13 +136,14 @@ def __init__(self, config_file="./configs/infer.yaml", **override): self.save_transforms = transforms.Compose(save_transforms) self.prev_mask = None self.batch_data = None - if self.trt: + if self.trt and TRT_AVAILABLE: + ts=os.path.getmtime(config_file) en_wrapper = ExportWrapper.wrap( self.model.image_encoder.encoder, input_names=["x"], output_names=["x_out"], ) - self.model.image_encoder.encoder = TRTWrapper("Encoder", en_wrapper) + self.model.image_encoder.encoder = TRTWrapper("Encoder", en_wrapper, timestamp=ts) self.model.image_encoder.encoder.load_engine() cls_wrapper = ExportWrapper.wrap( @@ -145,7 +151,7 @@ def __init__(self, config_file="./configs/infer.yaml", **override): input_names=["src", "class_vector"], output_names=["masks", "class_embedding"], ) - self.model.class_head = TRTWrapper("ClassHead", cls_wrapper) + self.model.class_head = TRTWrapper("ClassHead", cls_wrapper, timestamp=ts) self.model.class_head.load_engine() return diff --git a/scripts/utils/trt_utils.py b/scripts/utils/trt_utils.py index af95821..811228e 100644 --- a/scripts/utils/trt_utils.py +++ b/scripts/utils/trt_utils.py @@ -26,41 +26,42 @@ # limitations under the License. # +from collections import OrderedDict +from typing import List +from copy import copy +import numpy as np import os import pickle -import threading -from collections import OrderedDict +from PIL import Image +from polygraphy.backend.common import bytes_from_path +from polygraphy.backend.onnx import onnx_from_path, fold_constants, save_onnx +from polygraphy.backend.onnxrt import OnnxrtRunner, session_from_onnx +from polygraphy.backend.trt import TrtRunner, CreateConfig, ModifyNetworkOutputs, Profile +from polygraphy.backend.trt import engine_from_bytes, engine_from_network, network_from_onnx_path, save_engine +from polygraphy.logger import G_LOGGER as L_ +import random +from scipy import integrate import tensorrt as trt import torch +import traceback + +from io import BytesIO from cuda import cudart -from polygraphy.backend.common import bytes_from_path -from polygraphy.backend.onnx import fold_constants, onnx_from_path, save_onnx -from polygraphy.backend.onnxrt import OnnxrtRunner, session_from_onnx -from polygraphy.backend.trt import ( - CreateConfig, - ModifyNetworkOutputs, - Profile, - TrtRunner, - engine_from_bytes, - engine_from_network, - network_from_onnx_path, - save_engine, -) -from polygraphy.logger import G_LOGGER as L_ +from enum import Enum, auto + +import threading # TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE) # trt.init_libnvinfer_plugins(TRT_LOGGER, '') lock_sm = threading.Lock() - @torch.jit.script def check_m(m): t = torch.isnan(m) return not torch.any(t) - # Map of torch dtype -> numpy dtype trt_to_torch_dtype_dict = { trt.int32: torch.int32, @@ -72,18 +73,28 @@ def check_m(m): trt.bool: torch.bool, } +def get_dynamic_axes(profiles, extra_axes={}): + dynamic_axes=extra_axes + for profile in profiles: + for key in profile: + axes=[] + vals=profile[key] + for i in range(len(vals[0])): + if vals[0][i] != vals[2][i]: + axes.append(i) + if len(axes) > 0: + dynamic_axes[key] = axes + # print(f"Dynamic axes = {dynamic_axes}") + return dynamic_axes def CUASSERT(cuda_ret): err = cuda_ret[0] if err != 0: - raise RuntimeError( - f"CUDA ERROR: {err}, error code reference: https://nvidia.github.io/cuda-python/module/cudart.html#cuda.cudart.cudaError_t" - ) + raise RuntimeError(f"CUDA ERROR: {err}, error code reference: https://nvidia.github.io/cuda-python/module/cudart.html#cuda.cudart.cudaError_t") if len(cuda_ret) > 1: return cuda_ret[1] return None - class ShapeException(Exception): pass @@ -97,7 +108,7 @@ def __init__( self.engine = None self.context = None self.tensors = OrderedDict() - self.cuda_graph_instance = None # cuda graph + self.cuda_graph_instance = None # cuda graph def build( self, @@ -142,7 +153,7 @@ def build( save_timing_cache=timing_cache, ) self.engine = engine - + def save(self): save_engine(self.engine, path=self.engine_path) @@ -167,11 +178,11 @@ def activate(self, profile_num=0, reuse_device_memory=None): self.output_names.append(binding) dtype = trt_to_torch_dtype_dict[self.engine.get_tensor_dtype(binding)] self.dtypes.append(dtype) - self.cur_profile = profile_num + self.cur_profile = profile_num # L_.info(self.input_names) # L_.info(self.output_names) - - def allocate_buffers(self, device): + + def allocate_buffers(self, device): # allocate outputs ctx = self.context @@ -298,7 +309,7 @@ class TRTWrapper(torch.nn.Module): """ - def __init__(self, path, exp, use_cuda_graph=False): + def __init__(self, path, exp, use_cuda_graph=False, timestamp=None): super().__init__() self.exp_wrapper = None self.prev_wrapper = None @@ -308,6 +319,16 @@ def __init__(self, path, exp, use_cuda_graph=False): self.onnx_runner = None self.path = path self.use_cuda_graph = use_cuda_graph + + if os.path.exists(self.onnx_path): + ftime=os.path.getmtime(self.onnx_path) + if timestamp is not None and ftime < timestamp: + os.remove(self.onnx_path) + else: + timestamp = ftime + if timestamp is not None and os.path.exists(self.engine_path) and os.path.getmtime(self.engine_path) < timestamp: + os.remove(self.engine_path) + if exp is not None: self.attach(exp) @@ -553,17 +574,25 @@ def build_and_save( enable_all_tactics=True, ): if not self.has_engine(): - if not self.has_onnx(): - self.onnx_export( - input_example, - dynamo=dynamo, - verbose=verbose, - ) - self.build_engine( - fp16=fp16, - tf32=tf32, - direct_io=direct_io, - builder_optimization_level=5, - enable_all_tactics=enable_all_tactics, - ) - self.engine.save() + try: + if not self.has_onnx(): + self.onnx_export( + input_example, + dynamo=dynamo, + dynamic_shapes=get_dynamic_axes(input_profiles), + verbose=verbose, + ) + self.build_engine( + input_profiles=input_profiles, + fp16=fp16, tf32=tf32, + direct_io=direct_io, + builder_optimization_level=5, + enable_all_tactics=enable_all_tactics) + self.engine.save() + os.remove(self.onnx_path) + except Exception as e: + raise e + pass + + + From 47006eb5edcc1be3608b994596a8f462ce484a26 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 2 Aug 2024 02:26:21 +0000 Subject: [PATCH 10/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- scripts/infer.py | 11 ++++-- scripts/utils/trt_utils.py | 79 +++++++++++++++++++++----------------- 2 files changed, 50 insertions(+), 40 deletions(-) diff --git a/scripts/infer.py b/scripts/infer.py index 4d3a64e..07c1e3c 100644 --- a/scripts/infer.py +++ b/scripts/infer.py @@ -35,9 +35,10 @@ try: from .utils.trt_utils import ExportWrapper, TRTWrapper - TRT_AVAILABLE=True + + TRT_AVAILABLE = True except Exception: - TRT_AVAILABLE=False + TRT_AVAILABLE = False rearrange, _ = optional_import("einops", name="rearrange") sys.path.insert(0, os.path.abspath(os.path.dirname(__file__))) @@ -137,13 +138,15 @@ def __init__(self, config_file="./configs/infer.yaml", **override): self.prev_mask = None self.batch_data = None if self.trt and TRT_AVAILABLE: - ts=os.path.getmtime(config_file) + ts = os.path.getmtime(config_file) en_wrapper = ExportWrapper.wrap( self.model.image_encoder.encoder, input_names=["x"], output_names=["x_out"], ) - self.model.image_encoder.encoder = TRTWrapper("Encoder", en_wrapper, timestamp=ts) + self.model.image_encoder.encoder = TRTWrapper( + "Encoder", en_wrapper, timestamp=ts + ) self.model.image_encoder.encoder.load_engine() cls_wrapper = ExportWrapper.wrap( diff --git a/scripts/utils/trt_utils.py b/scripts/utils/trt_utils.py index db0e941..f71edea 100644 --- a/scripts/utils/trt_utils.py +++ b/scripts/utils/trt_utils.py @@ -26,42 +26,41 @@ # limitations under the License. # -from collections import OrderedDict -from typing import List -from copy import copy -import numpy as np import os import pickle -from PIL import Image -from polygraphy.backend.common import bytes_from_path -from polygraphy.backend.onnx import onnx_from_path, fold_constants, save_onnx -from polygraphy.backend.onnxrt import OnnxrtRunner, session_from_onnx -from polygraphy.backend.trt import TrtRunner, CreateConfig, ModifyNetworkOutputs, Profile -from polygraphy.backend.trt import engine_from_bytes, engine_from_network, network_from_onnx_path, save_engine -from polygraphy.logger import G_LOGGER as L_ +import threading +from collections import OrderedDict -import random -from scipy import integrate import tensorrt as trt import torch -import traceback - -from io import BytesIO from cuda import cudart -from enum import Enum, auto - -import threading +from polygraphy.backend.common import bytes_from_path +from polygraphy.backend.onnx import fold_constants, onnx_from_path, save_onnx +from polygraphy.backend.onnxrt import OnnxrtRunner, session_from_onnx +from polygraphy.backend.trt import ( + CreateConfig, + ModifyNetworkOutputs, + Profile, + TrtRunner, + engine_from_bytes, + engine_from_network, + network_from_onnx_path, + save_engine, +) +from polygraphy.logger import G_LOGGER as L_ # TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE) # trt.init_libnvinfer_plugins(TRT_LOGGER, '') lock_sm = threading.Lock() + @torch.jit.script def check_m(m): t = torch.isnan(m) return not torch.any(t) + # Map of torch dtype -> numpy dtype trt_to_torch_dtype_dict = { trt.int32: torch.int32, @@ -73,12 +72,13 @@ def check_m(m): trt.bool: torch.bool, } + def get_dynamic_axes(profiles, extra_axes={}): - dynamic_axes=extra_axes + dynamic_axes = extra_axes for profile in profiles: for key in profile: - axes=[] - vals=profile[key] + axes = [] + vals = profile[key] for i in range(len(vals[0])): if vals[0][i] != vals[2][i]: axes.append(i) @@ -87,14 +87,18 @@ def get_dynamic_axes(profiles, extra_axes={}): # print(f"Dynamic axes = {dynamic_axes}") return dynamic_axes + def CUASSERT(cuda_ret): err = cuda_ret[0] if err != 0: - raise RuntimeError(f"CUDA ERROR: {err}, error code reference: https://nvidia.github.io/cuda-python/module/cudart.html#cuda.cudart.cudaError_t") + raise RuntimeError( + f"CUDA ERROR: {err}, error code reference: https://nvidia.github.io/cuda-python/module/cudart.html#cuda.cudart.cudaError_t" + ) if len(cuda_ret) > 1: return cuda_ret[1] return None + class ShapeException(Exception): pass @@ -108,7 +112,7 @@ def __init__( self.engine = None self.context = None self.tensors = OrderedDict() - self.cuda_graph_instance = None # cuda graph + self.cuda_graph_instance = None # cuda graph def build( self, @@ -153,7 +157,7 @@ def build( save_timing_cache=timing_cache, ) self.engine = engine - + def save(self): save_engine(self.engine, path=self.engine_path) @@ -178,11 +182,11 @@ def activate(self, profile_num=0, reuse_device_memory=None): self.output_names.append(binding) dtype = trt_to_torch_dtype_dict[self.engine.get_tensor_dtype(binding)] self.dtypes.append(dtype) - self.cur_profile = profile_num + self.cur_profile = profile_num # L_.info(self.input_names) # L_.info(self.output_names) - - def allocate_buffers(self, device): + + def allocate_buffers(self, device): # allocate outputs ctx = self.context @@ -319,14 +323,18 @@ def __init__(self, path, exp, use_cuda_graph=False, timestamp=None): self.onnx_runner = None self.path = path self.use_cuda_graph = use_cuda_graph - + if os.path.exists(self.onnx_path): - ftime=os.path.getmtime(self.onnx_path) + ftime = os.path.getmtime(self.onnx_path) if timestamp is not None and ftime < timestamp: os.remove(self.onnx_path) else: timestamp = ftime - if timestamp is not None and os.path.exists(self.engine_path) and os.path.getmtime(self.engine_path) < timestamp: + if ( + timestamp is not None + and os.path.exists(self.engine_path) + and os.path.getmtime(self.engine_path) < timestamp + ): os.remove(self.engine_path) if exp is not None: @@ -584,15 +592,14 @@ def build_and_save( ) self.build_engine( input_profiles=input_profiles, - fp16=fp16, tf32=tf32, + fp16=fp16, + tf32=tf32, direct_io=direct_io, builder_optimization_level=5, - enable_all_tactics=enable_all_tactics) + enable_all_tactics=enable_all_tactics, + ) self.engine.save() os.remove(self.onnx_path) except Exception as e: raise e pass - - - From 8b1408a8a11e5733c35c3f1216e42222cc2fd714 Mon Sep 17 00:00:00 2001 From: Mingxue Gu Date: Fri, 2 Aug 2024 12:04:19 +0000 Subject: [PATCH 11/31] add accuracy benchmark results --- dices.json | 135 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 135 insertions(+) create mode 100644 dices.json diff --git a/dices.json b/dices.json new file mode 100644 index 0000000..3a7a7a6 --- /dev/null +++ b/dices.json @@ -0,0 +1,135 @@ +{ + "liver": 0.9999347925186157, + "kidney": 1.0, + "spleen": 0.9998570084571838, + "pancreas": 0.9997349977493286, + "right kidney": 0.9999557137489319, + "aorta": 1.0, + "inferior vena cava": 0.9998636245727539, + "right adrenal gland": 1.0, + "left adrenal gland": 0.9997097253799438, + "gallbladder": 1.0, + "esophagus": 0.9997258186340332, + "stomach": 0.9999347925186157, + "duodenum": 0.9996980428695679, + "left kidney": 0.9999045729637146, + "bladder": 0.9998233318328857, + "prostate or uterus (deprecated)": 1.0, + "portal vein and splenic vein": 0.999485433101654, + "rectum (deprecated)": 1.0, + "small bowel": 0.9996098875999451, + "lung": 1.0, + "bone": 1.0, + "brain": 1.0, + "lung tumor": 1.0, + "pancreatic tumor": 1.0, + "hepatic vessel": 1.0, + "hepatic tumor": 1.0, + "colon cancer primaries": 1.0, + "left lung upper lobe": 0.9998635053634644, + "left lung lower lobe": 0.9999285340309143, + "right lung upper lobe": 1.0, + "right lung middle lobe": 0.9999430179595947, + "right lung lower lobe": 0.999975323677063, + "vertebrae L5": 0.9999445080757141, + "vertebrae L4": 0.9999210834503174, + "vertebrae L3": 0.9998977184295654, + "vertebrae L2": 0.9999402761459351, + "vertebrae L1": 1.0, + "vertebrae T12": 0.9996854662895203, + "vertebrae T11": 0.9997434616088867, + "vertebrae T10": 0.9998674392700195, + "vertebrae T9": 0.9996097087860107, + "vertebrae T8": 1.0, + "vertebrae T7": 1.0, + "vertebrae T6": 1.0, + "vertebrae T5": 1.0, + "vertebrae T4": 1.0, + "vertebrae T3": 1.0, + "vertebrae T2": 1.0, + "vertebrae T1": 1.0, + "vertebrae C7": 1.0, + "vertebrae C6": 1.0, + "vertebrae C5": 1.0, + "vertebrae C4": 1.0, + "vertebrae C3": 1.0, + "vertebrae C2": 1.0, + "vertebrae C1": 1.0, + "trachea": 1.0, + "left iliac artery": 0.998672604560852, + "right iliac artery": 0.9997827410697937, + "left iliac vena": 0.9996750950813293, + "right iliac vena": 0.9997751712799072, + "colon": 0.999774158000946, + "left rib 1": 1.0, + "left rib 2": 1.0, + "left rib 3": 1.0, + "left rib 4": 1.0, + "left rib 5": 1.0, + "left rib 6": 0.9995150566101074, + "left rib 7": 0.9989900588989258, + "left rib 8": 1.0, + "left rib 9": 0.9997802972793579, + "left rib 10": 0.9982767701148987, + "left rib 11": 1.0, + "left rib 12": 1.0, + "right rib 1": 1.0, + "right rib 2": 1.0, + "right rib 3": 1.0, + "right rib 4": 1.0, + "right rib 5": 1.0, + "right rib 6": 0.999602198600769, + "right rib 7": 1.0, + "right rib 8": 0.9992419481277466, + "right rib 9": 1.0, + "right rib 10": 0.9998047947883606, + "right rib 11": 1.0, + "right rib 12": 1.0, + "left humerus": 0.9719626307487488, + "right humerus": 0.9873417615890503, + "left scapula": 1.0, + "right scapula": 0.9997193217277527, + "left clavicula": 1.0, + "right clavicula": 1.0, + "left femur": 0.9999800324440002, + "right femur": 0.9998434782028198, + "left hip": 0.9999173879623413, + "right hip": 0.9999226927757263, + "sacrum": 0.9997125267982483, + "left gluteus maximus": 0.9998618960380554, + "right gluteus maximus": 0.9998993277549744, + "left gluteus medius": 0.9997550249099731, + "right gluteus medius": 0.9997763633728027, + "left gluteus minimus": 0.9991177916526794, + "right gluteus minimus": 0.9998393058776855, + "left autochthon": 0.9998349547386169, + "right autochthon": 0.9998183846473694, + "left iliopsoas": 0.9998319149017334, + "right iliopsoas": 0.9998192191123962, + "left atrial appendage": 1.0, + "brachiocephalic trunk": 1.0, + "left brachiocephalic vein": 1.0, + "right brachiocephalic vein": 1.0, + "left common carotid artery": 1.0, + "right common carotid artery": 1.0, + "costal cartilages": 0.9996328353881836, + "heart": 0.9997690320014954, + "left kidney cyst": 1.0, + "right kidney cyst": 0.9991546869277954, + "prostate": 1.0, + "pulmonary vein": 1.0, + "skull": 1.0, + "spinal cord": 0.9995791912078857, + "sternum": 1.0, + "left subclavian artery": 1.0, + "right subclavian artery": 1.0, + "superior vena cava": 0.9977220892906189, + "thyroid gland": 1.0, + "vertebrae S1": 0.9995207786560059, + "bone lesion": 1.0, + "kidney mass (deprecated)": 1.0, + "liver tumor (deprecated)": 1.0, + "vertebrae L6 (deprecated)": 1.0, + "airway": 1.0, + "average": 0.999543797789198 +} \ No newline at end of file From e0c0e7a3dcf50d189ffac319dd6842abfa07af0f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 2 Aug 2024 12:05:09 +0000 Subject: [PATCH 12/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- dices.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dices.json b/dices.json index 3a7a7a6..0ab713f 100644 --- a/dices.json +++ b/dices.json @@ -132,4 +132,4 @@ "vertebrae L6 (deprecated)": 1.0, "airway": 1.0, "average": 0.999543797789198 -} \ No newline at end of file +} From 2d4dfc0a316dd027b6c14c90fc2cb4e67aebc397 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Sun, 4 Aug 2024 23:27:40 -0700 Subject: [PATCH 13/31] Refactored to use TRTWrapper from MONAI Signed-off-by: Boris Fomitchev --- scripts/infer.py | 27 +- scripts/utils/cast_utils.py | 101 ------ scripts/utils/export_utils.py | 368 --------------------- scripts/utils/trt_utils.py | 598 ---------------------------------- 4 files changed, 13 insertions(+), 1081 deletions(-) delete mode 100644 scripts/utils/cast_utils.py delete mode 100644 scripts/utils/export_utils.py delete mode 100644 scripts/utils/trt_utils.py diff --git a/scripts/infer.py b/scripts/infer.py index 4d3a64e..f25c681 100644 --- a/scripts/infer.py +++ b/scripts/infer.py @@ -34,9 +34,10 @@ from .utils.trans_utils import VistaPostTransform, get_largest_connected_component_point try: - from .utils.trt_utils import ExportWrapper, TRTWrapper + from monai.utils import TRTWrapper TRT_AVAILABLE=True -except Exception: +except Exception as e: + raise e TRT_AVAILABLE=False rearrange, _ = optional_import("einops", name="rearrange") @@ -138,20 +139,18 @@ def __init__(self, config_file="./configs/infer.yaml", **override): self.batch_data = None if self.trt and TRT_AVAILABLE: ts=os.path.getmtime(config_file) - en_wrapper = ExportWrapper.wrap( - self.model.image_encoder.encoder, - input_names=["x"], - output_names=["x_out"], - ) - self.model.image_encoder.encoder = TRTWrapper("Encoder", en_wrapper, timestamp=ts) + self.model.image_encoder.encoder = TRTWrapper("Encoder", + self.model.image_encoder.encoder, + input_names=["x"], + output_names=["x_out"], + timestamp=ts) self.model.image_encoder.encoder.load_engine() - cls_wrapper = ExportWrapper.wrap( - self.model.class_head, - input_names=["src", "class_vector"], - output_names=["masks", "class_embedding"], - ) - self.model.class_head = TRTWrapper("ClassHead", cls_wrapper, timestamp=ts) + self.model.class_head = TRTWrapper("ClassHead", + self.model.class_head, + input_names=["src", "class_vector"], + output_names=["masks", "class_embedding"], + timestamp=ts) self.model.class_head.load_engine() return diff --git a/scripts/utils/cast_utils.py b/scripts/utils/cast_utils.py deleted file mode 100644 index 329033e..0000000 --- a/scripts/utils/cast_utils.py +++ /dev/null @@ -1,101 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual -# property and proprietary rights in and to this material, related -# documentation and any modifications thereto. Any use, reproduction, -# disclosure or distribution of this material and related documentation -# without an express license agreement from NVIDIA CORPORATION or -# its affiliates is strictly prohibited. - -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from contextlib import nullcontext - -import torch - - -def avoid_bfloat16_autocast_context(): - """ - If the current autocast context is bfloat16, - cast it to float32 - """ - - if torch.is_autocast_enabled() and torch.get_autocast_gpu_dtype() == torch.bfloat16: - return torch.cuda.amp.autocast(dtype=torch.float32) - else: - return nullcontext() - - -def avoid_float16_autocast_context(): - """ - If the current autocast context is float16, cast it to bfloat16 - if available (unless we're in jit) or float32 - """ - - if torch.is_autocast_enabled() and torch.get_autocast_gpu_dtype() == torch.float16: - if torch.jit.is_scripting() or torch.jit.is_tracing(): - return torch.cuda.amp.autocast(dtype=torch.float32) - - if torch.cuda.is_bf16_supported(): - return torch.cuda.amp.autocast(dtype=torch.bfloat16) - else: - return torch.cuda.amp.autocast(dtype=torch.float32) - else: - return nullcontext() - - -def cast_tensor(x, from_dtype=torch.float16, to_dtype=torch.float32): - return x.to(dtype=to_dtype) if x.dtype == from_dtype else x - - -def cast_all(x, from_dtype=torch.float16, to_dtype=torch.float32): - if isinstance(x, torch.Tensor): - return cast_tensor(x, from_dtype=from_dtype, to_dtype=to_dtype) - else: - if isinstance(x, dict): - new_dict = {} - for k in x.keys(): - new_dict[k] = cast_all(x[k], from_dtype=from_dtype, to_dtype=to_dtype) - return new_dict - elif isinstance(x, tuple): - return tuple( - cast_all(y, from_dtype=from_dtype, to_dtype=to_dtype) for y in x - ) - - -class CastToFloat(torch.nn.Module): - def __init__(self, mod): - super(CastToFloat, self).__init__() - self.mod = mod - - def forward(self, x): - with torch.cuda.amp.autocast(enabled=False): - ret = self.mod.forward(x.to(torch.float32)).to(x.dtype) - return ret - - -class CastToFloatAll(torch.nn.Module): - def __init__(self, mod): - super(CastToFloatAll, self).__init__() - self.mod = mod - - def forward(self, *args): - from_dtype = args[0].dtype - with torch.cuda.amp.autocast(enabled=False): - ret = self.mod.forward( - *cast_all(args, from_dtype=from_dtype, to_dtype=torch.float32) - ) - return cast_all(ret, from_dtype=torch.float32, to_dtype=from_dtype) diff --git a/scripts/utils/export_utils.py b/scripts/utils/export_utils.py deleted file mode 100644 index b7eecbd..0000000 --- a/scripts/utils/export_utils.py +++ /dev/null @@ -1,368 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual -# property and proprietary rights in and to this material, related -# documentation and any modifications thereto. Any use, reproduction, -# disclosure or distribution of this material and related documentation -# without an express license agreement from NVIDIA CORPORATION or -# its affiliates is strictly prohibited. - -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Callable, Dict, Optional, Type - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from .cast_utils import CastToFloat - - -class LinearWithBiasSkip(nn.Module): - def __init__(self, weight, bias, skip_bias_add): - super(LinearWithBiasSkip, self).__init__() - self.bias = bias - self.weight = weight - self.skip_bias_add = skip_bias_add - - def forward(self, x): - if self.skip_bias_add: - return F.linear(x, self.weight), self.bias - return F.linear(x, self.weight, self.bias), None - - -def run_ts_and_compare( - ts_model, ts_input_list, ts_input_dict, output_example, check_tolerance=0.01 -): - # Verify the model can be read, and is valid - ts_out = ts_model(*ts_input_list, **ts_input_dict) - - all_good = True - for i, out in enumerate(ts_out): - expected = output_example[i] - - if torch.is_tensor(expected): - tout = out.to("cpu") - print(f"Checking output {i}, shape: {expected.shape}:\n") - this_good = True - try: - if not torch.allclose( - tout, expected.cpu(), rtol=check_tolerance, atol=check_tolerance - ): - this_good = False - except Exception: # there may ne size mismatch and it may be OK - this_good = False - if not this_good: - print( - f"Results mismatch! PyTorch(expected):\n{expected}\nTorchScript:\n{tout}" - ) - all_good = False - return all_good - - -def run_ort_and_compare(sess, ort_input, output_example, check_tolerance=0.01): - # Verify the model can be read, and is valid - ort_out = sess.run(None, ort_input) - all_good = True - for i, out in enumerate(ort_out): - expected = output_example[i] - - if torch.is_tensor(expected): - tout = torch.from_numpy(out) - print(f"Checking output {i}, shape: {expected.shape}:\n") - this_good = True - try: - if not torch.allclose( - tout, - expected.cpu(), - rtol=check_tolerance, - atol=100 * check_tolerance, - ): - this_good = False - except Exception: # there may ne size mismatch and it may be OK - this_good = False - if not this_good: - print( - f"onnxruntime results mismatch! PyTorch(expected):\n{expected}\nONNXruntime:\n{tout}" - ) - all_good = False - return all_good - - -apex_available = True - -try: - from apex.contrib.layer_norm.layer_norm import FastLayerNorm - from apex.normalization.fused_layer_norm import FusedLayerNorm, MixedFusedLayerNorm - from apex.transformer.functional.fused_softmax import FusedScaleMaskSoftmax - from apex.transformer.tensor_parallel.layers import ( - ColumnParallelLinear, - RowParallelLinear, - ) - - def replace_FusedLayerNorm(n: nn.Module) -> Optional[nn.LayerNorm]: - """ - Replaces Apex's FusedLayerNorm with nn.LayerNorm. This is required for ONNX export. - Args: - n: the FusedLayerNorm pytorch module to replace - Returns: - Equivalent LayerNorm module - """ - - p = next(n.parameters()) - if isinstance(n, FusedLayerNorm) or isinstance(n, MixedFusedLayerNorm): - shape, eps, affine = n.normalized_shape, n.eps, n.elementwise_affine - elif isinstance(n, FastLayerNorm): - shape, eps, affine = n.weight.shape, n.epsilon, True - else: - return None - - mod = nn.LayerNorm( - shape, eps=eps, elementwise_affine=affine, device=p.device, dtype=p.dtype - ) - n_state = n.state_dict() - mod.load_state_dict(n_state) - return mod - - def replace_RowParallelLinear(n: nn.Module) -> Optional[nn.Linear]: - """ - Replaces Apex's FusedLayerNorm with nn.LayerNorm. This is required for ONNX export. - Args: - n: the FusedLayerNorm pytorch module to replace - Returns: - Equivalent LayerNorm module - """ - if not isinstance(n, RowParallelLinear): - raise ValueError( - "This function can only change the RowParallelLinear module." - ) - - dev = next(n.parameters()).device - mod = LinearWithBiasSkip(n.weight, n.bias, n.skip_bias_add).to(device=dev) - - n_state = n.state_dict() - mod.load_state_dict(n_state) - return mod - - def replace_ParallelLinear(n: nn.Module) -> Optional[nn.Linear]: - """ - Replaces Apex's ColumnParallelLinear or RowParallelLinear with nn.Linear - Args: - n: the nn.Module pytorch module to replace - Returns: - Equivalent Linear module - """ - if not ( - isinstance(n, ColumnParallelLinear) or isinstance(n, RowParallelLinear) - ): - raise ValueError( - "This function can only change the ColumnParallelLinear or RowParallelLinear module." - ) - - dev = next(n.parameters()).device - mod = LinearWithBiasSkip(n.weight, n.bias, n.skip_bias_add).to(dev) - - n_state = n.state_dict() - mod.load_state_dict(n_state) - return mod - - def replace_FusedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]: - """ - Replaces Apex's FusedScaleMaskSoftmax with nn.LayerNorm. This is required for ONNX export. - Args: - n: the FusedScaleMaskSoftmax module to replace - Returns: - Equivalent LayerNorm module - """ - if not isinstance(n, FusedScaleMaskSoftmax): - raise ValueError( - "This function can only change the FusedScaleMaskSoftmax module." - ) - - # disable the fusion only - mod = FusedScaleMaskSoftmax( - n.input_in_fp16, - n.input_in_bf16, - n.attn_mask_type, - False, - n.mask_func, - n.softmax_in_fp32, - n.scale, - ) - - return mod - - default_Apex_replacements = { - "FusedLayerNorm": replace_FusedLayerNorm, - "MixedFusedLayerNorm": replace_FusedLayerNorm, - "FastLayerNorm": replace_FusedLayerNorm, - "ESM1bLayerNorm": replace_FusedLayerNorm, - "RowParallelLinear": replace_ParallelLinear, - "ColumnParallelLinear": replace_ParallelLinear, - "FusedScaleMaskSoftmax": replace_FusedScaleMaskSoftmax, - } - -except Exception: - default_Apex_replacements = {} - apex_available = False - - -def simple_replace( - BaseT: Type[nn.Module], DestT: Type[nn.Module] -) -> Callable[[nn.Module], Optional[nn.Module]]: - """ - Generic function generator to replace BaseT module with DestT. BaseT and DestT should have same atrributes. No weights are copied. - Args: - BaseT : module type to replace - DestT : destination module type - Returns: - swap function to replace BaseT module with DestT - """ - - def expansion_fn(mod: nn.Module) -> Optional[nn.Module]: - if not isinstance(mod, BaseT): - return None - args = [getattr(mod, name, None) for name in mod.__constants__] - out = DestT(*args) - return out - - return expansion_fn - - -def replace_MatchedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]: - """ - Replaces MatchedScaleMaskSoftmax with exportable softmax layer - Args: - n: module to replace - Returns: - exportable module - """ - # including the import here to avoid circular imports - from nemo.collections.nlp.modules.common.megatron.fused_softmax import ( - MatchedScaleMaskSoftmax, - ) - - # disabling fusion for the MatchedScaleMaskSoftmax - mod = MatchedScaleMaskSoftmax( - n.input_in_fp16, - n.input_in_bf16, - n.attn_mask_type, - False, - n.mask_func, - n.softmax_in_fp32, - n.scale, - ) - return mod - - -def wrap_module( - BaseT: Type[nn.Module], DestT: Type[nn.Module] -) -> Callable[[nn.Module], Optional[nn.Module]]: - """ - Generic function generator to replace BaseT module with DestT wrapper. - Args: - BaseT : module type to replace - DestT : destination module type - Returns: - swap function to replace BaseT module with DestT - """ - - def expansion_fn(mod: nn.Module) -> Optional[nn.Module]: - out = DestT(mod) - return out - - return expansion_fn - - -def swap_modules(model: nn.Module, mapping: Dict[str, nn.Module]): - """ - This function swaps nested modules as specified by "dot paths" in mod with a desired replacement. This allows - for swapping nested modules through arbitrary levels if children - - NOTE: This occurs in place, if you want to preserve model then make sure to copy it first. - - """ - for path, new_mod in mapping.items(): - expanded_path = path.split(".") - parent_mod = model - for sub_path in expanded_path[:-1]: - parent_mod = parent_mod._modules[sub_path] - parent_mod._modules[expanded_path[-1]] = new_mod - - return model - - -def replace_modules( - model: nn.Module, - expansions: Dict[str, Callable[[nn.Module], Optional[nn.Module]]] = None, -) -> nn.Module: - """ - Top-level function to replace modules in model, specified by class name with a desired replacement. - NOTE: This occurs in place, if you want to preserve model then make sure to copy it first. - Args: - model : top level module - expansions : replacement dictionary: module class name -> replacement function generator - Returns: - model, possibly modified in-place - """ - mapping: Dict[str, nn.Module] = {} - for name, m in model.named_modules(): - m_type = type(m).__name__ - if m_type in expansions: - # print (f"Found {m_type} in expansions ...") - swapped = expansions[m_type](m) - if swapped: - mapping[name] = swapped - - print(f"Swapped {len(mapping)} modules") - swap_modules(model, mapping) - return model - - -def script_module(m: nn.Module): - return torch.jit.script(m) - - -script_replacements = {} - - -def replace_for_export(model: nn.Module, do_cast: bool = False) -> nn.Module: - """ - Top-level function to replace default set of modules in model - NOTE: This occurs in place, if you want to preserve model then make sure to copy it first. - Args: - model : top level module - replace_1D_2D : include 1D -> 2D replacements - Returns: - model, possibly modified in-place - """ - if apex_available: - print("Replacing Apex layers ...") - replace_modules(model, default_Apex_replacements) - - if do_cast: - print("Adding casts around norms...") - cast_replacements = { - "BatchNorm1d": wrap_module(nn.BatchNorm1d, CastToFloat), - "BatchNorm2d": wrap_module(nn.BatchNorm2d, CastToFloat), - "LayerNorm": wrap_module(nn.LayerNorm, CastToFloat), - "InstanceNorm1d": wrap_module(nn.InstanceNorm1d, CastToFloat), - "InstanceNorm3d": wrap_module(nn.InstanceNorm3d, CastToFloat), - } - replace_modules(model, cast_replacements) - - # This one has to be the last - replace_modules(model, script_replacements) diff --git a/scripts/utils/trt_utils.py b/scripts/utils/trt_utils.py deleted file mode 100644 index db0e941..0000000 --- a/scripts/utils/trt_utils.py +++ /dev/null @@ -1,598 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual -# property and proprietary rights in and to this material, related -# documentation and any modifications thereto. Any use, reproduction, -# disclosure or distribution of this material and related documentation -# without an express license agreement from NVIDIA CORPORATION or -# its affiliates is strictly prohibited. - -# -# Copyright 2022 The HuggingFace Inc. team. -# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from collections import OrderedDict -from typing import List -from copy import copy -import numpy as np -import os -import pickle -from PIL import Image -from polygraphy.backend.common import bytes_from_path -from polygraphy.backend.onnx import onnx_from_path, fold_constants, save_onnx -from polygraphy.backend.onnxrt import OnnxrtRunner, session_from_onnx -from polygraphy.backend.trt import TrtRunner, CreateConfig, ModifyNetworkOutputs, Profile -from polygraphy.backend.trt import engine_from_bytes, engine_from_network, network_from_onnx_path, save_engine -from polygraphy.logger import G_LOGGER as L_ - -import random -from scipy import integrate -import tensorrt as trt -import torch -import traceback - -from io import BytesIO -from cuda import cudart -from enum import Enum, auto - -import threading - -# TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE) -# trt.init_libnvinfer_plugins(TRT_LOGGER, '') - -lock_sm = threading.Lock() - -@torch.jit.script -def check_m(m): - t = torch.isnan(m) - return not torch.any(t) - -# Map of torch dtype -> numpy dtype -trt_to_torch_dtype_dict = { - trt.int32: torch.int32, - trt.float32: torch.float32, - trt.float16: torch.float16, - trt.bfloat16: torch.float16, - trt.int64: torch.int64, - trt.int8: torch.int8, - trt.bool: torch.bool, -} - -def get_dynamic_axes(profiles, extra_axes={}): - dynamic_axes=extra_axes - for profile in profiles: - for key in profile: - axes=[] - vals=profile[key] - for i in range(len(vals[0])): - if vals[0][i] != vals[2][i]: - axes.append(i) - if len(axes) > 0: - dynamic_axes[key] = axes - # print(f"Dynamic axes = {dynamic_axes}") - return dynamic_axes - -def CUASSERT(cuda_ret): - err = cuda_ret[0] - if err != 0: - raise RuntimeError(f"CUDA ERROR: {err}, error code reference: https://nvidia.github.io/cuda-python/module/cudart.html#cuda.cudart.cudaError_t") - if len(cuda_ret) > 1: - return cuda_ret[1] - return None - -class ShapeException(Exception): - pass - - -class Engine: - def __init__( - self, - engine_path, - ): - self.engine_path = engine_path - self.engine = None - self.context = None - self.tensors = OrderedDict() - self.cuda_graph_instance = None # cuda graph - - def build( - self, - onnx_path, - profiles=[], - fp16=False, - bf16=False, - tf32=True, - builder_optimization_level=3, - enable_all_tactics=True, - direct_io=False, - timing_cache=None, - update_output_names=None, - ): - L_.info(f"Building TensorRT engine for {onnx_path}: {self.engine_path}") - config_kwargs = { - "builder_optimization_level": builder_optimization_level, - "direct_io": direct_io, - } - if not enable_all_tactics: - config_kwargs["tactic_sources"] = [] - - network = network_from_onnx_path( - onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM] - ) - if update_output_names: - L_.info(f"Updating network outputs to {update_output_names}") - network = ModifyNetworkOutputs(network, update_output_names) - # with L.verbosity(0): - L_.info("Calling engine_from_network...") - - engine = engine_from_network( - network, - config=CreateConfig( - fp16=fp16, - bf16=bf16, - tf32=tf32, - profiles=profiles, - load_timing_cache=timing_cache, - **config_kwargs, - ), - save_timing_cache=timing_cache, - ) - self.engine = engine - - def save(self): - save_engine(self.engine, path=self.engine_path) - - def load(self): - L_.info(f"Loading TensorRT engine: {self.engine_path}") - self.engine = engine_from_bytes(bytes_from_path(self.engine_path)) - - def activate(self, profile_num=0, reuse_device_memory=None): - if reuse_device_memory: - self.context = self.engine.create_execution_context_without_device_memory() - self.context.device_memory = reuse_device_memory - else: - self.context = self.engine.create_execution_context() - self.input_names = [] - self.output_names = [] - self.dtypes = [] - for idx in range(self.engine.num_io_tensors): - binding = self.engine[idx] - if self.engine.get_tensor_mode(binding) == trt.TensorIOMode.INPUT: - self.input_names.append(binding) - elif self.engine.get_tensor_mode(binding) == trt.TensorIOMode.OUTPUT: - self.output_names.append(binding) - dtype = trt_to_torch_dtype_dict[self.engine.get_tensor_dtype(binding)] - self.dtypes.append(dtype) - self.cur_profile = profile_num - # L_.info(self.input_names) - # L_.info(self.output_names) - - def allocate_buffers(self, device): - # allocate outputs - ctx = self.context - - for i, binding in enumerate(self.output_names): - shape = ctx.get_tensor_shape(binding) - t = torch.empty( - list(shape), dtype=self.dtypes[i], device=device - ).contiguous() - self.tensors[binding] = t - ctx.set_tensor_address(binding, t.data_ptr()) - - @staticmethod - def check_shape(shape, profile): - shape = list(shape) - minlist = profile[0] - maxlist = profile[2] - good = True - for i, s in enumerate(shape): - if s < minlist[i] or s > maxlist[i]: - good = False - return good - - def set_inputs(self, feed_dict, stream): - e = self.engine - ctx = self.context - last_profile = self.cur_profile - - def try_set_inputs(): - for binding, t in feed_dict.items(): - if t is not None: - t = t.contiguous() - shape = t.shape - # mincurmax = list(e.get_profile_shape(self.cur_profile, binding)) - # if not self.check_shape(shape, mincurmax): - # raise ShapeException(f"Input shape to be set is outside the bounds: {binding} -> {shape}, profile is {mincurmax}, trying another profile: {self.cur_profile}") - ctx.set_input_shape(binding, shape) - ctx.set_tensor_address(binding, t.data_ptr()) - - while True: - try: - try_set_inputs() - break - except ShapeException: - next_profile = (self.cur_profile + 1) % e.num_optimization_profiles - if next_profile == last_profile: - raise - self.cur_profile = next_profile - ctx.set_optimization_profile_async(self.cur_profile, stream) - # torch.cuda.synchronize() - - left = ctx.infer_shapes() - assert len(left) == 0 - - def infer(self, stream, use_cuda_graph=False): - if use_cuda_graph: - if self.cuda_graph_instance is not None: - CUASSERT(cudart.cudaGraphLaunch(self.cuda_graph_instance, stream)) - CUASSERT(cudart.cudaStreamSynchronize(stream)) - else: - # do inference before CUDA graph capture - noerror = self.context.execute_async_v3(stream) - if not noerror: - raise ValueError("ERROR: inference failed.") - # capture cuda graph - CUASSERT( - cudart.cudaStreamBeginCapture( - stream, - cudart.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal, - ) - ) - self.context.execute_async_v3(stream) - graph = CUASSERT(cudart.cudaStreamEndCapture(stream)) - self.cuda_graph_instance = CUASSERT( - cudart.cudaGraphInstantiate(graph, 0) - ) - print("CUDA Graph captured!") - else: - noerror = self.context.execute_async_v3(stream) - CUASSERT(cudart.cudaStreamSynchronize(stream)) - if not noerror: - raise ValueError("ERROR: inference failed.") - - return self.tensors - - -class ExportWrapper(torch.nn.Module): - """ - An auxiliary class to facilitate ONNX->TRT export of a module - """ - - def __init__(self, model, input_names=None, output_names=None, precision="fp32"): - super().__init__() - self.input_names = input_names - self.output_names = output_names - self.dynamic_shapes = None - - self.model = model - self.precision = precision - - def get_export_obj(self): - return self.model - - def sample_profile(self, min_len=None, max_len=None): - return None - - def can_handle(self, **args): - return True - - @classmethod - def wrap(cls, model, **args): - wrapper = cls(model, **args) - return wrapper - - -@torch.jit.script -def no_nans(m): - t = torch.isnan(m) - return not torch.any(t) - - -class TRTWrapper(torch.nn.Module): - """ - An auxiliary class to implement running of TRT optimized engines - - """ - - def __init__(self, path, exp, use_cuda_graph=False, timestamp=None): - super().__init__() - self.exp_wrapper = None - self.prev_wrapper = None - self.profiles = None - self.engine = None - self.jit_model = None - self.onnx_runner = None - self.path = path - self.use_cuda_graph = use_cuda_graph - - if os.path.exists(self.onnx_path): - ftime=os.path.getmtime(self.onnx_path) - if timestamp is not None and ftime < timestamp: - os.remove(self.onnx_path) - else: - timestamp = ftime - if timestamp is not None and os.path.exists(self.engine_path) and os.path.getmtime(self.engine_path) < timestamp: - os.remove(self.engine_path) - - if exp is not None: - self.attach(exp) - - @property - def engine_path(self): - return self.path + ".plan" - - @property - def jit_path(self): - return self.path + ".ts" - - @property - def onnx_path(self): - return self.path + ".onnx" - - @property - def profiles_path(self): - return self.path + ".profiles.pkl" - - def has_engine(self): - return self.engine is not None - - def has_onnx(self): - return os.path.exists(self.onnx_path) - - def has_jit(self): - return os.path.exists(self.jit_path) - - def has_profiles(self): - return os.path.exists(self.profiles_path) - - def load_engine(self): - try: - engine = Engine(self.engine_path) - engine.load() - engine.activate() - self.engine = engine - except Exception as e: - print(f"Exception while loading the engine:\n{e}") - pass - - def load_jit(self): - try: - self.jit_model = torch.jit.load(self.jit_path) - except Exception: - pass - - def load_onnx(self, providers=["CUDAExecutionProvider"]): - try: - onnx_runner = OnnxrtRunner( - session_from_onnx(self.onnx_path, providers=providers) - ) - onnx_runner.activate() - self.onnx_runner = onnx_runner - except Exception: - pass - - def load_profiles(self): - with open(self.profiles_path, "rb") as fp: - profiles = pickle.load(fp) - self.profiles = profiles - return profiles - - def save_profiles(self): - with open(self.profiles_path, "wb") as fp: - pickle.dump(self.profiles, fp) - - def attach(self, exp): - self.exp_wrapper = exp - self.input_names = exp.input_names - self.output_names = exp.output_names - - def can_handle(self, **args): - return self.exp_wrapper.can_handle(**args) - - def inputs_to_dict(self, input_example): - trt_inputs = {} - for i, inp in enumerate(input_example): - input_name = self.engine.input_names[i] - trt_inputs[input_name] = inp - return trt_inputs - - def forward(self, **args): - try: - if self.engine is not None: - if self.can_handle(**args): - # print(f"Running {self.engine_path}...") - # forward_trt is not thread safe as we do not use per-thread execution contexts - with lock_sm: - return self.forward_trt(args) - elif self.jit_model is not None: - return self.jit_model.forward(**args) - elif self.onnx_runner is not None: - print(f"Running {self.onnx_path}...") - ret = self.onnx_runner.infer(args) - ret = list(ret.values()) - ret = [r.cuda() for r in ret] - if len(ret) == 1: - ret = ret[0] - return ret - except Exception as e: - print(f"Exception: {e}\nFalling back to Pytorch ...") - - return self.exp_wrapper.get_export_obj().forward(**args) - - def forward_trt(self, trt_inputs): - stream = torch.cuda.Stream(device=torch.cuda.current_device()) - self.engine.set_inputs(trt_inputs, stream.cuda_stream) - self.engine.allocate_buffers(torch.device("cuda")) - # Need this to synchronize with Torch stream - stream.wait_stream(torch.cuda.current_stream()) - ret = self.engine.infer(stream.cuda_stream, use_cuda_graph=self.use_cuda_graph) - ret = list(ret.values()) - # for r in ret: - # assert no_nans(r), "NaNs in TRT output!" - if len(ret) == 1: - ret = ret[0] - return ret - - def forward_trt_runner(self, trt_inputs): - with TrtRunner(self.engine) as runner: - ret = runner.infer(trt_inputs) - ret = list(ret.values()) - ret = [r.cuda() for r in ret] - # check = [check_m(r) for r in ret] - if len(ret) == 1: - ret = ret[0] - return ret - - def build_engine( - self, - input_profiles=[], - fp16=False, - bf16=False, - tf32=False, - builder_optimization_level=3, - direct_io=False, - enable_all_tactics=True, - ): - profiles = [] - if len(input_profiles) > 0: - for input_profile in input_profiles: - if isinstance(input_profile, Profile): - profiles.append(input_profile) - else: - p = Profile() - for name, dims in input_profile.items(): - assert len(dims) == 3 - p.add(name, min=dims[0], opt=dims[1], max=dims[2]) - profiles.append(p) - self.profiles = profiles - self.save_profiles() - - engine = Engine(self.path + ".plan") - engine.build( - self.onnx_path, - profiles, - fp16=fp16, - bf16=bf16, - tf32=tf32, - direct_io=direct_io, - builder_optimization_level=builder_optimization_level, - enable_all_tactics=enable_all_tactics, - ) - engine.activate() - self.engine = engine - - def jit_export( - self, - input_example, - verbose=False, - ): - self.jit_model = torch.jit.trace( - self.exp_wrapper, - input_example, - ).eval() - self.jit_model = torch.jit.freeze(self.jit_model) - torch.jit.save(self.jit_model, self.jit_path) - - def onnx_export( - self, - input_example, - dynamo=False, - onnx_registry=None, - dynamic_shapes=None, - verbose=False, - opset_version=18, - ): - L_.info(f"Exporting to ONNX, dynamic shapes: {dynamic_shapes}") - model = self.exp_wrapper.get_export_obj() - from .export_utils import replace_for_export - - replace_for_export(model, do_cast=True) - - if dynamo: - torch.onnx.export( - model, - input_example, - self.onnx_path, - dynamo=dynamo, - verbose=verbose, - opset_version=opset_version, - do_constant_folding=True, - input_names=self.input_names, - output_names=self.output_names, - dynamic_shapes=dynamic_shapes, - ) - else: - torch.onnx.export( - model, - input_example, - self.onnx_path, - verbose=verbose, - opset_version=opset_version, - do_constant_folding=True, - input_names=self.input_names, - output_names=self.output_names, - dynamic_axes=dynamic_shapes, - ) - L_.info("Folding constants...") - model_onnx = onnx_from_path(self.onnx_path) - fold_constants(model_onnx, allow_onnxruntime_shape_inference=False) - L_.info("Done folding constants.") - - L_.info("Saving model...") - save_onnx( - model_onnx, - self.onnx_path, - ) - L_.info("Done saving model.") - - def build_and_save( - self, - input_example, - dynamo=False, - verbose=False, - input_profiles=[], - fp16=False, - bf16=False, - tf32=True, - builder_optimization_level=3, - direct_io=False, - enable_all_tactics=True, - ): - if not self.has_engine(): - try: - if not self.has_onnx(): - self.onnx_export( - input_example, - dynamo=dynamo, - dynamic_shapes=get_dynamic_axes(input_profiles), - verbose=verbose, - ) - self.build_engine( - input_profiles=input_profiles, - fp16=fp16, tf32=tf32, - direct_io=direct_io, - builder_optimization_level=5, - enable_all_tactics=enable_all_tactics) - self.engine.save() - os.remove(self.onnx_path) - except Exception as e: - raise e - pass - - - From 92cda1b7da41551b2c177456079b5d77323f082d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 5 Aug 2024 06:41:12 +0000 Subject: [PATCH 14/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- scripts/infer.py | 33 +++++++++++++++++++-------------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/scripts/infer.py b/scripts/infer.py index e542d46..6043b98 100644 --- a/scripts/infer.py +++ b/scripts/infer.py @@ -35,9 +35,10 @@ try: from monai.utils import TRTWrapper - TRT_AVAILABLE=True -except Exception as e: - TRT_AVAILABLE=False + + TRT_AVAILABLE = True +except Exception: + TRT_AVAILABLE = False rearrange, _ = optional_import("einops", name="rearrange") sys.path.insert(0, os.path.abspath(os.path.dirname(__file__))) @@ -137,19 +138,23 @@ def __init__(self, config_file="./configs/infer.yaml", **override): self.prev_mask = None self.batch_data = None if self.trt and TRT_AVAILABLE: - ts=os.path.getmtime(config_file) - self.model.image_encoder.encoder = TRTWrapper("Encoder", - self.model.image_encoder.encoder, - input_names=["x"], - output_names=["x_out"], - timestamp=ts) + ts = os.path.getmtime(config_file) + self.model.image_encoder.encoder = TRTWrapper( + "Encoder", + self.model.image_encoder.encoder, + input_names=["x"], + output_names=["x_out"], + timestamp=ts, + ) self.model.image_encoder.encoder.load_engine() - self.model.class_head = TRTWrapper("ClassHead", - self.model.class_head, - input_names=["src", "class_vector"], - output_names=["masks", "class_embedding"], - timestamp=ts) + self.model.class_head = TRTWrapper( + "ClassHead", + self.model.class_head, + input_names=["src", "class_vector"], + output_names=["masks", "class_embedding"], + timestamp=ts, + ) self.model.class_head.load_engine() return From 15e37e32fb449c173dc04aa8d901b80b6ec80c64 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Mon, 5 Aug 2024 12:44:04 -0700 Subject: [PATCH 15/31] Using optional import Signed-off-by: Boris Fomitchev --- scripts/infer.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/scripts/infer.py b/scripts/infer.py index e542d46..32c003d 100644 --- a/scripts/infer.py +++ b/scripts/infer.py @@ -33,11 +33,7 @@ from .train import CONFIG from .utils.trans_utils import VistaPostTransform, get_largest_connected_component_point -try: - from monai.utils import TRTWrapper - TRT_AVAILABLE=True -except Exception as e: - TRT_AVAILABLE=False +from monai.utils import TRT_AVAILABLE, TRTWrapper rearrange, _ = optional_import("einops", name="rearrange") sys.path.insert(0, os.path.abspath(os.path.dirname(__file__))) From 8a01bb513d96e8b87b4f86c7568a314e00997b67 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 5 Aug 2024 19:45:15 +0000 Subject: [PATCH 16/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- scripts/infer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/scripts/infer.py b/scripts/infer.py index 83f566f..799b583 100644 --- a/scripts/infer.py +++ b/scripts/infer.py @@ -25,7 +25,7 @@ from monai.bundle import ConfigParser from monai.bundle.scripts import _pop_args, _update_args from monai.data import decollate_batch, list_data_collate, partition_dataset -from monai.utils import optional_import +from monai.utils import TRT_AVAILABLE, TRTWrapper, optional_import from vista3d import vista_model_registry @@ -33,8 +33,6 @@ from .train import CONFIG from .utils.trans_utils import VistaPostTransform, get_largest_connected_component_point -from monai.utils import TRT_AVAILABLE, TRTWrapper - rearrange, _ = optional_import("einops", name="rearrange") sys.path.insert(0, os.path.abspath(os.path.dirname(__file__))) IGNORE_PROMPT = set( From 9cadc9711fbf88241e1a8e8416758af226a54f35 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Mon, 5 Aug 2024 12:56:17 -0700 Subject: [PATCH 17/31] Using optional import, take 2 Signed-off-by: Boris Fomitchev --- scripts/infer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scripts/infer.py b/scripts/infer.py index 799b583..142a970 100644 --- a/scripts/infer.py +++ b/scripts/infer.py @@ -25,7 +25,8 @@ from monai.bundle import ConfigParser from monai.bundle.scripts import _pop_args, _update_args from monai.data import decollate_batch, list_data_collate, partition_dataset -from monai.utils import TRT_AVAILABLE, TRTWrapper, optional_import +from monai.utils import optional_import +TRTWrapper, TRT_AVAILABLE = optional_import('monai.utils', name='TRTWrapper') from vista3d import vista_model_registry From 6a3d2ac7a8767adc9dc9ea26074ef49732d610c8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 5 Aug 2024 19:56:41 +0000 Subject: [PATCH 18/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- scripts/infer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scripts/infer.py b/scripts/infer.py index 142a970..0879338 100644 --- a/scripts/infer.py +++ b/scripts/infer.py @@ -26,7 +26,8 @@ from monai.bundle.scripts import _pop_args, _update_args from monai.data import decollate_batch, list_data_collate, partition_dataset from monai.utils import optional_import -TRTWrapper, TRT_AVAILABLE = optional_import('monai.utils', name='TRTWrapper') + +TRTWrapper, TRT_AVAILABLE = optional_import("monai.utils", name="TRTWrapper") from vista3d import vista_model_registry From 939abc28f829e9500394a22b32a2ef0865c03c11 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Mon, 5 Aug 2024 16:17:24 -0700 Subject: [PATCH 19/31] precision_constraints=obey Signed-off-by: Boris Fomitchev --- vista3d/modeling/vista3d.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/vista3d/modeling/vista3d.py b/vista3d/modeling/vista3d.py index 9c09e97..1ce0364 100644 --- a/vista3d/modeling/vista3d.py +++ b/vista3d/modeling/vista3d.py @@ -300,12 +300,10 @@ def forward( if hasattr(self.image_encoder.encoder, "build_and_save"): self.image_encoder.encoder.build_and_save( (input_images,), - dynamo=False, - verbose=False, fp16=True, tf32=True, builder_optimization_level=5, - enable_all_tactics=True, + precision_constraints="obey", ) out, out_auto = self.image_encoder( @@ -326,9 +324,10 @@ def forward( ), fp16=True, tf32=True, - dynamo=False, - verbose=False, + builder_optimization_level=5, + precision_constraints="obey", ) + logits, _ = self.class_head(src=out_auto, class_vector=class_vector) if point_coords is not None: From 4f1d21c4c8a728b0750414ec1b8f515ee11001ed Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 5 Aug 2024 23:18:17 +0000 Subject: [PATCH 20/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- vista3d/modeling/vista3d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vista3d/modeling/vista3d.py b/vista3d/modeling/vista3d.py index 1ce0364..b9720b7 100644 --- a/vista3d/modeling/vista3d.py +++ b/vista3d/modeling/vista3d.py @@ -327,7 +327,7 @@ def forward( builder_optimization_level=5, precision_constraints="obey", ) - + logits, _ = self.class_head(src=out_auto, class_vector=class_vector) if point_coords is not None: From b14474d36aec323e143d4690a2d8f6e38eef0a52 Mon Sep 17 00:00:00 2001 From: Mingxue Gu Date: Tue, 6 Aug 2024 16:15:32 +0000 Subject: [PATCH 21/31] update accuracy benchmark --- dices.json | 112 ++++++++++++++++++++++++++--------------------------- 1 file changed, 56 insertions(+), 56 deletions(-) diff --git a/dices.json b/dices.json index 0ab713f..35b15a5 100644 --- a/dices.json +++ b/dices.json @@ -1,23 +1,23 @@ { - "liver": 0.9999347925186157, + "liver": 0.9999467134475708, "kidney": 1.0, - "spleen": 0.9998570084571838, - "pancreas": 0.9997349977493286, - "right kidney": 0.9999557137489319, - "aorta": 1.0, - "inferior vena cava": 0.9998636245727539, + "spleen": 0.9998987317085266, + "pancreas": 0.9998106360435486, + "right kidney": 0.9997254610061646, + "aorta": 0.9999536275863647, + "inferior vena cava": 0.9997954964637756, "right adrenal gland": 1.0, - "left adrenal gland": 0.9997097253799438, + "left adrenal gland": 0.9971064925193787, "gallbladder": 1.0, "esophagus": 0.9997258186340332, - "stomach": 0.9999347925186157, - "duodenum": 0.9996980428695679, - "left kidney": 0.9999045729637146, + "stomach": 0.9999147653579712, + "duodenum": 0.9995471835136414, + "left kidney": 0.9997535347938538, "bladder": 0.9998233318328857, "prostate or uterus (deprecated)": 1.0, - "portal vein and splenic vein": 0.999485433101654, + "portal vein and splenic vein": 0.9996570348739624, "rectum (deprecated)": 1.0, - "small bowel": 0.9996098875999451, + "small bowel": 0.9995405673980713, "lung": 1.0, "bone": 1.0, "brain": 1.0, @@ -26,21 +26,21 @@ "hepatic vessel": 1.0, "hepatic tumor": 1.0, "colon cancer primaries": 1.0, - "left lung upper lobe": 0.9998635053634644, - "left lung lower lobe": 0.9999285340309143, + "left lung upper lobe": 0.9999317526817322, + "left lung lower lobe": 0.9999247789382935, "right lung upper lobe": 1.0, - "right lung middle lobe": 0.9999430179595947, - "right lung lower lobe": 0.999975323677063, - "vertebrae L5": 0.9999445080757141, + "right lung middle lobe": 0.9999620318412781, + "right lung lower lobe": 0.9999691843986511, + "vertebrae L5": 0.9999167323112488, "vertebrae L4": 0.9999210834503174, - "vertebrae L3": 0.9998977184295654, - "vertebrae L2": 0.9999402761459351, - "vertebrae L1": 1.0, - "vertebrae T12": 0.9996854662895203, + "vertebrae L3": 1.0, + "vertebrae L2": 0.9997909665107727, + "vertebrae L1": 0.9998704195022583, + "vertebrae T12": 0.999764084815979, "vertebrae T11": 0.9997434616088867, "vertebrae T10": 0.9998674392700195, - "vertebrae T9": 0.9996097087860107, - "vertebrae T8": 1.0, + "vertebrae T9": 0.9997072815895081, + "vertebrae T8": 0.9992929697036743, "vertebrae T7": 1.0, "vertebrae T6": 1.0, "vertebrae T5": 1.0, @@ -58,19 +58,19 @@ "trachea": 1.0, "left iliac artery": 0.998672604560852, "right iliac artery": 0.9997827410697937, - "left iliac vena": 0.9996750950813293, + "left iliac vena": 0.9996752142906189, "right iliac vena": 0.9997751712799072, - "colon": 0.999774158000946, + "colon": 0.9997839331626892, "left rib 1": 1.0, "left rib 2": 1.0, "left rib 3": 1.0, "left rib 4": 1.0, "left rib 5": 1.0, - "left rib 6": 0.9995150566101074, - "left rib 7": 0.9989900588989258, + "left rib 6": 0.9985436797142029, + "left rib 7": 0.9997116327285767, "left rib 8": 1.0, - "left rib 9": 0.9997802972793579, - "left rib 10": 0.9982767701148987, + "left rib 9": 0.9997071027755737, + "left rib 10": 0.9987931251525879, "left rib 11": 1.0, "left rib 12": 1.0, "right rib 1": 1.0, @@ -78,58 +78,58 @@ "right rib 3": 1.0, "right rib 4": 1.0, "right rib 5": 1.0, - "right rib 6": 0.999602198600769, - "right rib 7": 1.0, - "right rib 8": 0.9992419481277466, + "right rib 6": 0.9992054104804993, + "right rib 7": 0.999552845954895, + "right rib 8": 0.9996969103813171, "right rib 9": 1.0, - "right rib 10": 0.9998047947883606, + "right rib 10": 0.9995119571685791, "right rib 11": 1.0, "right rib 12": 1.0, "left humerus": 0.9719626307487488, "right humerus": 0.9873417615890503, "left scapula": 1.0, - "right scapula": 0.9997193217277527, + "right scapula": 1.0, "left clavicula": 1.0, "right clavicula": 1.0, - "left femur": 0.9999800324440002, - "right femur": 0.9998434782028198, - "left hip": 0.9999173879623413, + "left femur": 0.999920129776001, + "right femur": 0.9998330473899841, + "left hip": 0.9999256730079651, "right hip": 0.9999226927757263, - "sacrum": 0.9997125267982483, - "left gluteus maximus": 0.9998618960380554, - "right gluteus maximus": 0.9998993277549744, - "left gluteus medius": 0.9997550249099731, - "right gluteus medius": 0.9997763633728027, - "left gluteus minimus": 0.9991177916526794, - "right gluteus minimus": 0.9998393058776855, - "left autochthon": 0.9998349547386169, - "right autochthon": 0.9998183846473694, - "left iliopsoas": 0.9998319149017334, - "right iliopsoas": 0.9998192191123962, + "sacrum": 0.9997796416282654, + "left gluteus maximus": 0.9998824000358582, + "right gluteus maximus": 0.9998437166213989, + "left gluteus medius": 0.9997230172157288, + "right gluteus medius": 0.9997458457946777, + "left gluteus minimus": 0.9993826150894165, + "right gluteus minimus": 0.9997991919517517, + "left autochthon": 0.999840259552002, + "right autochthon": 0.9998072981834412, + "left iliopsoas": 0.9998109340667725, + "right iliopsoas": 0.9998148679733276, "left atrial appendage": 1.0, "brachiocephalic trunk": 1.0, "left brachiocephalic vein": 1.0, "right brachiocephalic vein": 1.0, "left common carotid artery": 1.0, "right common carotid artery": 1.0, - "costal cartilages": 0.9996328353881836, - "heart": 0.9997690320014954, + "costal cartilages": 0.9993331432342529, + "heart": 0.9998570084571838, "left kidney cyst": 1.0, - "right kidney cyst": 0.9991546869277954, + "right kidney cyst": 0.9997888803482056, "prostate": 1.0, "pulmonary vein": 1.0, "skull": 1.0, - "spinal cord": 0.9995791912078857, + "spinal cord": 0.9996580481529236, "sternum": 1.0, "left subclavian artery": 1.0, "right subclavian artery": 1.0, - "superior vena cava": 0.9977220892906189, + "superior vena cava": 1.0, "thyroid gland": 1.0, - "vertebrae S1": 0.9995207786560059, + "vertebrae S1": 0.9998401999473572, "bone lesion": 1.0, "kidney mass (deprecated)": 1.0, "liver tumor (deprecated)": 1.0, "vertebrae L6 (deprecated)": 1.0, "airway": 1.0, - "average": 0.999543797789198 -} + "average": 0.9995385372277462 +} \ No newline at end of file From 1b7d13a26ea94532ec6669a1e896204421c17dfa Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 6 Aug 2024 16:34:23 +0000 Subject: [PATCH 22/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- dices.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dices.json b/dices.json index 35b15a5..6febab3 100644 --- a/dices.json +++ b/dices.json @@ -132,4 +132,4 @@ "vertebrae L6 (deprecated)": 1.0, "airway": 1.0, "average": 0.9995385372277462 -} \ No newline at end of file +} From cd3ee1e373533a1d9be75623f4ce634a29c3d46a Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Tue, 6 Aug 2024 13:17:22 -0700 Subject: [PATCH 23/31] Fixed ruff Signed-off-by: Boris Fomitchev --- scripts/infer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/scripts/infer.py b/scripts/infer.py index 0879338..11f9e3e 100644 --- a/scripts/infer.py +++ b/scripts/infer.py @@ -26,15 +26,14 @@ from monai.bundle.scripts import _pop_args, _update_args from monai.data import decollate_batch, list_data_collate, partition_dataset from monai.utils import optional_import - -TRTWrapper, TRT_AVAILABLE = optional_import("monai.utils", name="TRTWrapper") - from vista3d import vista_model_registry from .sliding_window import point_based_window_inferer, sliding_window_inference from .train import CONFIG from .utils.trans_utils import VistaPostTransform, get_largest_connected_component_point +TRTWrapper, TRT_AVAILABLE = optional_import("monai.utils", name="TRTWrapper") + rearrange, _ = optional_import("einops", name="rearrange") sys.path.insert(0, os.path.abspath(os.path.dirname(__file__))) IGNORE_PROMPT = set( From 9b6cd9e0e7f40909b448911d49a2881660cfbd62 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 6 Aug 2024 20:17:43 +0000 Subject: [PATCH 24/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- scripts/infer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/infer.py b/scripts/infer.py index 11f9e3e..0ef9daa 100644 --- a/scripts/infer.py +++ b/scripts/infer.py @@ -26,6 +26,7 @@ from monai.bundle.scripts import _pop_args, _update_args from monai.data import decollate_batch, list_data_collate, partition_dataset from monai.utils import optional_import + from vista3d import vista_model_registry from .sliding_window import point_based_window_inferer, sliding_window_inference From fd7f6aff069a24dab61fa6fa991d1bb2b30e2363 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Thu, 8 Aug 2024 19:24:34 -0700 Subject: [PATCH 25/31] Adjusted for TRWrapper move Signed-off-by: Boris Fomitchev --- scripts/infer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/infer.py b/scripts/infer.py index 0ef9daa..c92222b 100644 --- a/scripts/infer.py +++ b/scripts/infer.py @@ -33,7 +33,7 @@ from .train import CONFIG from .utils.trans_utils import VistaPostTransform, get_largest_connected_component_point -TRTWrapper, TRT_AVAILABLE = optional_import("monai.utils", name="TRTWrapper") +TRTWrapper, TRT_AVAILABLE = optional_import("monai.networks.trt_wrapper", name="TRTWrapper") rearrange, _ = optional_import("einops", name="rearrange") sys.path.insert(0, os.path.abspath(os.path.dirname(__file__))) From 90b5b477567b0946795611ee4018916a040b3fb8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 9 Aug 2024 02:24:56 +0000 Subject: [PATCH 26/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- scripts/infer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scripts/infer.py b/scripts/infer.py index c92222b..11f420d 100644 --- a/scripts/infer.py +++ b/scripts/infer.py @@ -33,7 +33,9 @@ from .train import CONFIG from .utils.trans_utils import VistaPostTransform, get_largest_connected_component_point -TRTWrapper, TRT_AVAILABLE = optional_import("monai.networks.trt_wrapper", name="TRTWrapper") +TRTWrapper, TRT_AVAILABLE = optional_import( + "monai.networks.trt_wrapper", name="TRTWrapper" +) rearrange, _ = optional_import("einops", name="rearrange") sys.path.insert(0, os.path.abspath(os.path.dirname(__file__))) From c0641e503673f7ebab1fbca1db99930012bb5559 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Fri, 9 Aug 2024 23:55:17 -0700 Subject: [PATCH 27/31] Adjusted for TRTWrapper API change Signed-off-by: Boris Fomitchev --- scripts/infer.py | 22 +++++++++++++--------- vista3d/modeling/vista3d.py | 26 ++------------------------ 2 files changed, 15 insertions(+), 33 deletions(-) diff --git a/scripts/infer.py b/scripts/infer.py index c92222b..7d20ac5 100644 --- a/scripts/infer.py +++ b/scripts/infer.py @@ -133,24 +133,28 @@ def __init__(self, config_file="./configs/infer.yaml", **override): self.prev_mask = None self.batch_data = None if self.trt and TRT_AVAILABLE: + bundle_root = parser.get_parsed_content("bundle_root") ts = os.path.getmtime(config_file) self.model.image_encoder.encoder = TRTWrapper( - "Encoder", + f"{bundle_root}/image_encoder", self.model.image_encoder.encoder, - input_names=["x"], - output_names=["x_out"], + precision="fp16", + build_args={ + "builder_optimization_level": 5, + "precision_constraints":"obey" + }, timestamp=ts, ) - self.model.image_encoder.encoder.load_engine() - self.model.class_head = TRTWrapper( - "ClassHead", + f"{bundle_root}/class_head", self.model.class_head, - input_names=["src", "class_vector"], - output_names=["masks", "class_embedding"], + precision="fp16", + build_args={ + "builder_optimization_level": 5, + "precision_constraints":"obey" + }, timestamp=ts, ) - self.model.class_head.load_engine() return def clear_cache(self): diff --git a/vista3d/modeling/vista3d.py b/vista3d/modeling/vista3d.py index b9720b7..39e58f5 100644 --- a/vista3d/modeling/vista3d.py +++ b/vista3d/modeling/vista3d.py @@ -296,18 +296,8 @@ def forward( ): out, out_auto = self.image_embeddings, None else: - # Support for TRT wrappping - if hasattr(self.image_encoder.encoder, "build_and_save"): - self.image_encoder.encoder.build_and_save( - (input_images,), - fp16=True, - tf32=True, - builder_optimization_level=5, - precision_constraints="obey", - ) - out, out_auto = self.image_encoder( - x=input_images, + input_images, with_point=point_coords is not None, with_label=class_vector is not None, ) @@ -316,19 +306,7 @@ def forward( # force releasing memories that set to None torch.cuda.empty_cache() if class_vector is not None: - if hasattr(self.class_head, "build_and_save"): - self.class_head.build_and_save( - ( - out_auto, - class_vector, - ), - fp16=True, - tf32=True, - builder_optimization_level=5, - precision_constraints="obey", - ) - - logits, _ = self.class_head(src=out_auto, class_vector=class_vector) + logits, _ = self.class_head(out_auto, class_vector=class_vector) if point_coords is not None: point_logits = self.point_head( From b24abb308e731c703b55f9e4c9f788ca580d25e4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 10 Aug 2024 06:56:16 +0000 Subject: [PATCH 28/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- scripts/infer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/infer.py b/scripts/infer.py index 3fd663a..66da415 100644 --- a/scripts/infer.py +++ b/scripts/infer.py @@ -143,7 +143,7 @@ def __init__(self, config_file="./configs/infer.yaml", **override): precision="fp16", build_args={ "builder_optimization_level": 5, - "precision_constraints":"obey" + "precision_constraints": "obey", }, timestamp=ts, ) @@ -153,7 +153,7 @@ def __init__(self, config_file="./configs/infer.yaml", **override): precision="fp16", build_args={ "builder_optimization_level": 5, - "precision_constraints":"obey" + "precision_constraints": "obey", }, timestamp=ts, ) From 60533388ff3b94a2354c6feec1278bd2b15c87b3 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Sat, 10 Aug 2024 23:09:05 -0700 Subject: [PATCH 29/31] Adjusted TRTWrapper args Signed-off-by: Boris Fomitchev --- scripts/infer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/infer.py b/scripts/infer.py index 66da415..509c4c5 100644 --- a/scripts/infer.py +++ b/scripts/infer.py @@ -138,8 +138,8 @@ def __init__(self, config_file="./configs/infer.yaml", **override): bundle_root = parser.get_parsed_content("bundle_root") ts = os.path.getmtime(config_file) self.model.image_encoder.encoder = TRTWrapper( - f"{bundle_root}/image_encoder", self.model.image_encoder.encoder, + f"{bundle_root}/image_encoder", precision="fp16", build_args={ "builder_optimization_level": 5, @@ -148,8 +148,8 @@ def __init__(self, config_file="./configs/infer.yaml", **override): timestamp=ts, ) self.model.class_head = TRTWrapper( - f"{bundle_root}/class_head", self.model.class_head, + f"{bundle_root}/class_head", precision="fp16", build_args={ "builder_optimization_level": 5, From 2ec59eb63c393cb7792cc71c8d3adfd78838317f Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Sun, 18 Aug 2024 13:51:58 -0700 Subject: [PATCH 30/31] Adjusted for TRT wrapper refactoring Signed-off-by: Boris Fomitchev --- scripts/infer.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/scripts/infer.py b/scripts/infer.py index 509c4c5..856a7c0 100644 --- a/scripts/infer.py +++ b/scripts/infer.py @@ -33,8 +33,8 @@ from .train import CONFIG from .utils.trans_utils import VistaPostTransform, get_largest_connected_component_point -TRTWrapper, TRT_AVAILABLE = optional_import( - "monai.networks.trt_wrapper", name="TRTWrapper" +trt_wrap, TRT_AVAILABLE = optional_import( + "monai.networks", name="trt_wrap" ) rearrange, _ = optional_import("einops", name="rearrange") @@ -137,25 +137,24 @@ def __init__(self, config_file="./configs/infer.yaml", **override): if self.trt and TRT_AVAILABLE: bundle_root = parser.get_parsed_content("bundle_root") ts = os.path.getmtime(config_file) - self.model.image_encoder.encoder = TRTWrapper( - self.model.image_encoder.encoder, - f"{bundle_root}/image_encoder", - precision="fp16", - build_args={ + trt_args = { + "precision": "fp16", + "build_args": { "builder_optimization_level": 5, "precision_constraints": "obey", }, - timestamp=ts, + "timestamp": ts + } + + trt_wrap( + self.model.image_encoder.encoder, + f"{bundle_root}/image_encoder", + args=trt_args, ) - self.model.class_head = TRTWrapper( + trt_wrap( self.model.class_head, f"{bundle_root}/class_head", - precision="fp16", - build_args={ - "builder_optimization_level": 5, - "precision_constraints": "obey", - }, - timestamp=ts, + args=trt_args, ) return From 2047da00004fe06de17ddb9a7e0e832178441ed8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 18 Aug 2024 20:52:29 +0000 Subject: [PATCH 31/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- scripts/infer.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/scripts/infer.py b/scripts/infer.py index 856a7c0..313a7a2 100644 --- a/scripts/infer.py +++ b/scripts/infer.py @@ -33,9 +33,7 @@ from .train import CONFIG from .utils.trans_utils import VistaPostTransform, get_largest_connected_component_point -trt_wrap, TRT_AVAILABLE = optional_import( - "monai.networks", name="trt_wrap" -) +trt_wrap, TRT_AVAILABLE = optional_import("monai.networks", name="trt_wrap") rearrange, _ = optional_import("einops", name="rearrange") sys.path.insert(0, os.path.abspath(os.path.dirname(__file__))) @@ -143,7 +141,7 @@ def __init__(self, config_file="./configs/infer.yaml", **override): "builder_optimization_level": 5, "precision_constraints": "obey", }, - "timestamp": ts + "timestamp": ts, } trt_wrap(