Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 46 additions & 12 deletions monai/apps/nnunet/nnunetv2_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import glob
import os
import re
import shlex
import subprocess
from typing import Any
Expand All @@ -34,6 +35,8 @@

__all__ = ["nnUNetV2Runner"]

DATASET_ID_FORMAT = r"Dataset[0-9]{3}|[0-9]+" # regex format for a valid nnUnet dataset name


class nnUNetV2Runner: # noqa: N801
"""
Expand Down Expand Up @@ -195,6 +198,13 @@ def __init__(

# dataset_name_or_id has to be a string
self.dataset_name_or_id = str(self.input_info.pop("dataset_name_or_id", 1))
self.dataset_name: str | None = None

# ensure the dataset name is a single identifier/number, this prevents code injection when composing commands
if re.fullmatch(DATASET_ID_FORMAT, self.dataset_name_or_id) is None:
raise ValueError(
f"Value for dataset_name_or_id `{self.dataset_name_or_id}` not a valid dataset name or ID."
)

try:
from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
Expand Down Expand Up @@ -239,7 +249,7 @@ def convert_dataset(self):

from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name

self.dataset_name = maybe_convert_to_dataset_name(int(self.dataset_name_or_id))
self.dataset_name = maybe_convert_to_dataset_name(self.dataset_name_or_id)

datalist_json = ConfigParser.load_config_file(self.input_info.pop("datalist"))

Expand Down Expand Up @@ -548,7 +558,7 @@ def train_single_model_command(
Raises:
ValueError: If gpu_id is an empty tuple or list.
"""
env = os.environ.copy()
env: dict[str, str] = os.environ.copy()
device_setting: str = "0"
num_gpus = 1
if isinstance(gpu_id, str):
Expand All @@ -574,22 +584,25 @@ def train_single_model_command(

cmd = [
"nnUNetv2_train",
f"{self.dataset_name_or_id}",
f"{config}",
f"{fold}",
self.dataset_name_or_id,
config,
fold,
"-tr",
f"{self.trainer_class_name}",
self.trainer_class_name,
"-num_gpus",
f"{num_gpus}",
num_gpus,
]

if self.export_validation_probabilities:
cmd.append("--npz")

for _key, _value in kwargs.items():
if _key == "p" or _key == "pretrained_weights":
cmd.extend([f"-{_key}", f"{_value}"])
else:
cmd.extend([f"--{_key}", f"{_value}"])
return cmd, env
prefix = "-" if _key in {"p", "pretrained_weights"} else "--"
cmd += [f"{prefix}{_key}", str(_value)]

cmd_str: list[str] = [str(c) for c in cmd]

return cmd_str, env
Comment on lines 585 to +605
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

🧩 Analysis chain

🌐 Web query:

nnUNetv2_train argparse flags store_true --c --val --use_compressed --disable_checkpointing

💡 Result:

In nnU-Net v2, the command line utility nnUNetv2_train utilizes several argparse flags to control the training and validation process. These flags are implemented as store_true options, meaning they do not require an associated value and serve as toggles [1][2]. The requested flags perform the following functions: --c (Continue training): This flag instructs the trainer to resume training from the most recent checkpoint. It is essential if a training run was interrupted or if you wish to extend a previously stopped training session [2]. --val (Validation only): This flag sets the trainer to skip the training process and instead run only the final validation on the specified fold [2][3]. This is often used after a training run has completed or to re-run validation without re-training [4][3]. --use_compressed: This flag directs the system to use compressed data during training. By default, nnU-Net may decompress data; enabling this flag prevents decompression. It is noted that reading compressed data is significantly more CPU and potentially RAM intensive, and it should only be used when necessary [2]. --disable_checkpointing: This flag disables the saving of model checkpoints during training. This is useful in scenarios where you are testing a configuration or running quick experiments and do not want to consume disk space with multiple large model checkpoint files [2]. In addition to these, the --npz flag is frequently used alongside these commands. Although not explicitly in your list, it is critical to note that --npz is required if you intend to use nnUNetv2_find_best_configuration later, as it forces the model to save the softmax probabilities from the final validation, which are necessary for ensemble creation and model selection [4][3].

Citations:


Fix argv construction for store_true kwargs (e.g., --c, --val) in monai/apps/nnunet/nnunetv2_runner.py (lines 585-605): the kwargs loop always appends str(_value), so documented boolean flags (c, val, use_compressed, disable_checkpointing) become --c True / --val True instead of being emitted as bare flags (the same pattern as the special-case --npz). Update the builder to emit --<flag> only when the value is True and omit the flag when False (and similarly only include pretrained_weights when it’s a real path, not False).

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/apps/nnunet/nnunetv2_runner.py` around lines 585 - 605, The argv
builder in the kwargs loop currently always appends str(_value), causing
store_true options to appear as "--flag True" instead of a bare "--flag" and
including false/empty pretrained_weights; update the loop in nnunetv2_runner.py
where cmd is assembled so that for boolean flags (e.g., keys "c", "val",
"use_compressed", "disable_checkpointing" and any other store_true options) you
only append the flag name (prefix + _key) when _value is True and skip it when
False, and for "pretrained_weights" (or "-p") only append the option and its
value when _value is truthy (a non-empty path); keep the existing prefix
selection logic for short vs long options and continue converting items to
strings later.


def train(
self,
Expand Down Expand Up @@ -641,7 +654,14 @@ def train_parallel_cmd(
None (all available GPUs).
kwargs: this optional parameter allows you to specify additional arguments defined in the
``train_single_model`` method.

Raises:
ValueError: self.dataset_name must have a value, ie. when using an existing dataset or after creating one.
"""

if self.dataset_name is None:
raise ValueError(f"A valid dataset name must be given in {self.dataset_name=}.")

# unpack compressed files
folder_names = []
for root, _, files in os.walk(os.path.join(self.nnunet_preprocessed, self.dataset_name)):
Expand Down Expand Up @@ -696,7 +716,14 @@ def train_parallel(
None (all available GPUs).
kwargs: this optional parameter allows you to specify additional arguments defined in the
``train_single_model`` method.

Raises:
ValueError: self.dataset_name must have a value, ie. when using an existing dataset or after creating one.
"""

if self.dataset_name is None:
raise ValueError(f"A valid dataset name must be given in {self.dataset_name=}.")

all_cmds = self.train_parallel_cmd(configs=configs, gpu_id_for_all=gpu_id_for_all, **kwargs)
for s, cmds in enumerate(all_cmds):
for gpu_id, gpu_cmd in cmds.items():
Expand Down Expand Up @@ -908,7 +935,14 @@ def predict_ensemble_postprocessing(
run_postprocessing: whether to conduct post-processing
kwargs: this optional parameter allows you to specify additional arguments defined in the
``predict`` method.

Raises:
ValueError: self.dataset_name must have a value, ie. when using an existing dataset or after creating one.
"""

if self.dataset_name is None:
raise ValueError(f"A valid dataset name must be given in {self.dataset_name=}.")

from nnunetv2.ensembling.ensemble import ensemble_folders
from nnunetv2.postprocessing.remove_connected_components import apply_postprocessing_to_folder
from nnunetv2.utilities.file_path_utilities import get_output_folder
Expand Down
75 changes: 75 additions & 0 deletions tests/integration/test_integration_nnunetv2_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@

from __future__ import annotations

import logging
import os
import tempfile
import unittest
from textwrap import dedent

import nibabel as nib
import numpy as np

import monai.apps.nnunet.nnunetv2_runner
from monai.apps.nnunet import nnUNetV2Runner
from monai.bundle.config_parser import ConfigParser
from monai.data import create_test_image_3d
Expand All @@ -27,6 +30,8 @@
_, has_tb = optional_import("torch.utils.tensorboard", name="SummaryWriter")
_, has_nnunet = optional_import("nnunetv2")

monai.apps.nnunet.nnunetv2_runner.logger.setLevel(logging.ERROR) # suppress warning logging to clean up test output

sim_datalist: dict[str, list[dict]] = {
"testing": [{"image": "val_001.fake.nii.gz"}, {"image": "val_002.fake.nii.gz"}],
"training": [
Expand Down Expand Up @@ -91,5 +96,75 @@ def tearDown(self) -> None:
self.test_dir.cleanup()


@skip_if_quick
@unittest.skipIf(not has_nnunet, "no nnunetv2")
class TestnnUNetV2RunnerSecurity(unittest.TestCase):
def setUp(self) -> None:
self.test_dir = tempfile.TemporaryDirectory()
test_path = self.test_dir.name

self.good_yml1 = os.path.join(test_path, "good1.yml")
self.good_yml2 = os.path.join(test_path, "good2.yml")
self.inject_yml = os.path.join(test_path, "test.yml")

good_yml_content1 = f"""
dataset_name_or_id: Dataset123
dataroot: {test_path}/data
datalist: {test_path}/lists/task4.json
work_dir: {test_path}/work
nnunet_raw: {test_path}/nnUNet_raw
nnunet_preprocessed: {test_path}/nnUNet_preprocessed
nnunet_results: {test_path}/nnUNet_results
"""

with open(self.good_yml1, "w") as o:
o.write(dedent(good_yml_content1))

good_yml_content2 = f"""
dataset_name_or_id: 123
dataroot: {test_path}/data
datalist: {test_path}/lists/task4.json
work_dir: {test_path}/work
nnunet_raw: {test_path}/nnUNet_raw
nnunet_preprocessed: {test_path}/nnUNet_preprocessed
nnunet_results: {test_path}/nnUNet_results
"""

with open(self.good_yml2, "w") as o:
o.write(dedent(good_yml_content2))

# define a config file with code-injecting dataset name
injecting_yml_content = f"""
dataset_name_or_id: '4 & echo "This is exploited" > "{test_path}/test.txt" & rem'
dataroot: {test_path}/data
datalist: {test_path}/lists/task4.json
work_dir: {test_path}/work
nnunet_raw: {test_path}/nnUNet_raw
nnunet_preprocessed: {test_path}/nnUNet_preprocessed
nnunet_results: {test_path}/nnUNet_results
"""

with open(self.inject_yml, "w") as o:
o.write(dedent(injecting_yml_content))

def test_nnunetv2runner_good_dataset_name(self) -> None:
"""
Test the dataset name given must conform to the nnUNet requirement of being an int or "Dataset###".
"""
for ds in [self.good_yml1, self.good_yml2]:
with self.subTest(f"Testing {os.path.basename(ds)}"):
nnUNetV2Runner(input_config=ds, trainer_class_name="nnUNetTrainer")

def test_nnunetv2runner_bad_dataset_name(self) -> None:
"""
Test the dataset name given must conform to the nnUNet requirement of being an int or "Dataset###".
"""
with self.assertRaises(ValueError):
nnUNetV2Runner(input_config=self.inject_yml, trainer_class_name="nnUNetTrainer")

def tearDown(self) -> None:
self.test_dir.cleanup()


if __name__ == "__main__":
unittest.main()
Loading