Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions vista3d/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ limitations under the License.

# MONAI **V**ersatile **I**maging **S**egmen**T**ation and **A**nnotation
[[`Paper`](https://arxiv.org/pdf/2406.05285)] [[`Demo`](https://build.nvidia.com/nvidia/vista-3d)] [[`Checkpoint`]](https://drive.google.com/file/d/1DRYA2-AI-UJ23W1VbjqHsnHENGi0ShUl/view?usp=sharing)

## News!
[03/12/2025] We provide VISTA3D as a baseline for the challenge "CVPR 2025: Foundation Models for Interactive 3D Biomedical Image Segmentation"([link](https://www.codabench.org/competitions/5263/)). The simplified code based on MONAI 1.4 is provided in the [here](./cvpr_workshop/).

[02/26/2025] VISTA3D paper has been accepted by **CVPR2025**!
## Overview

The **VISTA3D** is a foundation model trained systematically on 11,454 volumes encompassing 127 types of human anatomical structures and various lesions. It provides accurate out-of-the-box segmentation that matches state-of-the-art supervised models which are trained on each dataset. The model also achieves state-of-the-art zero-shot interactive segmentation in 3D, representing a promising step toward developing a versatile medical image foundation model.
Expand Down
24 changes: 24 additions & 0 deletions vista3d/cvpr_workshop/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Use an appropriate base image with GPU support
FROM nvidia/cuda:11.8.0-runtime-ubuntu22.04
RUN apt-get update && apt-get install -y \
python3 python3-pip && \
rm -rf /var/lib/apt/lists/*
# Set working directory
WORKDIR /workspace

# Copy inference script and requirements
COPY infer_cvpr.py /workspace/infer.py
COPY train_cvpr.py /workspace/train.py
COPY update_ckpt.py /workspace/update_ckpt.py
COPY Dockerfile /workspace/Dockerfile
COPY requirements.txt /workspace/
COPY model_final.pth /workspace
# Install Python dependencies
RUN pip3 install -r requirements.txt

# Copy the prediction script
COPY predict.sh /workspace/predict.sh
RUN chmod +x /workspace/predict.sh

# Set default command
CMD ["/bin/bash"]
34 changes: 34 additions & 0 deletions vista3d/cvpr_workshop/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
<!--
Copyright (c) MONAI Consortium
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->

# Overview
This repository is written for the "CVPR 2025: Foundation Models for Interactive 3D Biomedical Image Segmentation"([link](https://www.codabench.org/competitions/5263/)) challenge. It
is based on MONAI 1.4. Many of the functions in the main VISTA3D repository are moved to MONAI 1.4 and this simplified folder will directly use components from MONAI.

It is overly simplied to train interactive segmentation models across different modalities. The sophisticated transforms and recipes used for VISTA3D are removed.

# Setup
```
pip install -r requirements.txt
```

# Training
Download VISTA3D pretrained checkpoint or from scratch. Generate a json list that contains your traning data.
```
torchrun --nnodes=1 --nproc_per_node=8 train_cvpr.py
```

# Inference
We provide a Dockerfile to satisfy the challenge format. For more details, refer to the [challenge website]((https://www.codabench.org/competitions/5263/))


147 changes: 147 additions & 0 deletions vista3d/cvpr_workshop/infer_cvpr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import monai
import monai.transforms
import torch
import argparse
import numpy as np
import nibabel as nib
import glob
from monai.networks.nets.vista3d import vista3d132
from monai.utils import optional_import
from monai.apps.vista3d.inferer import point_based_window_inferer
from monai.inferers import SlidingWindowInfererAdapt

tqdm, _ = optional_import("tqdm", name="tqdm")
import numpy as np
import pdb
import os

def convert_clicks(alldata):
# indexes = list(alldata.keys())
# data = [alldata[i] for i in indexes]
data = alldata
B = len(data) # Number of objects
indexes = np.arange(1, B+1).tolist()
# Determine the maximum number of points across all objects
max_N = max(len(obj['fg']) + len(obj['bg']) for obj in data)

# Initialize padded arrays
point_coords = np.zeros((B, max_N, 3), dtype=int)
point_labels = np.full((B, max_N), -1, dtype=int)

for i, obj in enumerate(data):
points = []
labels = []

# Add foreground points
for fg_point in obj['fg']:
points.append(fg_point)
labels.append(1)

# Add background points
for bg_point in obj['bg']:
points.append(bg_point)
labels.append(0)

# Fill in the arrays
point_coords[i, :len(points)] = points
point_labels[i, :len(labels)] = labels

return point_coords, point_labels, indexes


if __name__ == '__main__':
# set to true to save nifti files for visualization
save_data = False
point_inferer = True # use point based inferen
roi_size = [128,128,128]
parser = argparse.ArgumentParser()
parser.add_argument("--test_img_path", type=str, default='./tests')
parser.add_argument("--save_path", type=str, default='./outputs/')
parser.add_argument("--model", type=str, default='checkpoints/model_final.pth')
args = parser.parse_args()
os.makedirs(args.save_path,exist_ok=True)
# load model
checkpoint_path = args.model
model = vista3d132(in_channels=1)
pretrained_ckpt = torch.load(checkpoint_path, map_location='cuda')
model.load_state_dict(pretrained_ckpt, strict=True)

# load data
test_cases = glob.glob(os.path.join(args.test_img_path, "*.npz"))
for img_path in test_cases:
case_name = os.path.basename(img_path)
print(case_name)
img = np.load(img_path, allow_pickle=True)
img_array = img['imgs']
spacing = img['spacing']
original_shape = img_array.shape
affine = np.diag(spacing.tolist() + [1]) # 4x4 affine matrix
if save_data:
# Create a NIfTI image
nifti_img = nib.Nifti1Image(img_array, affine)
# Save the NIfTI file
nib.save(nifti_img, img_path.replace('.npz','.nii.gz'))
nifti_img = nib.Nifti1Image(img['gts'], affine)
# Save the NIfTI file
nib.save(nifti_img, img_path.replace('.npz','gts.nii.gz'))
clicks = img.get('clicks', [{'fg':[[418, 138, 136]], 'bg':[]}])
point_coords, point_labels, indexes = convert_clicks(clicks)
# preprocess
img_array = torch.from_numpy(img_array)
img_array = img_array.unsqueeze(0)
img_array = monai.transforms.ScaleIntensityRangePercentiles(lower=1, upper=99, b_min=0, b_max=1, clip=True)(img_array)
img_array = img_array.unsqueeze(0) # add channel dim
device = 'cuda'
# slidingwindow
with torch.no_grad():
if not point_inferer:
model.NINF_VALUE = 0 # set to 0 in case sliding window is used.
# directly using slidingwindow inferer is not optimal.
val_outputs = SlidingWindowInfererAdapt(
roi_size=roi_size, sw_batch_size=1, with_coord=True, padding_mode="replicate"
)(
inputs=img_array.to(device),
transpose=True,
network=model.to(device),
point_coords=torch.from_numpy(point_coords).to(device),
point_labels=torch.from_numpy(point_labels).to(device)
)[0] > 0
final_outputs = torch.zeros_like(val_outputs[0], dtype=torch.float32)
for i, v in enumerate(val_outputs):
final_outputs += indexes[i] * v
else:
# point based
final_outputs = torch.zeros_like(img_array[0,0], dtype=torch.float32)
for i, v in enumerate(indexes):
val_outputs = point_based_window_inferer(
inputs=img_array.to(device),
roi_size=roi_size,
transpose=True,
with_coord=True,
predictor=model.to(device),
mode="gaussian",
sw_device=device,
device=device,
center_only=True, # only crop the center
point_coords=torch.from_numpy(point_coords[[i]]).to(device),
point_labels=torch.from_numpy(point_labels[[i]]).to(device)
)[0] > 0
final_outputs[val_outputs[0]] = v
final_outputs = torch.nan_to_num(final_outputs)
# save data
if save_data:
# Create a NIfTI image
nifti_img = nib.Nifti1Image(final_outputs.to(torch.float32).data.cpu().numpy(), affine)
# Save the NIfTI file
nib.save(nifti_img, os.path.join(args.save_path, case_name.replace('.npz','.nii.gz')))
np.savez_compressed(os.path.join(args.save_path, case_name), segs=final_outputs.to(torch.float32).data.cpu().numpy())










4 changes: 4 additions & 0 deletions vista3d/cvpr_workshop/predict.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#!/bin/bash

# Run inference script with input/output folder paths
python3 infer.py --test_img_path /workspace/inputs/ --save_path /workspace/outputs/ --model /workspace/model_final.pth
13 changes: 13 additions & 0 deletions vista3d/cvpr_workshop/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
tensorboard
matplotlib
monai
torchvision
nibabel
torch
connected-components-3d
pandas
numpy
scipy
cupy-cuda12x
cucim
tqdm
Loading