diff --git a/src/sagemaker/modules/configs.py b/src/sagemaker/modules/configs.py index 458c596a36..ac54e2ad0b 100644 --- a/src/sagemaker/modules/configs.py +++ b/src/sagemaker/modules/configs.py @@ -88,7 +88,8 @@ class SourceCode(BaseConfig): Parameters: source_dir (Optional[str]): - The local directory containing the source code to be used in the training job container. + The local directory, s3 uri, or path to tar.gz file stored locally or in s3 that contains + the source code to be used in the training job container. requirements (Optional[str]): The path within ``source_dir`` to a ``requirements.txt`` file. If specified, the listed requirements will be installed in the training job container. diff --git a/src/sagemaker/modules/train/model_trainer.py b/src/sagemaker/modules/train/model_trainer.py index aef6e3312b..4183fb87cd 100644 --- a/src/sagemaker/modules/train/model_trainer.py +++ b/src/sagemaker/modules/train/model_trainer.py @@ -407,28 +407,45 @@ def _validate_source_code(self, source_code: Optional[SourceCode]): "If 'requirements' or 'entry_script' is provided in 'source_code', " + "'source_dir' must also be provided.", ) - if not _is_valid_path(source_dir, path_type="Directory"): + if not ( + _is_valid_path(source_dir, path_type="Directory") + or _is_valid_s3_uri(source_dir, path_type="Directory") + or ( + _is_valid_path(source_dir, path_type="File") + and source_dir.endswith(".tar.gz") + ) + or ( + _is_valid_s3_uri(source_dir, path_type="File") + and source_dir.endswith(".tar.gz") + ) + ): raise ValueError( - f"Invalid 'source_dir' path: {source_dir}. " + "Must be a valid directory.", + f"Invalid 'source_dir' path: {source_dir}. " + + "Must be a valid local directory, " + "s3 uri or path to tar.gz file stored locally or in s3.", ) if requirements: - if not _is_valid_path( - f"{source_dir}/{requirements}", - path_type="File", - ): - raise ValueError( - f"Invalid 'requirements': {requirements}. " - + "Must be a valid file within the 'source_dir'.", - ) + if not source_dir.endswith(".tar.gz"): + if not _is_valid_path( + f"{source_dir}/{requirements}", path_type="File" + ) and not _is_valid_s3_uri( + f"{source_dir}/{requirements}", path_type="File" + ): + raise ValueError( + f"Invalid 'requirements': {requirements}. " + + "Must be a valid file within the 'source_dir'.", + ) if entry_script: - if not _is_valid_path( - f"{source_dir}/{entry_script}", - path_type="File", - ): - raise ValueError( - f"Invalid 'entry_script': {entry_script}. " - + "Must be a valid file within the 'source_dir'.", - ) + if not source_dir.endswith(".tar.gz"): + if not _is_valid_path( + f"{source_dir}/{entry_script}", path_type="File" + ) and not _is_valid_s3_uri( + f"{source_dir}/{entry_script}", path_type="File" + ): + raise ValueError( + f"Invalid 'entry_script': {entry_script}. " + + "Must be a valid file within the 'source_dir'.", + ) def model_post_init(self, __context: Any): """Post init method to perform custom validation and set default values.""" @@ -838,12 +855,17 @@ def _prepare_train_script( install_requirements = "" if source_code.requirements: - install_requirements = "echo 'Installing requirements'\n" - install_requirements = f"$SM_PIP_CMD install -r {source_code.requirements}" + install_requirements = ( + "echo 'Installing requirements'\n" + + f"$SM_PIP_CMD install -r {source_code.requirements}" + ) working_dir = "" if source_code.source_dir: - working_dir = f"cd {SM_CODE_CONTAINER_PATH}" + working_dir = f"cd {SM_CODE_CONTAINER_PATH} \n" + if source_code.source_dir.endswith(".tar.gz"): + tarfile_name = os.path.basename(source_code.source_dir) + working_dir += f"tar --strip-components=1 -xzf {tarfile_name} \n" if base_command: execute_driver = EXECUTE_BASE_COMMANDS.format(base_command=base_command) diff --git a/tests/data/modules/script_mode/code.tar.gz b/tests/data/modules/script_mode/code.tar.gz new file mode 100644 index 0000000000..7c43f35f57 Binary files /dev/null and b/tests/data/modules/script_mode/code.tar.gz differ diff --git a/tests/integ/sagemaker/modules/train/test_model_trainer.py b/tests/integ/sagemaker/modules/train/test_model_trainer.py index a1e3106553..332b536d77 100644 --- a/tests/integ/sagemaker/modules/train/test_model_trainer.py +++ b/tests/integ/sagemaker/modules/train/test_model_trainer.py @@ -44,6 +44,24 @@ DEFAULT_CPU_IMAGE = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:2.0.0-cpu-py310" +TAR_FILE_SOURCE_DIR = f"{DATA_DIR}/modules/script_mode/code.tar.gz" +TAR_FILE_SOURCE_CODE = SourceCode( + source_dir=TAR_FILE_SOURCE_DIR, + requirements="requirements.txt", + entry_script="custom_script.py", +) + + +def test_source_dir_local_tar_file(modules_sagemaker_session): + model_trainer = ModelTrainer( + sagemaker_session=modules_sagemaker_session, + training_image=DEFAULT_CPU_IMAGE, + source_code=TAR_FILE_SOURCE_CODE, + base_job_name="source_dir_local_tar_file", + ) + + model_trainer.train() + def test_hp_contract_basic_py_script(modules_sagemaker_session): model_trainer = ModelTrainer( diff --git a/tests/unit/sagemaker/modules/train/test_model_trainer.py b/tests/unit/sagemaker/modules/train/test_model_trainer.py index 13530a3983..6001c5db36 100644 --- a/tests/unit/sagemaker/modules/train/test_model_trainer.py +++ b/tests/unit/sagemaker/modules/train/test_model_trainer.py @@ -92,9 +92,6 @@ source_dir=DEFAULT_SOURCE_DIR, entry_script="custom_script.py", ) -UNSUPPORTED_SOURCE_CODE = SourceCode( - entry_script="train.py", -) DEFAULT_ENTRYPOINT = ["/bin/bash"] DEFAULT_ARGUMENTS = [ "-c", @@ -152,7 +149,19 @@ def model_trainer(): { "init_params": { "training_image": DEFAULT_IMAGE, - "source_code": UNSUPPORTED_SOURCE_CODE, + "source_code": SourceCode( + entry_script="train.py", + ), + }, + "should_throw": True, + }, + { + "init_params": { + "training_image": DEFAULT_IMAGE, + "source_code": SourceCode( + source_dir="s3://bucket/requirements.txt", + entry_script="custom_script.py", + ), }, "should_throw": True, }, @@ -163,13 +172,47 @@ def model_trainer(): }, "should_throw": False, }, + { + "init_params": { + "training_image": DEFAULT_IMAGE, + "source_code": SourceCode( + source_dir=f"{DEFAULT_SOURCE_DIR}/code.tar.gz", + entry_script="custom_script.py", + ), + }, + "should_throw": False, + }, + { + "init_params": { + "training_image": DEFAULT_IMAGE, + "source_code": SourceCode( + source_dir="s3://bucket/code/", + entry_script="custom_script.py", + ), + }, + "should_throw": False, + }, + { + "init_params": { + "training_image": DEFAULT_IMAGE, + "source_code": SourceCode( + source_dir="s3://bucket/code/code.tar.gz", + entry_script="custom_script.py", + ), + }, + "should_throw": False, + }, ], ids=[ "no_params", "training_image_and_algorithm_name", "only_training_image", - "unsupported_source_code", - "supported_source_code", + "unsupported_source_code_missing_source_dir", + "unsupported_source_code_s3_other_file", + "supported_source_code_local_dir", + "supported_source_code_local_tar_file", + "supported_source_code_s3_dir", + "supported_source_code_s3_tar_file", ], ) def test_model_trainer_param_validation(test_case, modules_session):