From b898f76dbde5602513da3e0ddc299844ba5b96eb Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos Date: Thu, 13 Feb 2025 23:27:11 -0800 Subject: [PATCH 1/6] fix: make configs safer --- src/sagemaker/modules/configs.py | 6 +++++- src/sagemaker/modules/distributed.py | 6 +++++- src/sagemaker/modules/train/model_trainer.py | 8 ++++++-- .../sagemaker/modules/train/test_model_trainer.py | 11 ++++++++++- 4 files changed, 26 insertions(+), 5 deletions(-) diff --git a/src/sagemaker/modules/configs.py b/src/sagemaker/modules/configs.py index ec0df519f5..0d0b7a0608 100644 --- a/src/sagemaker/modules/configs.py +++ b/src/sagemaker/modules/configs.py @@ -22,7 +22,7 @@ from __future__ import absolute_import from typing import Optional, Union -from pydantic import BaseModel, model_validator +from pydantic import BaseModel, model_validator, ConfigDict import sagemaker_core.shapes as shapes @@ -94,6 +94,8 @@ class SourceCode(BaseModel): If not specified, entry_script must be provided. """ + model_config = ConfigDict(validate_assignment=True, extra="forbid") + source_dir: Optional[str] = None requirements: Optional[str] = None entry_script: Optional[str] = None @@ -215,5 +217,7 @@ class InputData(BaseModel): S3DataSource object, or FileSystemDataSource object. """ + model_config = ConfigDict(validate_assignment=True, extra="forbid") + channel_name: str = None data_source: Union[str, FileSystemDataSource, S3DataSource] = None diff --git a/src/sagemaker/modules/distributed.py b/src/sagemaker/modules/distributed.py index 6cdc136dcf..06297a1924 100644 --- a/src/sagemaker/modules/distributed.py +++ b/src/sagemaker/modules/distributed.py @@ -14,7 +14,7 @@ from __future__ import absolute_import from typing import Optional, Dict, Any, List -from pydantic import BaseModel, PrivateAttr +from pydantic import BaseModel, PrivateAttr, ConfigDict from sagemaker.modules.utils import safe_serialize @@ -53,6 +53,8 @@ class SMP(BaseModel): parallelism or expert parallelism. """ + model_config = ConfigDict(validate_assignment=True, extra="forbid") + hybrid_shard_degree: Optional[int] = None sm_activation_offloading: Optional[bool] = None activation_loading_horizon: Optional[int] = None @@ -75,6 +77,8 @@ def _to_mp_hyperparameters(self) -> Dict[str, Any]: class DistributedConfig(BaseModel): """Base class for distributed training configurations.""" + model_config = ConfigDict(validate_assignment=True, extra="forbid") + _type: str = PrivateAttr() def model_dump(self, *args, **kwargs): diff --git a/src/sagemaker/modules/train/model_trainer.py b/src/sagemaker/modules/train/model_trainer.py index 31decfaca9..49a78417d4 100644 --- a/src/sagemaker/modules/train/model_trainer.py +++ b/src/sagemaker/modules/train/model_trainer.py @@ -792,14 +792,14 @@ def _prepare_train_script( """Prepare the training script to be executed in the training job container. Args: - source_code (SourceCodeConfig): The source code configuration. + source_code (SourceCode): The source code configuration. """ base_command = "" if source_code.command: if source_code.entry_script: logger.warning( - "Both 'command' and 'entry_script' are provided in the SourceCodeConfig. " + "Both 'command' and 'entry_script' are provided in the SourceCode. " + "Defaulting to 'command'." ) base_command = source_code.command.split() @@ -831,6 +831,10 @@ def _prepare_train_script( + "Only .py and .sh scripts are supported." ) execute_driver = EXECUTE_BASIC_SCRIPT_DRIVER + else: + raise ValueError( + f"Invalid configuration, please provide a valid SourceCode: {source_code}" + ) train_script = TRAIN_SCRIPT_TEMPLATE.format( working_dir=working_dir, diff --git a/tests/unit/sagemaker/modules/train/test_model_trainer.py b/tests/unit/sagemaker/modules/train/test_model_trainer.py index 093da20ab8..39ac8ab3f2 100644 --- a/tests/unit/sagemaker/modules/train/test_model_trainer.py +++ b/tests/unit/sagemaker/modules/train/test_model_trainer.py @@ -438,7 +438,7 @@ def test_create_input_data_channel(mock_default_bucket, mock_upload_data, model_ { "source_code": DEFAULT_SOURCE_CODE, "distributed": MPI( - custom_mpi_options=["-x", "VAR1", "-x", "VAR2"], + mpi_additional_options=["-x", "VAR1", "-x", "VAR2"], ), "expected_template": EXECUTE_MPI_DRIVER, "expected_hyperparameters": {}, @@ -1059,3 +1059,12 @@ def mock_upload_data(path, bucket, key_prefix): hyper_parameters=hyperparameters, environment=environment, ) + + +def test_safe_configs(): + # Test extra fails + with pytest.raises(ValueError): + SourceCode(entry_point="train.py") + # Test invalid type fails + with pytest.raises(ValueError): + SourceCode(entry_script=1) From 6223f260a58bfb1211529b6486900c61be5b4ff2 Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos Date: Fri, 14 Feb 2025 10:24:46 -0800 Subject: [PATCH 2/6] fix: safer destructor in ModelTrainer --- src/sagemaker/modules/train/model_trainer.py | 11 +++++--- .../modules/train/test_model_trainer.py | 25 +++++++++++++++++++ 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/src/sagemaker/modules/train/model_trainer.py b/src/sagemaker/modules/train/model_trainer.py index 49a78417d4..f454498008 100644 --- a/src/sagemaker/modules/train/model_trainer.py +++ b/src/sagemaker/modules/train/model_trainer.py @@ -205,7 +205,9 @@ class ModelTrainer(BaseModel): "LOCAL_CONTAINER" mode. """ - model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") + model_config = ConfigDict( + arbitrary_types_allowed=True, validate_assignment=True, extra="forbid" + ) training_mode: Mode = Mode.SAGEMAKER_TRAINING_JOB sagemaker_session: Optional[Session] = None @@ -363,9 +365,10 @@ def _populate_intelligent_defaults_from_model_trainer_space(self): def __del__(self): """Destructor method to clean up the temporary directory.""" - # Clean up the temporary directory if it exists - if self._temp_recipe_train_dir is not None: - self._temp_recipe_train_dir.cleanup() + # Clean up the temporary directory if it exists and class was initialized + if hasattr(self, "__pydantic_fields_set__"): + if self._temp_recipe_train_dir is not None: + self._temp_recipe_train_dir.cleanup() def _validate_training_image_and_algorithm_name( self, training_image: Optional[str], algorithm_name: Optional[str] diff --git a/tests/unit/sagemaker/modules/train/test_model_trainer.py b/tests/unit/sagemaker/modules/train/test_model_trainer.py index 39ac8ab3f2..6c9ca7f53e 100644 --- a/tests/unit/sagemaker/modules/train/test_model_trainer.py +++ b/tests/unit/sagemaker/modules/train/test_model_trainer.py @@ -18,6 +18,7 @@ import json import os import pytest +from pydantic import ValidationError from unittest.mock import patch, MagicMock, ANY from sagemaker import image_uris @@ -1068,3 +1069,27 @@ def test_safe_configs(): # Test invalid type fails with pytest.raises(ValueError): SourceCode(entry_script=1) + + +@patch("sagemaker.modules.train.model_trainer.TemporaryDirectory") +def test_destructor_cleanup(mock_tmp_dir, modules_session): + + with pytest.raises(ValidationError): + model_trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute="test" + ) + mock_tmp_dir.cleanup.assert_not_called() + + model_trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute=DEFAULT_COMPUTE_CONFIG, + ) + model_trainer._temp_recipe_train_dir = mock_tmp_dir + mock_tmp_dir.assert_not_called() + del model_trainer + mock_tmp_dir.cleanup.assert_called_once() From bead4945896c0d09752de2d76695d2b2c4cc8410 Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos Date: Fri, 14 Feb 2025 10:41:03 -0800 Subject: [PATCH 3/6] format --- tests/unit/sagemaker/modules/train/test_model_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/sagemaker/modules/train/test_model_trainer.py b/tests/unit/sagemaker/modules/train/test_model_trainer.py index 6c9ca7f53e..29da03bcd9 100644 --- a/tests/unit/sagemaker/modules/train/test_model_trainer.py +++ b/tests/unit/sagemaker/modules/train/test_model_trainer.py @@ -1073,13 +1073,13 @@ def test_safe_configs(): @patch("sagemaker.modules.train.model_trainer.TemporaryDirectory") def test_destructor_cleanup(mock_tmp_dir, modules_session): - + with pytest.raises(ValidationError): model_trainer = ModelTrainer( training_image=DEFAULT_IMAGE, role=DEFAULT_ROLE, sagemaker_session=modules_session, - compute="test" + compute="test", ) mock_tmp_dir.cleanup.assert_not_called() From 038658377c172c8f138ecc89cf1a209da77a7bdc Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos Date: Fri, 14 Feb 2025 17:18:37 -0800 Subject: [PATCH 4/6] Update error message --- src/sagemaker/modules/train/model_trainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/modules/train/model_trainer.py b/src/sagemaker/modules/train/model_trainer.py index f454498008..21c3004e16 100644 --- a/src/sagemaker/modules/train/model_trainer.py +++ b/src/sagemaker/modules/train/model_trainer.py @@ -835,8 +835,10 @@ def _prepare_train_script( ) execute_driver = EXECUTE_BASIC_SCRIPT_DRIVER else: + # This should never be reached, as the source_code should have been validated. raise ValueError( - f"Invalid configuration, please provide a valid SourceCode: {source_code}" + f"Unsupported SourceCode or DistributedConfig: {source_code}, {distributed}." + + f"Please provide a valid configuration with atleast one of 'command' or entry_script'." ) train_script = TRAIN_SCRIPT_TEMPLATE.format( From 5b8d790b579175fd0f6adf9c85625fec9e42a16d Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos Date: Fri, 14 Feb 2025 19:36:59 -0800 Subject: [PATCH 5/6] pylint --- src/sagemaker/modules/train/model_trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/modules/train/model_trainer.py b/src/sagemaker/modules/train/model_trainer.py index 21c3004e16..a47d8f91ad 100644 --- a/src/sagemaker/modules/train/model_trainer.py +++ b/src/sagemaker/modules/train/model_trainer.py @@ -838,7 +838,8 @@ def _prepare_train_script( # This should never be reached, as the source_code should have been validated. raise ValueError( f"Unsupported SourceCode or DistributedConfig: {source_code}, {distributed}." - + f"Please provide a valid configuration with atleast one of 'command' or entry_script'." + + "Please provide a valid configuration with atleast one of 'command'" + + " or entry_script'." ) train_script = TRAIN_SCRIPT_TEMPLATE.format( From 3128278d3059b287f5a150b044c65c2a3b036820 Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos Date: Mon, 24 Feb 2025 10:36:00 -0800 Subject: [PATCH 6/6] Create BaseConfig --- src/sagemaker/modules/configs.py | 14 ++++++++------ src/sagemaker/modules/distributed.py | 11 ++++------- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/src/sagemaker/modules/configs.py b/src/sagemaker/modules/configs.py index 0d0b7a0608..458c596a36 100644 --- a/src/sagemaker/modules/configs.py +++ b/src/sagemaker/modules/configs.py @@ -74,7 +74,13 @@ ] -class SourceCode(BaseModel): +class BaseConfig(BaseModel): + """BaseConfig""" + + model_config = ConfigDict(validate_assignment=True, extra="forbid") + + +class SourceCode(BaseConfig): """SourceCode. The SourceCode class allows the user to specify the source code location, dependencies, @@ -94,8 +100,6 @@ class SourceCode(BaseModel): If not specified, entry_script must be provided. """ - model_config = ConfigDict(validate_assignment=True, extra="forbid") - source_dir: Optional[str] = None requirements: Optional[str] = None entry_script: Optional[str] = None @@ -196,7 +200,7 @@ def _to_vpc_config(self) -> shapes.VpcConfig: return shapes.VpcConfig(**filtered_dict) -class InputData(BaseModel): +class InputData(BaseConfig): """InputData. This config allows the user to specify an input data source for the training job. @@ -217,7 +221,5 @@ class InputData(BaseModel): S3DataSource object, or FileSystemDataSource object. """ - model_config = ConfigDict(validate_assignment=True, extra="forbid") - channel_name: str = None data_source: Union[str, FileSystemDataSource, S3DataSource] = None diff --git a/src/sagemaker/modules/distributed.py b/src/sagemaker/modules/distributed.py index 06297a1924..f28589de54 100644 --- a/src/sagemaker/modules/distributed.py +++ b/src/sagemaker/modules/distributed.py @@ -14,11 +14,12 @@ from __future__ import absolute_import from typing import Optional, Dict, Any, List -from pydantic import BaseModel, PrivateAttr, ConfigDict +from pydantic import PrivateAttr from sagemaker.modules.utils import safe_serialize +from sagemaker.modules.configs import BaseConfig -class SMP(BaseModel): +class SMP(BaseConfig): """SMP. This class is used for configuring the SageMaker Model Parallelism v2 parameters. @@ -53,8 +54,6 @@ class SMP(BaseModel): parallelism or expert parallelism. """ - model_config = ConfigDict(validate_assignment=True, extra="forbid") - hybrid_shard_degree: Optional[int] = None sm_activation_offloading: Optional[bool] = None activation_loading_horizon: Optional[int] = None @@ -74,11 +73,9 @@ def _to_mp_hyperparameters(self) -> Dict[str, Any]: return hyperparameters -class DistributedConfig(BaseModel): +class DistributedConfig(BaseConfig): """Base class for distributed training configurations.""" - model_config = ConfigDict(validate_assignment=True, extra="forbid") - _type: str = PrivateAttr() def model_dump(self, *args, **kwargs):