diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a7b501f..69e6bfe 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,7 +9,7 @@ ci: repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v5.0.0 + rev: v6.0.0 hooks: - id: end-of-file-fixer - id: trailing-whitespace @@ -26,18 +26,18 @@ repos: args: ['--autofix', '--no-sort-keys', '--indent=4'] - id: end-of-file-fixer - id: mixed-line-ending - - repo: https://github.com/psf/black - rev: "24.10.0" + - repo: https://github.com/psf/black-pre-commit-mirror + rev: "25.12.0" hooks: - id: black - id: black-jupyter - repo: https://github.com/pycqa/isort - rev: 5.13.2 + rev: 7.0.0 hooks: - id: isort args: ["--profile", "black"] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.8.6 + rev: v0.14.10 hooks: - id: ruff args: ['--fix'] diff --git a/vista3d/README.md b/vista3d/README.md index 070573f..e34a99b 100644 --- a/vista3d/README.md +++ b/vista3d/README.md @@ -96,9 +96,9 @@ mv model-zoo/models/vista3d vista3dbundle & rm -rf model-zoo cd vista3dbundle mkdir models # minor model weights naming conversion due to monai version change -wget -O models/model.pt https://developer.download.nvidia.com/assets/Clara/monai/tutorials/model_zoo/model_vista3d.pt +wget -O models/model.pt https://developer.download.nvidia.com/assets/Clara/monai/tutorials/model_zoo/model_vista3d.pt ``` -MONAI bundle accepts multiple json config files and input arguments. The latter configs/arguments will overide the previous configs/arguments if they have overlapping keys. +MONAI bundle accepts multiple json config files and input arguments. The latter configs/arguments will overide the previous configs/arguments if they have overlapping keys. ```python # Automatic Segment everything python -m monai.bundle run --config_file configs/inference.json --input_dict "{'image':'spleen_03.nii.gz'} @@ -108,7 +108,7 @@ python -m monai.bundle run --config_file configs/inference.json --input_dict "{' python -m monai.bundle run --config_file configs/inference.json --input_dict "{'image':'spleen_03.nii.gz','label_prompt':[3]} ``` ```python -# Interactive segmentation +# Interactive segmentation # Points must be three dimensional (x,y,z) in the shape of [[x,y,z],...,[x,y,z]]. Point labels can only be -1(ignore), 0(negative), 1(positive) and 2(negative for special overlaped class like tumor), 3(positive for special class). Only supporting 1 class per inference. The output 255 represents NaN value which means not processed region. python -m monai.bundle run --config_file configs/inference.json --input_dict "{'image':'spleen_03.nii.gz','points':[[128,128,16], [100,100,16]],'point_labels':[1, 0]}" ``` @@ -158,7 +158,7 @@ python -m monai.bundle run --config_file="['configs/inference.json', 'configs/ba ### 1.1 Overlapped classes and postprocessing with [ShapeKit](https://arxiv.org/pdf/2506.24003) VISTA3D is trained with binary segmentation, and may produce false positives due to weak false positive supervision. ShapeKit solves this problem with sophisticated postprocessing. ShapeKit requires segmentation mask for each class. VISTA3D by default performs argmax and collaps overlapping classes. Change the `monai.apps.vista3d.transforms.VistaPostTransformd` in `inference.json` to save each class segmentation as a separate channel. Then follow [ShapeKit](https://github.com/BodyMaps/ShapeKit) codebase for processing. ```json -{ +{ "_target_": "Activationsd", "sigmoid": true, "keys": "pred" @@ -180,7 +180,7 @@ To segment everything, run ```bash export CUDA_VISIBLE_DEVICES=0; python -m scripts.infer --config_file 'configs/infer.yaml' - infer_everything --image_file 'example-1.nii.gz' ``` -To segment based on point clicks, provide `point` and `point_label`. +To segment based on point clicks, provide `point` and `point_label`. ```bash export CUDA_VISIBLE_DEVICES=0; python -m scripts.infer --config_file 'configs/infer.yaml' - infer --image_file 'example-1.nii.gz' --point "[[128,128,16],[100,100,6]]" --point_label "[1,0]" --save_mask true ``` diff --git a/vista3d/cvpr_workshop/infer_cvpr.py b/vista3d/cvpr_workshop/infer_cvpr.py index 61d3536..0b3d7ef 100755 --- a/vista3d/cvpr_workshop/infer_cvpr.py +++ b/vista3d/cvpr_workshop/infer_cvpr.py @@ -16,6 +16,7 @@ from train_cvpr import ROI_SIZE + def convert_clicks(alldata): # indexes = list(alldata.keys()) # data = [alldata[i] for i in indexes] diff --git a/vista3d/cvpr_workshop/train_cvpr.py b/vista3d/cvpr_workshop/train_cvpr.py index ac099f6..2bfac66 100755 --- a/vista3d/cvpr_workshop/train_cvpr.py +++ b/vista3d/cvpr_workshop/train_cvpr.py @@ -22,7 +22,8 @@ import matplotlib.pyplot as plt NUM_PATCHES_PER_IMAGE = 2 -ROI_SIZE= [128, 128, 128] +ROI_SIZE = [128, 128, 128] + def plot_to_tensorboard(writer, epoch, inputs, labels, points, outputs): """ @@ -109,7 +110,7 @@ def __getitem__(self, idx): keys=["image", "label"], label_key="label", num_classes=label.max() + 1, - ratios=tuple(float(i > 0) for i in range(label.max()+1)), + ratios=tuple(float(i > 0) for i in range(label.max() + 1)), num_samples=NUM_PATCHES_PER_IMAGE, ), monai.transforms.RandScaleIntensityd( @@ -137,17 +138,19 @@ def __getitem__(self, idx): mode=["constant", "constant"], keys=["image", "label"], spatial_size=ROI_SIZE, - ) + ), ] ) data = transforms(data) return data + import re + def get_latest_epoch(directory): # Pattern to match filenames like 'model_epoch.pth' - pattern = re.compile(r'model_epoch(\d+)\.pth') + pattern = re.compile(r"model_epoch(\d+)\.pth") max_epoch = -1 for filename in os.listdir(directory): @@ -159,6 +162,7 @@ def get_latest_epoch(directory): return max_epoch if max_epoch != -1 else None + # Training function def train(): json_file = "allset.json" # Update with your JSON file @@ -169,7 +173,6 @@ def train(): start_epoch = get_latest_epoch(checkpoint_dir) start_checkpoint = "./CPRR25_vista3D_model_final_10percent_data.pth" - os.makedirs(checkpoint_dir, exist_ok=True) dist.init_process_group(backend="nccl") world_size = int(os.environ["WORLD_SIZE"]) @@ -189,11 +192,12 @@ def train(): model.load_state_dict(pretrained_ckpt, strict=True) else: print(f"Resuming from epoch {start_epoch}") - pretrained_ckpt = torch.load(os.path.join(checkpoint_dir, f"model_epoch{start_epoch}.pth")) - model.load_state_dict(pretrained_ckpt['model'], strict=True) + pretrained_ckpt = torch.load( + os.path.join(checkpoint_dir, f"model_epoch{start_epoch}.pth") + ) + model.load_state_dict(pretrained_ckpt["model"], strict=True) model = DDP(model, device_ids=[local_rank], find_unused_parameters=True) - optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1.0e-05) lr_scheduler = monai.optimizers.WarmupCosineSchedule( optimizer=optimizer, @@ -265,10 +269,16 @@ def train(): if local_rank == 0: writer.add_scalar("loss", loss.item(), step) if local_rank == 0 and (epoch + 1) % save_interval == 0: - checkpoint_path = os.path.join(checkpoint_dir, f"model_epoch{epoch + 1}.pth") + checkpoint_path = os.path.join( + checkpoint_dir, f"model_epoch{epoch + 1}.pth" + ) if world_size > 1: torch.save( - {"model": model.module.state_dict(), "epoch": epoch + 1, "step": step}, + { + "model": model.module.state_dict(), + "epoch": epoch + 1, + "step": step, + }, checkpoint_path, ) print(