diff --git a/monai/apps/nnunet/nnunetv2_runner.py b/monai/apps/nnunet/nnunetv2_runner.py index 4828ccb56e..547e73332f 100644 --- a/monai/apps/nnunet/nnunetv2_runner.py +++ b/monai/apps/nnunet/nnunetv2_runner.py @@ -14,6 +14,7 @@ import glob import os +import re import shlex import subprocess from typing import Any @@ -34,6 +35,8 @@ __all__ = ["nnUNetV2Runner"] +DATASET_ID_FORMAT = r"Dataset[0-9]{3}|[0-9]+" # regex format for a valid nnUnet dataset name + class nnUNetV2Runner: # noqa: N801 """ @@ -195,6 +198,13 @@ def __init__( # dataset_name_or_id has to be a string self.dataset_name_or_id = str(self.input_info.pop("dataset_name_or_id", 1)) + self.dataset_name: str | None = None + + # ensure the dataset name is a single identifier/number, this prevents code injection when composing commands + if re.fullmatch(DATASET_ID_FORMAT, self.dataset_name_or_id) is None: + raise ValueError( + f"Value for dataset_name_or_id `{self.dataset_name_or_id}` not a valid dataset name or ID." + ) try: from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name @@ -239,7 +249,7 @@ def convert_dataset(self): from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name - self.dataset_name = maybe_convert_to_dataset_name(int(self.dataset_name_or_id)) + self.dataset_name = maybe_convert_to_dataset_name(self.dataset_name_or_id) datalist_json = ConfigParser.load_config_file(self.input_info.pop("datalist")) @@ -548,7 +558,7 @@ def train_single_model_command( Raises: ValueError: If gpu_id is an empty tuple or list. """ - env = os.environ.copy() + env: dict[str, str] = os.environ.copy() device_setting: str = "0" num_gpus = 1 if isinstance(gpu_id, str): @@ -574,22 +584,25 @@ def train_single_model_command( cmd = [ "nnUNetv2_train", - f"{self.dataset_name_or_id}", - f"{config}", - f"{fold}", + self.dataset_name_or_id, + config, + fold, "-tr", - f"{self.trainer_class_name}", + self.trainer_class_name, "-num_gpus", - f"{num_gpus}", + num_gpus, ] + if self.export_validation_probabilities: cmd.append("--npz") + for _key, _value in kwargs.items(): - if _key == "p" or _key == "pretrained_weights": - cmd.extend([f"-{_key}", f"{_value}"]) - else: - cmd.extend([f"--{_key}", f"{_value}"]) - return cmd, env + prefix = "-" if _key in {"p", "pretrained_weights"} else "--" + cmd += [f"{prefix}{_key}", str(_value)] + + cmd_str: list[str] = [str(c) for c in cmd] + + return cmd_str, env def train( self, @@ -641,7 +654,14 @@ def train_parallel_cmd( None (all available GPUs). kwargs: this optional parameter allows you to specify additional arguments defined in the ``train_single_model`` method. + + Raises: + ValueError: self.dataset_name must have a value, ie. when using an existing dataset or after creating one. """ + + if self.dataset_name is None: + raise ValueError(f"A valid dataset name must be given in {self.dataset_name=}.") + # unpack compressed files folder_names = [] for root, _, files in os.walk(os.path.join(self.nnunet_preprocessed, self.dataset_name)): @@ -696,7 +716,14 @@ def train_parallel( None (all available GPUs). kwargs: this optional parameter allows you to specify additional arguments defined in the ``train_single_model`` method. + + Raises: + ValueError: self.dataset_name must have a value, ie. when using an existing dataset or after creating one. """ + + if self.dataset_name is None: + raise ValueError(f"A valid dataset name must be given in {self.dataset_name=}.") + all_cmds = self.train_parallel_cmd(configs=configs, gpu_id_for_all=gpu_id_for_all, **kwargs) for s, cmds in enumerate(all_cmds): for gpu_id, gpu_cmd in cmds.items(): @@ -908,7 +935,14 @@ def predict_ensemble_postprocessing( run_postprocessing: whether to conduct post-processing kwargs: this optional parameter allows you to specify additional arguments defined in the ``predict`` method. + + Raises: + ValueError: self.dataset_name must have a value, ie. when using an existing dataset or after creating one. """ + + if self.dataset_name is None: + raise ValueError(f"A valid dataset name must be given in {self.dataset_name=}.") + from nnunetv2.ensembling.ensemble import ensemble_folders from nnunetv2.postprocessing.remove_connected_components import apply_postprocessing_to_folder from nnunetv2.utilities.file_path_utilities import get_output_folder diff --git a/tests/integration/test_integration_nnunetv2_runner.py b/tests/integration/test_integration_nnunetv2_runner.py index 44291e5722..f4cf0b4fb1 100644 --- a/tests/integration/test_integration_nnunetv2_runner.py +++ b/tests/integration/test_integration_nnunetv2_runner.py @@ -11,13 +11,16 @@ from __future__ import annotations +import logging import os import tempfile import unittest +from textwrap import dedent import nibabel as nib import numpy as np +import monai.apps.nnunet.nnunetv2_runner from monai.apps.nnunet import nnUNetV2Runner from monai.bundle.config_parser import ConfigParser from monai.data import create_test_image_3d @@ -27,6 +30,8 @@ _, has_tb = optional_import("torch.utils.tensorboard", name="SummaryWriter") _, has_nnunet = optional_import("nnunetv2") +monai.apps.nnunet.nnunetv2_runner.logger.setLevel(logging.ERROR) # suppress warning logging to clean up test output + sim_datalist: dict[str, list[dict]] = { "testing": [{"image": "val_001.fake.nii.gz"}, {"image": "val_002.fake.nii.gz"}], "training": [ @@ -91,5 +96,75 @@ def tearDown(self) -> None: self.test_dir.cleanup() +@skip_if_quick +@unittest.skipIf(not has_nnunet, "no nnunetv2") +class TestnnUNetV2RunnerSecurity(unittest.TestCase): + def setUp(self) -> None: + self.test_dir = tempfile.TemporaryDirectory() + test_path = self.test_dir.name + + self.good_yml1 = os.path.join(test_path, "good1.yml") + self.good_yml2 = os.path.join(test_path, "good2.yml") + self.inject_yml = os.path.join(test_path, "test.yml") + + good_yml_content1 = f""" + dataset_name_or_id: Dataset123 + dataroot: {test_path}/data + datalist: {test_path}/lists/task4.json + work_dir: {test_path}/work + nnunet_raw: {test_path}/nnUNet_raw + nnunet_preprocessed: {test_path}/nnUNet_preprocessed + nnunet_results: {test_path}/nnUNet_results + """ + + with open(self.good_yml1, "w") as o: + o.write(dedent(good_yml_content1)) + + good_yml_content2 = f""" + dataset_name_or_id: 123 + dataroot: {test_path}/data + datalist: {test_path}/lists/task4.json + work_dir: {test_path}/work + nnunet_raw: {test_path}/nnUNet_raw + nnunet_preprocessed: {test_path}/nnUNet_preprocessed + nnunet_results: {test_path}/nnUNet_results + """ + + with open(self.good_yml2, "w") as o: + o.write(dedent(good_yml_content2)) + + # define a config file with code-injecting dataset name + injecting_yml_content = f""" + dataset_name_or_id: '4 & echo "This is exploited" > "{test_path}/test.txt" & rem' + dataroot: {test_path}/data + datalist: {test_path}/lists/task4.json + work_dir: {test_path}/work + nnunet_raw: {test_path}/nnUNet_raw + nnunet_preprocessed: {test_path}/nnUNet_preprocessed + nnunet_results: {test_path}/nnUNet_results + """ + + with open(self.inject_yml, "w") as o: + o.write(dedent(injecting_yml_content)) + + def test_nnunetv2runner_good_dataset_name(self) -> None: + """ + Test the dataset name given must conform to the nnUNet requirement of being an int or "Dataset###". + """ + for ds in [self.good_yml1, self.good_yml2]: + with self.subTest(f"Testing {os.path.basename(ds)}"): + nnUNetV2Runner(input_config=ds, trainer_class_name="nnUNetTrainer") + + def test_nnunetv2runner_bad_dataset_name(self) -> None: + """ + Test the dataset name given must conform to the nnUNet requirement of being an int or "Dataset###". + """ + with self.assertRaises(ValueError): + nnUNetV2Runner(input_config=self.inject_yml, trainer_class_name="nnUNetTrainer") + + def tearDown(self) -> None: + self.test_dir.cleanup() + + if __name__ == "__main__": unittest.main()