Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
6e3201f
Initial profiling/export
borisfom Jul 16, 2024
a1da5e2
stash
borisfom Jul 19, 2024
d854bc5
Working TRT wrappers for encoder and class head
borisfom Jul 24, 2024
a82ce56
Merge remote-tracking branch 'origin/vista3d' into vista3d-export
borisfom Jul 24, 2024
818a548
Cleaned up, working TRT wrapping
borisfom Jul 24, 2024
3e4a84b
Cleanup
borisfom Jul 31, 2024
a45deed
Merge remote-tracking branch 'origin/vista3d' into vista3d-export
borisfom Jul 31, 2024
91606c3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 31, 2024
1004dad
Fixing CI issues
borisfom Aug 1, 2024
e141954
Merge branch 'vista3d-export' of github.com:borisfom/VISTA into vista…
borisfom Aug 1, 2024
1942aa5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 1, 2024
20dce0e
Improved TRT engine handling, fallback added
borisfom Aug 2, 2024
6844725
Merge branch 'vista3d-export' of github.com:borisfom/VISTA into vista…
borisfom Aug 2, 2024
47006eb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 2, 2024
8b1408a
add accuracy benchmark results
mingxueg-nv Aug 2, 2024
4650e1b
Merge branch 'vista3d-export' of https://github.com/borisfom/VISTA in…
mingxueg-nv Aug 2, 2024
e0c0e7a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 2, 2024
2d4dfc0
Refactored to use TRTWrapper from MONAI
borisfom Aug 5, 2024
20f1b19
Merge branch 'vista3d-export' of github.com:borisfom/VISTA into vista…
borisfom Aug 5, 2024
92cda1b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 5, 2024
15e37e3
Using optional import
borisfom Aug 5, 2024
20cf733
Merge branch 'vista3d-export' of github.com:borisfom/VISTA into vista…
borisfom Aug 5, 2024
8a01bb5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 5, 2024
9cadc97
Using optional import, take 2
borisfom Aug 5, 2024
6a3d2ac
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 5, 2024
939abc2
precision_constraints=obey
borisfom Aug 5, 2024
6b711f9
Merge branch 'vista3d-export' of github.com:borisfom/VISTA into vista…
borisfom Aug 5, 2024
4f1d21c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 5, 2024
b14474d
update accuracy benchmark
mingxueg-nv Aug 6, 2024
1b7d13a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 6, 2024
cd3ee1e
Fixed ruff
borisfom Aug 6, 2024
9b6cd9e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 6, 2024
fd7f6af
Adjusted for TRWrapper move
borisfom Aug 9, 2024
90b5b47
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 9, 2024
c0641e5
Adjusted for TRTWrapper API change
borisfom Aug 10, 2024
b668233
Merge branch 'vista3d-export' of github.com:borisfom/VISTA into vista…
borisfom Aug 10, 2024
b24abb3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 10, 2024
6053338
Adjusted TRTWrapper args
borisfom Aug 11, 2024
2ec59eb
Adjusted for TRT wrapper refactoring
borisfom Aug 18, 2024
2047da0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ Ask and answer questions on [MONAI VISTA's GitHub discussions tab](https://githu

## License

The codebase is under Apache 2.0 Licence. The model weight is under special NVIDIA license.
The codebase is under Apache 2.0 Licence. The model weight is under special NVIDIA license.

## Reference

Expand Down
1 change: 1 addition & 0 deletions configs/infer.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
trt: true
amp: true
input_channels: 1
patch_size: [128, 128, 128]
Expand Down
2 changes: 1 addition & 1 deletion data/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ The output of this step is multiple JSON files, each file corresponds
to one dataset.

##### 2. Add label_dict.json and label_mapping.json
Add new class indexes to `label_dict.json` and the local to global mapping to `label_mapping.json`.
Add new class indexes to `label_dict.json` and the local to global mapping to `label_mapping.json`.

## SupverVoxel Generation
1. Download the segment anything repo and download the ViT-H weights
Expand Down
135 changes: 135 additions & 0 deletions dices.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
{
"liver": 0.9999467134475708,
"kidney": 1.0,
"spleen": 0.9998987317085266,
"pancreas": 0.9998106360435486,
"right kidney": 0.9997254610061646,
"aorta": 0.9999536275863647,
"inferior vena cava": 0.9997954964637756,
"right adrenal gland": 1.0,
"left adrenal gland": 0.9971064925193787,
"gallbladder": 1.0,
"esophagus": 0.9997258186340332,
"stomach": 0.9999147653579712,
"duodenum": 0.9995471835136414,
"left kidney": 0.9997535347938538,
"bladder": 0.9998233318328857,
"prostate or uterus (deprecated)": 1.0,
"portal vein and splenic vein": 0.9996570348739624,
"rectum (deprecated)": 1.0,
"small bowel": 0.9995405673980713,
"lung": 1.0,
"bone": 1.0,
"brain": 1.0,
"lung tumor": 1.0,
"pancreatic tumor": 1.0,
"hepatic vessel": 1.0,
"hepatic tumor": 1.0,
"colon cancer primaries": 1.0,
"left lung upper lobe": 0.9999317526817322,
"left lung lower lobe": 0.9999247789382935,
"right lung upper lobe": 1.0,
"right lung middle lobe": 0.9999620318412781,
"right lung lower lobe": 0.9999691843986511,
"vertebrae L5": 0.9999167323112488,
"vertebrae L4": 0.9999210834503174,
"vertebrae L3": 1.0,
"vertebrae L2": 0.9997909665107727,
"vertebrae L1": 0.9998704195022583,
"vertebrae T12": 0.999764084815979,
"vertebrae T11": 0.9997434616088867,
"vertebrae T10": 0.9998674392700195,
"vertebrae T9": 0.9997072815895081,
"vertebrae T8": 0.9992929697036743,
"vertebrae T7": 1.0,
"vertebrae T6": 1.0,
"vertebrae T5": 1.0,
"vertebrae T4": 1.0,
"vertebrae T3": 1.0,
"vertebrae T2": 1.0,
"vertebrae T1": 1.0,
"vertebrae C7": 1.0,
"vertebrae C6": 1.0,
"vertebrae C5": 1.0,
"vertebrae C4": 1.0,
"vertebrae C3": 1.0,
"vertebrae C2": 1.0,
"vertebrae C1": 1.0,
"trachea": 1.0,
"left iliac artery": 0.998672604560852,
"right iliac artery": 0.9997827410697937,
"left iliac vena": 0.9996752142906189,
"right iliac vena": 0.9997751712799072,
"colon": 0.9997839331626892,
"left rib 1": 1.0,
"left rib 2": 1.0,
"left rib 3": 1.0,
"left rib 4": 1.0,
"left rib 5": 1.0,
"left rib 6": 0.9985436797142029,
"left rib 7": 0.9997116327285767,
"left rib 8": 1.0,
"left rib 9": 0.9997071027755737,
"left rib 10": 0.9987931251525879,
"left rib 11": 1.0,
"left rib 12": 1.0,
"right rib 1": 1.0,
"right rib 2": 1.0,
"right rib 3": 1.0,
"right rib 4": 1.0,
"right rib 5": 1.0,
"right rib 6": 0.9992054104804993,
"right rib 7": 0.999552845954895,
"right rib 8": 0.9996969103813171,
"right rib 9": 1.0,
"right rib 10": 0.9995119571685791,
"right rib 11": 1.0,
"right rib 12": 1.0,
"left humerus": 0.9719626307487488,
"right humerus": 0.9873417615890503,
"left scapula": 1.0,
"right scapula": 1.0,
"left clavicula": 1.0,
"right clavicula": 1.0,
"left femur": 0.999920129776001,
"right femur": 0.9998330473899841,
"left hip": 0.9999256730079651,
"right hip": 0.9999226927757263,
"sacrum": 0.9997796416282654,
"left gluteus maximus": 0.9998824000358582,
"right gluteus maximus": 0.9998437166213989,
"left gluteus medius": 0.9997230172157288,
"right gluteus medius": 0.9997458457946777,
"left gluteus minimus": 0.9993826150894165,
"right gluteus minimus": 0.9997991919517517,
"left autochthon": 0.999840259552002,
"right autochthon": 0.9998072981834412,
"left iliopsoas": 0.9998109340667725,
"right iliopsoas": 0.9998148679733276,
"left atrial appendage": 1.0,
"brachiocephalic trunk": 1.0,
"left brachiocephalic vein": 1.0,
"right brachiocephalic vein": 1.0,
"left common carotid artery": 1.0,
"right common carotid artery": 1.0,
"costal cartilages": 0.9993331432342529,
"heart": 0.9998570084571838,
"left kidney cyst": 1.0,
"right kidney cyst": 0.9997888803482056,
"prostate": 1.0,
"pulmonary vein": 1.0,
"skull": 1.0,
"spinal cord": 0.9996580481529236,
"sternum": 1.0,
"left subclavian artery": 1.0,
"right subclavian artery": 1.0,
"superior vena cava": 1.0,
"thyroid gland": 1.0,
"vertebrae S1": 0.9998401999473572,
"bone lesion": 1.0,
"kidney mass (deprecated)": 1.0,
"liver tumor (deprecated)": 1.0,
"vertebrae L6 (deprecated)": 1.0,
"airway": 1.0,
"average": 0.9995385372277462
}
10 changes: 7 additions & 3 deletions scripts/debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,16 +123,20 @@ def on_button_click(event, ax=ax):
print("-- segmenting ---")
self.generate_mask()
print("-- done ---")
print("-- Note: Point only prompts will only do 128 cubic segmentation, a cropping artefact will be observed. ---")
print("-- Note: Point without class will be treated as supported class, which has worse zero-shot ability. Try class > 132 to perform better zeroshot. ---")
print(
"-- Note: Point only prompts will only do 128 cubic segmentation, a cropping artefact will be observed. ---"
)
print(
"-- Note: Point without class will be treated as supported class, which has worse zero-shot ability. Try class > 132 to perform better zeroshot. ---"
)
print("-- Note: CTRL + Right Click will be adding negative points. ---")
print(
"-- Note: Click points on different foreground class will cause segmentation conflicts. Clear first. ---"
)
print(
"-- Note: Click points not matching class prompts will also cause confusion. ---"
)

self.update_slice(ax)
# self.point_start = len(self.clicked_points)

Expand Down
34 changes: 34 additions & 0 deletions scripts/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import logging
import os
import sys
import time
from functools import partial

import monai
Expand All @@ -32,6 +33,8 @@
from .train import CONFIG
from .utils.trans_utils import VistaPostTransform, get_largest_connected_component_point

trt_wrap, TRT_AVAILABLE = optional_import("monai.networks", name="trt_wrap")

rearrange, _ = optional_import("einops", name="rearrange")
sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))
IGNORE_PROMPT = set(
Expand Down Expand Up @@ -73,6 +76,7 @@ def __init__(self, config_file="./configs/infer.yaml", **override):
parser.update(pairs=_args)

self.amp = parser.get_parsed_content("amp")
self.trt = parser.get_parsed_content("trt")
input_channels = parser.get_parsed_content("input_channels")
patch_size = parser.get_parsed_content("patch_size")
self.patch_size = patch_size
Expand Down Expand Up @@ -128,6 +132,28 @@ def __init__(self, config_file="./configs/infer.yaml", **override):
self.save_transforms = transforms.Compose(save_transforms)
self.prev_mask = None
self.batch_data = None
if self.trt and TRT_AVAILABLE:
bundle_root = parser.get_parsed_content("bundle_root")
ts = os.path.getmtime(config_file)
trt_args = {
"precision": "fp16",
"build_args": {
"builder_optimization_level": 5,
"precision_constraints": "obey",
},
"timestamp": ts,
}

trt_wrap(
self.model.image_encoder.encoder,
f"{bundle_root}/image_encoder",
args=trt_args,
)
trt_wrap(
self.model.class_head,
f"{bundle_root}/class_head",
args=trt_args,
)
return

def clear_cache(self):
Expand Down Expand Up @@ -161,6 +187,7 @@ def infer(
used together with prev_mask. If prev_mask is generated by N points, point_start should be N+1 to save
time and avoid repeated inference. This is by default disabled.
"""
time00 = time.time()
self.model.eval()
if not isinstance(image_file, dict):
image_file = {"image": image_file}
Expand Down Expand Up @@ -255,12 +282,15 @@ def infer(
finished = False
if finished:
break
print(f"Infer Time: {time.time() - time00}")

if not finished:
raise RuntimeError("Infer not finished due to OOM.")
return batch_data[0]["pred"]

@torch.no_grad()
def infer_everything(self, image_file, label_prompt=EVERYTHING_PROMPT, rank=0):
time00 = time.time()
self.model.eval()
device = f"cuda:{rank}"
if not isinstance(image_file, dict):
Expand Down Expand Up @@ -302,6 +332,8 @@ def infer_everything(self, image_file, label_prompt=EVERYTHING_PROMPT, rank=0):
finished = False
if finished:
break
print(f"InferEverything Time: {time.time() - time00}")

if not finished:
raise RuntimeError("Infer not finished due to OOM.")

Expand All @@ -324,5 +356,7 @@ def batch_infer_everything(self, datalist=str, basedir=str):


if __name__ == "__main__":
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
fire, _ = optional_import("fire")
fire.Fire(InferClass)
12 changes: 7 additions & 5 deletions vista3d/modeling/segresnetds.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from __future__ import annotations

from collections.abc import Callable
from typing import Union
from typing import List, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -473,7 +473,7 @@ def _forward(
f"Input spatial dims {x.shape} must be divisible by {self.shape_factor()}"
)

x_down = self.encoder(x)
x_down = self.encoder(x=x)

x_down.reverse()
x = x_down.pop(0)
Expand All @@ -483,8 +483,9 @@ def _forward(

outputs: list[torch.Tensor] = []
outputs_auto: list[torch.Tensor] = []
x_ = x.clone()

if with_point:
x_ = x.clone()
i = 0
for level in self.up_layers:
x = level["upsample"](x)
Expand All @@ -496,7 +497,8 @@ def _forward(
i = i + 1

outputs.reverse()
x = x_
x = x_

if with_label:
i = 0
for level in self.up_layers_auto:
Expand All @@ -522,7 +524,7 @@ def _forward(

def forward(
self, x: torch.Tensor, with_point=True, with_label=True, **kwargs
) -> Union[None, torch.Tensor, list[torch.Tensor]]:
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
return self._forward(x, with_point, with_label)

def set_auto_grad(self, auto_freeze=False, point_freeze=False):
Expand Down
3 changes: 2 additions & 1 deletion vista3d/modeling/vista3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,8 @@ def forward(
# force releasing memories that set to None
torch.cuda.empty_cache()
if class_vector is not None:
logits, _ = self.class_head(out_auto, class_vector)
logits, _ = self.class_head(out_auto, class_vector=class_vector)

if point_coords is not None:
point_logits = self.point_head(
out, point_coords, point_labels, class_vector=prompt_class
Expand Down