From 8c277e779249cdc32a8a4eb9206916fa89e68dbb Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos Date: Thu, 13 Feb 2025 16:55:12 -0800 Subject: [PATCH 01/16] feat: Make DistributedConfig Extensible --- src/sagemaker/modules/distributed.py | 66 ++++++++++++---- src/sagemaker/modules/templates.py | 13 +--- .../train/container_drivers/__init__.py | 4 +- .../container_drivers/common/__init__.py | 14 ++++ .../container_drivers/{ => common}/utils.py | 4 +- .../container_drivers/drivers/__init__.py | 14 ++++ .../{ => drivers}/basic_script_driver.py | 14 ++-- .../{ => drivers}/mpi_driver.py | 35 +++++---- .../{ => drivers}/mpi_utils.py | 14 +++- .../{ => drivers}/torchrun_driver.py | 21 ++--- .../container_drivers/scripts/__init__.py | 2 +- .../container_drivers/scripts/environment.py | 24 +++++- src/sagemaker/modules/train/model_trainer.py | 43 +++++----- .../scripts/test_enviornment.py | 4 +- .../container_drivers/test_mpi_driver.py | 58 ++++++-------- .../train/container_drivers/test_mpi_utils.py | 8 +- .../container_drivers/test_torchrun_driver.py | 78 +++++++------------ .../train/container_drivers/test_utils.py | 2 +- .../modules/train/test_model_trainer.py | 14 +++- 19 files changed, 250 insertions(+), 182 deletions(-) create mode 100644 src/sagemaker/modules/train/container_drivers/common/__init__.py rename src/sagemaker/modules/train/container_drivers/{ => common}/utils.py (98%) create mode 100644 src/sagemaker/modules/train/container_drivers/drivers/__init__.py rename src/sagemaker/modules/train/container_drivers/{ => drivers}/basic_script_driver.py (89%) rename src/sagemaker/modules/train/container_drivers/{ => drivers}/mpi_driver.py (83%) rename src/sagemaker/modules/train/container_drivers/{ => drivers}/mpi_utils.py (97%) rename src/sagemaker/modules/train/container_drivers/{ => drivers}/torchrun_driver.py (88%) diff --git a/src/sagemaker/modules/distributed.py b/src/sagemaker/modules/distributed.py index f28589de54..580ba16af9 100644 --- a/src/sagemaker/modules/distributed.py +++ b/src/sagemaker/modules/distributed.py @@ -13,10 +13,13 @@ """Distributed module.""" from __future__ import absolute_import +import os + +from abc import ABC, abstractmethod from typing import Optional, Dict, Any, List -from pydantic import PrivateAttr +from pydantic import BaseModel from sagemaker.modules.utils import safe_serialize -from sagemaker.modules.configs import BaseConfig +from sagemaker.modules.constants import SM_DRIVERS_LOCAL_PATH class SMP(BaseConfig): @@ -73,16 +76,39 @@ def _to_mp_hyperparameters(self) -> Dict[str, Any]: return hyperparameters -class DistributedConfig(BaseConfig): - """Base class for distributed training configurations.""" +class DistributedConfig(BaseModel, ABC): + """Abstract base class for distributed training configurations. + + This class defines the interface that all distributed training configurations + must implement. It provides a standardized way to specify driver scripts and + their locations for distributed training jobs. + """ + + @property + @abstractmethod + def driver_dir(self) -> str: + """Directory containing the driver script. + + This property should return the path to the directory containing + the driver script, relative to the container's working directory. + + Returns: + str: Path to directory containing the driver script + """ + pass + + @property + @abstractmethod + def driver_script(self) -> str: + """Name of the driver script. - _type: str = PrivateAttr() + This property should return the name of the Python script that implements + the distributed training driver logic. - def model_dump(self, *args, **kwargs): - """Dump the model to a dictionary.""" - result = super().model_dump(*args, **kwargs) - result["_type"] = self._type - return result + Returns: + str: Name of the driver script file + """ + pass class Torchrun(DistributedConfig): @@ -99,11 +125,17 @@ class Torchrun(DistributedConfig): The SageMaker Model Parallelism v2 parameters. """ - _type: str = PrivateAttr(default="torchrun") - process_count_per_node: Optional[int] = None smp: Optional["SMP"] = None + @property + def driver_dir(self) -> str: + return os.path.join(SM_DRIVERS_LOCAL_PATH, "drivers") + + @property + def driver_script(self) -> str: + return "torchrun_driver.py" + class MPI(DistributedConfig): """MPI. @@ -119,7 +151,13 @@ class MPI(DistributedConfig): The custom MPI options to use for the training job. """ - _type: str = PrivateAttr(default="mpi") - process_count_per_node: Optional[int] = None mpi_additional_options: Optional[List[str]] = None + + @property + def driver_dir(self) -> str: + return os.path.join(SM_DRIVERS_LOCAL_PATH, "drivers") + + @property + def driver_script(self) -> str: + return "mpi_driver.py" diff --git a/src/sagemaker/modules/templates.py b/src/sagemaker/modules/templates.py index fba60dda47..9dfef646ed 100644 --- a/src/sagemaker/modules/templates.py +++ b/src/sagemaker/modules/templates.py @@ -21,17 +21,12 @@ EXECUTE_BASIC_SCRIPT_DRIVER = """ echo "Running Basic Script driver" -$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/basic_script_driver.py +$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/drivers/basic_script_driver.py """ -EXEUCTE_TORCHRUN_DRIVER = """ -echo "Running Torchrun driver" -$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/torchrun_driver.py -""" - -EXECUTE_MPI_DRIVER = """ -echo "Running MPI driver" -$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/mpi_driver.py +EXEUCTE_DISTRIBUTED_DRIVER = """ +echo "Running {driver_name} Driver" +$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/drivers/{driver_script} """ TRAIN_SCRIPT_TEMPLATE = """ diff --git a/src/sagemaker/modules/train/container_drivers/__init__.py b/src/sagemaker/modules/train/container_drivers/__init__.py index 18557a2eb5..f59a8d25f2 100644 --- a/src/sagemaker/modules/train/container_drivers/__init__.py +++ b/src/sagemaker/modules/train/container_drivers/__init__.py @@ -10,5 +10,5 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -"""Sagemaker modules container_drivers directory.""" -from __future__ import absolute_import +"""Sagemaker modules container drivers directory.""" +from __future__ import absolute_import \ No newline at end of file diff --git a/src/sagemaker/modules/train/container_drivers/common/__init__.py b/src/sagemaker/modules/train/container_drivers/common/__init__.py new file mode 100644 index 0000000000..64e4ba0091 --- /dev/null +++ b/src/sagemaker/modules/train/container_drivers/common/__init__.py @@ -0,0 +1,14 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Sagemaker modules container drivers - common directory.""" +from __future__ import absolute_import \ No newline at end of file diff --git a/src/sagemaker/modules/train/container_drivers/utils.py b/src/sagemaker/modules/train/container_drivers/common/utils.py similarity index 98% rename from src/sagemaker/modules/train/container_drivers/utils.py rename to src/sagemaker/modules/train/container_drivers/common/utils.py index e939a6e0b8..c07aa1359a 100644 --- a/src/sagemaker/modules/train/container_drivers/utils.py +++ b/src/sagemaker/modules/train/container_drivers/common/utils.py @@ -99,10 +99,10 @@ def read_hyperparameters_json(hyperparameters_json: Dict[str, Any] = HYPERPARAME return hyperparameters_dict -def get_process_count(distributed_dict: Dict[str, Any]) -> int: +def get_process_count(process_count: Optional[int] = None) -> int: """Get the number of processes to run on each node in the training job.""" return ( - int(distributed_dict.get("process_count_per_node", 0)) + process_count or int(os.environ.get("SM_NUM_GPUS", 0)) or int(os.environ.get("SM_NUM_NEURONS", 0)) or 1 diff --git a/src/sagemaker/modules/train/container_drivers/drivers/__init__.py b/src/sagemaker/modules/train/container_drivers/drivers/__init__.py new file mode 100644 index 0000000000..68c15efa09 --- /dev/null +++ b/src/sagemaker/modules/train/container_drivers/drivers/__init__.py @@ -0,0 +1,14 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Sagemaker modules container drivers - drivers directory.""" +from __future__ import absolute_import \ No newline at end of file diff --git a/src/sagemaker/modules/train/container_drivers/basic_script_driver.py b/src/sagemaker/modules/train/container_drivers/drivers/basic_script_driver.py similarity index 89% rename from src/sagemaker/modules/train/container_drivers/basic_script_driver.py rename to src/sagemaker/modules/train/container_drivers/drivers/basic_script_driver.py index cb0278bc9f..f39563ed92 100644 --- a/src/sagemaker/modules/train/container_drivers/basic_script_driver.py +++ b/src/sagemaker/modules/train/container_drivers/drivers/basic_script_driver.py @@ -13,16 +13,19 @@ """This module is the entry point for the Basic Script Driver.""" from __future__ import absolute_import +import os import sys +import json import shlex +from pathlib import Path from typing import List -from utils import ( +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from common.utils import ( # noqa: E402 logger, get_python_executable, - read_source_code_json, - read_hyperparameters_json, execute_commands, write_failure_file, hyperparameters_to_cli_args, @@ -31,11 +34,10 @@ def create_commands() -> List[str]: """Create the commands to execute.""" - source_code = read_source_code_json() - hyperparameters = read_hyperparameters_json() + entry_script = os.environ["SM_ENTRY_SCRIPT"] + hyperparameters = json.loads(os.environ["SM_HPS"]) python_executable = get_python_executable() - entry_script = source_code["entry_script"] args = hyperparameters_to_cli_args(hyperparameters) if entry_script.endswith(".py"): commands = [python_executable, entry_script] diff --git a/src/sagemaker/modules/train/container_drivers/mpi_driver.py b/src/sagemaker/modules/train/container_drivers/drivers/mpi_driver.py similarity index 83% rename from src/sagemaker/modules/train/container_drivers/mpi_driver.py rename to src/sagemaker/modules/train/container_drivers/drivers/mpi_driver.py index dceb748cc0..c71737b277 100644 --- a/src/sagemaker/modules/train/container_drivers/mpi_driver.py +++ b/src/sagemaker/modules/train/container_drivers/drivers/mpi_driver.py @@ -16,18 +16,8 @@ import os import sys import json +from pathlib import Path -from utils import ( - logger, - read_source_code_json, - read_distributed_json, - read_hyperparameters_json, - hyperparameters_to_cli_args, - get_process_count, - execute_commands, - write_failure_file, - USER_CODE_PATH, -) from mpi_utils import ( start_sshd_daemon, bootstrap_master_node, @@ -38,6 +28,16 @@ ) +sys.path.insert(0, str(Path(__file__).parent.parent)) +from common.utils import ( # noqa: E402 + logger, + hyperparameters_to_cli_args, + get_process_count, + execute_commands, + write_failure_file, +) + + def main(): """Main function for the MPI driver script. @@ -58,9 +58,9 @@ def main(): 5. Exit """ - source_code = read_source_code_json() - distribution = read_distributed_json() - hyperparameters = read_hyperparameters_json() + entry_script = os.environ["SM_ENTRY_SCRIPT"] + distributed_config = json.loads(os.environ["SM_DISTRIBUTED_CONFIG"]) + hyperparameters = json.loads(os.environ["SM_HPS"]) sm_current_host = os.environ["SM_CURRENT_HOST"] sm_hosts = json.loads(os.environ["SM_HOSTS"]) @@ -77,7 +77,8 @@ def main(): host_list = json.loads(os.environ["SM_HOSTS"]) host_count = int(os.environ["SM_HOST_COUNT"]) - process_count = get_process_count(distribution) + process_count = int(distributed_config.get("process_count_per_node", 0)) + process_count = get_process_count(process_count) if process_count > 1: host_list = ["{}:{}".format(host, process_count) for host in host_list] @@ -86,8 +87,8 @@ def main(): host_count=host_count, host_list=host_list, num_processes=process_count, - additional_options=distribution.get("mpi_additional_options", []), - entry_script_path=os.path.join(USER_CODE_PATH, source_code["entry_script"]), + additional_options=distributed_config.get("mpi_additional_options", []), + entry_script_path=entry_script, ) args = hyperparameters_to_cli_args(hyperparameters) diff --git a/src/sagemaker/modules/train/container_drivers/mpi_utils.py b/src/sagemaker/modules/train/container_drivers/drivers/mpi_utils.py similarity index 97% rename from src/sagemaker/modules/train/container_drivers/mpi_utils.py rename to src/sagemaker/modules/train/container_drivers/drivers/mpi_utils.py index 00ddc815cd..d19bc2d78d 100644 --- a/src/sagemaker/modules/train/container_drivers/mpi_utils.py +++ b/src/sagemaker/modules/train/container_drivers/drivers/mpi_utils.py @@ -14,12 +14,22 @@ from __future__ import absolute_import import os +import sys import subprocess import time +import paramiko + +from pathlib import Path from typing import List -import paramiko -from utils import SM_EFA_NCCL_INSTANCES, SM_EFA_RDMA_INSTANCES, get_python_executable, logger +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from common.utils import ( # noqa: E402 + SM_EFA_NCCL_INSTANCES, + SM_EFA_RDMA_INSTANCES, + get_python_executable, + logger, +) FINISHED_STATUS_FILE = "/tmp/done.algo-1" READY_FILE = "/tmp/ready.%s" diff --git a/src/sagemaker/modules/train/container_drivers/torchrun_driver.py b/src/sagemaker/modules/train/container_drivers/drivers/torchrun_driver.py similarity index 88% rename from src/sagemaker/modules/train/container_drivers/torchrun_driver.py rename to src/sagemaker/modules/train/container_drivers/drivers/torchrun_driver.py index 666479ec84..30410a538a 100644 --- a/src/sagemaker/modules/train/container_drivers/torchrun_driver.py +++ b/src/sagemaker/modules/train/container_drivers/drivers/torchrun_driver.py @@ -15,20 +15,20 @@ import os import sys +import json +from pathlib import Path from typing import List, Tuple -from utils import ( +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from common.utils import ( # noqa: E402 logger, - read_source_code_json, - read_distributed_json, - read_hyperparameters_json, hyperparameters_to_cli_args, get_process_count, get_python_executable, execute_commands, write_failure_file, - USER_CODE_PATH, SM_EFA_NCCL_INSTANCES, SM_EFA_RDMA_INSTANCES, ) @@ -65,11 +65,12 @@ def setup_env(): def create_commands(): """Create the Torch Distributed command to execute""" - source_code = read_source_code_json() - distribution = read_distributed_json() - hyperparameters = read_hyperparameters_json() + entry_script = os.environ["SM_ENTRY_SCRIPT"] + distributed_config = json.loads(os.environ["SM_DISTRIBUTED_CONFIG"]) + hyperparameters = json.loads(os.environ["SM_HPS"]) - process_count = get_process_count(distribution) + process_count = int(distributed_config.get("process_count_per_node", 0)) + process_count = get_process_count(process_count) host_count = int(os.environ["SM_HOST_COUNT"]) torch_cmd = [] @@ -94,7 +95,7 @@ def create_commands(): ] ) - torch_cmd.extend([os.path.join(USER_CODE_PATH, source_code["entry_script"])]) + torch_cmd.extend([entry_script]) args = hyperparameters_to_cli_args(hyperparameters) torch_cmd += args diff --git a/src/sagemaker/modules/train/container_drivers/scripts/__init__.py b/src/sagemaker/modules/train/container_drivers/scripts/__init__.py index 1abbce4067..f04c5b17a0 100644 --- a/src/sagemaker/modules/train/container_drivers/scripts/__init__.py +++ b/src/sagemaker/modules/train/container_drivers/scripts/__init__.py @@ -10,5 +10,5 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -"""Sagemaker modules scripts directory.""" +"""Sagemaker modules container drivers - scripts directory.""" from __future__ import absolute_import diff --git a/src/sagemaker/modules/train/container_drivers/scripts/environment.py b/src/sagemaker/modules/train/container_drivers/scripts/environment.py index ea6abac425..3e405c4289 100644 --- a/src/sagemaker/modules/train/container_drivers/scripts/environment.py +++ b/src/sagemaker/modules/train/container_drivers/scripts/environment.py @@ -19,12 +19,17 @@ import json import os import sys +from pathlib import Path import logging -parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) -sys.path.insert(0, parent_dir) +sys.path.insert(0, str(Path(__file__).parent.parent)) -from utils import safe_serialize, safe_deserialize # noqa: E402 # pylint: disable=C0413 +from common.utils import ( # noqa: E402 + safe_serialize, + safe_deserialize, + read_distributed_json, + read_source_code_json, +) # Initialize logger SM_LOG_LEVEL = os.environ.get("SM_LOG_LEVEL", 20) @@ -42,6 +47,8 @@ SM_OUTPUT_DIR = "/opt/ml/output" SM_OUTPUT_FAILURE = "/opt/ml/output/failure" SM_OUTPUT_DATA_DIR = "/opt/ml/output/data" +SM_SOURCE_DIR_PATH = "/opt/ml/input/data/code" +SM_DRIVER_DIR_PATH = "/opt/ml/input/data/sm_drivers/drivers" SM_MASTER_ADDR = "algo-1" SM_MASTER_PORT = 7777 @@ -158,6 +165,17 @@ def set_env( "SM_MASTER_PORT": SM_MASTER_PORT, } + # SourceCode and DistributedConfig Environment Variables + source_code = read_source_code_json() + if source_code: + env_vars["SM_SOURCE_DIR"] = SM_SOURCE_DIR_PATH + env_vars["SM_ENTRY_SCRIPT"] = source_code.get("entry_script", "") + + distributed = read_distributed_json() + if distributed: + env_vars["SM_DRIVER_DIR"] = SM_DRIVER_DIR_PATH + env_vars["SM_DISTRIBUTED_CONFIG"] = distributed + # Data Channels channels = list(input_data_config.keys()) for channel in channels: diff --git a/src/sagemaker/modules/train/model_trainer.py b/src/sagemaker/modules/train/model_trainer.py index a47d8f91ad..ab518be4bc 100644 --- a/src/sagemaker/modules/train/model_trainer.py +++ b/src/sagemaker/modules/train/model_trainer.py @@ -70,7 +70,7 @@ ) from sagemaker.modules.local_core.local_container import _LocalContainer -from sagemaker.modules.distributed import Torchrun, MPI, DistributedConfig +from sagemaker.modules.distributed import Torchrun, DistributedConfig from sagemaker.modules.utils import ( _get_repo_name_from_image, _get_unique_name, @@ -94,8 +94,7 @@ from sagemaker.modules.templates import ( TRAIN_SCRIPT_TEMPLATE, EXECUTE_BASE_COMMANDS, - EXECUTE_MPI_DRIVER, - EXEUCTE_TORCHRUN_DRIVER, + EXEUCTE_DISTRIBUTED_DRIVER, EXECUTE_BASIC_SCRIPT_DRIVER, ) from sagemaker.telemetry.telemetry_logging import _telemetry_emitter @@ -153,7 +152,7 @@ class ModelTrainer(BaseModel): source_code (Optional[SourceCode]): The source code configuration. This is used to configure the source code for running the training job. - distributed (Optional[Union[MPI, Torchrun]]): + distributed (Optional[Union[DistributedConfig]]): The distributed runner for the training job. This is used to configure a distributed training job. If specifed, ``source_code`` must also be provided. @@ -214,7 +213,7 @@ class ModelTrainer(BaseModel): role: Optional[str] = None base_job_name: Optional[str] = None source_code: Optional[SourceCode] = None - distributed: Optional[Union[MPI, Torchrun]] = None + distributed: Optional[Union[DistributedConfig]] = None compute: Optional[Compute] = None networking: Optional[Networking] = None stopping_condition: Optional[StoppingCondition] = None @@ -537,12 +536,17 @@ def train( container_arguments = None if self.source_code: if self.training_mode == Mode.LOCAL_CONTAINER: - drivers_dir = TemporaryDirectory( - prefix=os.path.join(self.local_container_root + "/") - ) + tmp_dir = TemporaryDirectory(prefix=os.path.join(self.local_container_root + "/")) else: - drivers_dir = TemporaryDirectory() - shutil.copytree(SM_DRIVERS_LOCAL_PATH, drivers_dir.name, dirs_exist_ok=True) + tmp_dir = TemporaryDirectory() + # Copy everything under container_drivers/ to a temporary directory + shutil.copytree(SM_DRIVERS_LOCAL_PATH, tmp_dir.name, dirs_exist_ok=True) + + # If distributed is provided, overwrite code under /drivers + if self.distributed: + distributed_driver_dir = self.distributed.driver_dir + driver_dir = os.path.join(tmp_dir.name, "drivers") + shutil.copytree(distributed_driver_dir, driver_dir, dirs_exist_ok=True) # If source code is provided, create a channel for the source code # The source code will be mounted at /opt/ml/input/data/code in the container @@ -555,7 +559,7 @@ def train( input_data_config.append(source_code_channel) self._prepare_train_script( - tmp_dir=drivers_dir, + tmp_dir=tmp_dir, source_code=self.source_code, distributed=self.distributed, ) @@ -564,13 +568,13 @@ def train( mp_parameters = self.distributed.smp._to_mp_hyperparameters() string_hyper_parameters.update(mp_parameters) - self._write_source_code_json(tmp_dir=drivers_dir, source_code=self.source_code) - self._write_distributed_json(tmp_dir=drivers_dir, distributed=self.distributed) + self._write_source_code_json(tmp_dir=tmp_dir, source_code=self.source_code) + self._write_distributed_json(tmp_dir=tmp_dir, distributed=self.distributed) # Create an input channel for drivers packaged by the sdk sm_drivers_channel = self.create_input_data_channel( channel_name=SM_DRIVERS, - data_source=drivers_dir.name, + data_source=tmp_dir.name, key_prefix=input_data_key_prefix, ) input_data_config.append(sm_drivers_channel) @@ -820,13 +824,10 @@ def _prepare_train_script( if base_command: execute_driver = EXECUTE_BASE_COMMANDS.format(base_command=base_command) elif distributed: - distribution_type = distributed._type - if distribution_type == "mpi": - execute_driver = EXECUTE_MPI_DRIVER - elif distribution_type == "torchrun": - execute_driver = EXEUCTE_TORCHRUN_DRIVER - else: - raise ValueError(f"Unsupported distribution type: {distribution_type}.") + execute_driver = EXEUCTE_DISTRIBUTED_DRIVER.format( + driver_name=distributed.__class__.__name__, + driver_script=distributed.driver_script, + ) elif source_code.entry_script and not source_code.command and not distributed: if not source_code.entry_script.endswith((".py", ".sh")): raise ValueError( diff --git a/tests/unit/sagemaker/modules/train/container_drivers/scripts/test_enviornment.py b/tests/unit/sagemaker/modules/train/container_drivers/scripts/test_enviornment.py index 30d6dfdf6c..1024f60646 100644 --- a/tests/unit/sagemaker/modules/train/container_drivers/scripts/test_enviornment.py +++ b/tests/unit/sagemaker/modules/train/container_drivers/scripts/test_enviornment.py @@ -21,12 +21,10 @@ from sagemaker.modules.train.container_drivers.scripts.environment import ( set_env, - log_key_value, log_env_variables, - mask_sensitive_info, HIDDEN_VALUE, ) -from sagemaker.modules.train.container_drivers.utils import safe_serialize, safe_deserialize +from sagemaker.modules.train.container_drivers.common.utils import safe_serialize, safe_deserialize RESOURCE_CONFIG = dict( current_host="algo-1", diff --git a/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_driver.py b/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_driver.py index a1a84da1ab..a752360981 100644 --- a/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_driver.py +++ b/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_driver.py @@ -15,13 +15,14 @@ import os import sys +import json from unittest.mock import patch, MagicMock sys.modules["utils"] = MagicMock() sys.modules["mpi_utils"] = MagicMock() -from sagemaker.modules.train.container_drivers import mpi_driver # noqa: E402 +from sagemaker.modules.train.container_drivers.drivers import mpi_driver # noqa: E402 DUMMY_MPI_COMMAND = [ @@ -40,12 +41,7 @@ "script.py", ] -DUMMY_SOURCE_CODE = { - "source_code": "source_code", - "entry_script": "script.py", -} DUMMY_DISTRIBUTED = { - "_type": "mpi", "process_count_per_node": 2, "mpi_additional_options": [ "--verbose", @@ -62,17 +58,18 @@ "SM_HOSTS": '["algo-1", "algo-2"]', "SM_MASTER_ADDR": "algo-1", "SM_HOST_COUNT": "2", + "SM_HPS": json.dumps({}), + "SM_DISTRIBUTED_CONFIG": json.dumps(DUMMY_DISTRIBUTED), + "SM_ENTRY_SCRIPT": "/opt/ml/input/data/code/script.py", }, ) -@patch("sagemaker.modules.train.container_drivers.mpi_driver.read_distributed_json") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.read_source_code_json") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.write_env_vars_to_file") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.start_sshd_daemon") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.bootstrap_master_node") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.bootstrap_worker_node") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.hyperparameters_to_cli_args") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.get_mpirun_command") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.execute_commands") +@patch("sagemaker.modules.train.container_drivers.drivers.mpi_driver.write_env_vars_to_file") +@patch("sagemaker.modules.train.container_drivers.drivers.mpi_driver.start_sshd_daemon") +@patch("sagemaker.modules.train.container_drivers.drivers.mpi_driver.bootstrap_master_node") +@patch("sagemaker.modules.train.container_drivers.drivers.mpi_driver.bootstrap_worker_node") +@patch("sagemaker.modules.train.container_drivers.drivers.mpi_driver.hyperparameters_to_cli_args") +@patch("sagemaker.modules.train.container_drivers.drivers.mpi_driver.get_mpirun_command") +@patch("sagemaker.modules.train.container_drivers.drivers.mpi_driver.execute_commands") def test_mpi_driver_worker( mock_execute_commands, mock_get_mpirun_command, @@ -81,12 +78,8 @@ def test_mpi_driver_worker( mock_bootstrap_master_node, mock_start_sshd_daemon, mock_write_env_vars_to_file, - mock_read_source_code_json, - mock_read_distributed_json, ): mock_hyperparameters_to_cli_args.return_value = [] - mock_read_source_code_json.return_value = DUMMY_SOURCE_CODE - mock_read_distributed_json.return_value = DUMMY_DISTRIBUTED mpi_driver.main() @@ -106,19 +99,20 @@ def test_mpi_driver_worker( "SM_HOSTS": '["algo-1", "algo-2"]', "SM_MASTER_ADDR": "algo-1", "SM_HOST_COUNT": "2", + "SM_HPS": json.dumps({}), + "SM_DISTRIBUTED_CONFIG": json.dumps(DUMMY_DISTRIBUTED), + "SM_ENTRY_SCRIPT": "script.py", }, ) -@patch("sagemaker.modules.train.container_drivers.mpi_driver.read_distributed_json") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.read_source_code_json") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.write_env_vars_to_file") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.start_sshd_daemon") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.bootstrap_master_node") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.bootstrap_worker_node") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.get_process_count") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.hyperparameters_to_cli_args") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.get_mpirun_command") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.execute_commands") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.write_status_file_to_workers") +@patch("sagemaker.modules.train.container_drivers.drivers.mpi_driver.write_env_vars_to_file") +@patch("sagemaker.modules.train.container_drivers.drivers.mpi_driver.start_sshd_daemon") +@patch("sagemaker.modules.train.container_drivers.drivers.mpi_driver.bootstrap_master_node") +@patch("sagemaker.modules.train.container_drivers.drivers.mpi_driver.bootstrap_worker_node") +@patch("sagemaker.modules.train.container_drivers.drivers.mpi_driver.get_process_count") +@patch("sagemaker.modules.train.container_drivers.drivers.mpi_driver.hyperparameters_to_cli_args") +@patch("sagemaker.modules.train.container_drivers.drivers.mpi_driver.get_mpirun_command") +@patch("sagemaker.modules.train.container_drivers.drivers.mpi_driver.execute_commands") +@patch("sagemaker.modules.train.container_drivers.drivers.mpi_driver.write_status_file_to_workers") def test_mpi_driver_master( mock_write_status_file_to_workers, mock_execute_commands, @@ -129,12 +123,8 @@ def test_mpi_driver_master( mock_bootstrap_master_node, mock_start_sshd_daemon, mock_write_env_vars_to_file, - mock_read_source_code_config_json, - mock_read_distributed_json, ): mock_hyperparameters_to_cli_args.return_value = [] - mock_read_source_code_config_json.return_value = DUMMY_SOURCE_CODE - mock_read_distributed_json.return_value = DUMMY_DISTRIBUTED mock_get_mpirun_command.return_value = DUMMY_MPI_COMMAND mock_get_process_count.return_value = 2 mock_execute_commands.return_value = (0, "") diff --git a/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_utils.py b/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_utils.py index 2328b1ace5..6c9f2545f0 100644 --- a/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_utils.py +++ b/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_utils.py @@ -27,7 +27,7 @@ mock_utils.get_python_executable = Mock(return_value="/usr/bin/python") with patch.dict("sys.modules", {"utils": mock_utils}): - from sagemaker.modules.train.container_drivers.mpi_utils import ( + from sagemaker.modules.train.container_drivers.drivers.mpi_utils import ( CustomHostKeyPolicy, _can_connect, write_status_file_to_workers, @@ -65,7 +65,7 @@ def test_custom_host_key_policy_invalid_hostname(): @patch("paramiko.SSHClient") -@patch("sagemaker.modules.train.container_drivers.mpi_utils.logger") +@patch("sagemaker.modules.train.container_drivers.drivers.mpi_utils.logger") def test_can_connect_success(mock_logger, mock_ssh_client): """Test successful SSH connection.""" mock_client = Mock() @@ -81,7 +81,7 @@ def test_can_connect_success(mock_logger, mock_ssh_client): @patch("paramiko.SSHClient") -@patch("sagemaker.modules.train.container_drivers.mpi_utils.logger") +@patch("sagemaker.modules.train.container_drivers.drivers.mpi_utils.logger") def test_can_connect_failure(mock_logger, mock_ssh_client): """Test SSH connection failure.""" mock_client = Mock() @@ -97,7 +97,7 @@ def test_can_connect_failure(mock_logger, mock_ssh_client): @patch("subprocess.run") -@patch("sagemaker.modules.train.container_drivers.mpi_utils.logger") +@patch("sagemaker.modules.train.container_drivers.drivers.mpi_utils.logger") def test_write_status_file_to_workers_failure(mock_logger, mock_run): """Test failed status file writing to workers with retry timeout.""" mock_run.side_effect = subprocess.CalledProcessError(1, "ssh") diff --git a/tests/unit/sagemaker/modules/train/container_drivers/test_torchrun_driver.py b/tests/unit/sagemaker/modules/train/container_drivers/test_torchrun_driver.py index 4cff07a0c0..bfd26001c4 100644 --- a/tests/unit/sagemaker/modules/train/container_drivers/test_torchrun_driver.py +++ b/tests/unit/sagemaker/modules/train/container_drivers/test_torchrun_driver.py @@ -15,38 +15,36 @@ import os import sys +import json from unittest.mock import patch, MagicMock sys.modules["utils"] = MagicMock() -from sagemaker.modules.train.container_drivers import torchrun_driver # noqa: E402 +from sagemaker.modules.train.container_drivers.drivers import torchrun_driver # noqa: E402 -DUMMY_SOURCE_CODE = { - "source_code": "source_code", - "entry_script": "script.py", -} - -DUMMY_distributed = {"_type": "torchrun", "process_count_per_node": 2} +DUMMY_DISTRIBUTED = {"process_count_per_node": 2} @patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.get_python_executable", + "sagemaker.modules.train.container_drivers.drivers.torchrun_driver.get_python_executable", return_value="python3", ) @patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.pytorch_version", return_value=(2, 0) + "sagemaker.modules.train.container_drivers.drivers.torchrun_driver.pytorch_version", + return_value=(2, 0), ) def test_get_base_pytorch_command_torchrun(mock_pytorch_version, mock_get_python_executable): assert torchrun_driver.get_base_pytorch_command() == ["torchrun"] @patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.get_python_executable", + "sagemaker.modules.train.container_drivers.drivers.torchrun_driver.get_python_executable", return_value="python3", ) @patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.pytorch_version", return_value=(1, 8) + "sagemaker.modules.train.container_drivers.drivers.torchrun_driver.pytorch_version", + return_value=(1, 8), ) def test_get_base_pytorch_command_torch_distributed_launch( mock_pytorch_version, mock_get_python_executable @@ -62,38 +60,29 @@ def test_get_base_pytorch_command_torch_distributed_launch( "SM_CURRENT_INSTANCE_TYPE": "ml.p4d.24xlarge", "SM_NETWORK_INTERFACE_NAME": "eth0", "SM_HOST_COUNT": "1", + "SM_HPS": json.dumps({}), + "SM_DISTRIBUTED_CONFIG": json.dumps(DUMMY_DISTRIBUTED), + "SM_ENTRY_SCRIPT": "script.py", }, ) @patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.USER_CODE_PATH", - "/opt/ml/input/data/code", -) -@patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.get_process_count", return_value=2 + "sagemaker.modules.train.container_drivers.drivers.torchrun_driver.get_process_count", + return_value=2, ) @patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.pytorch_version", return_value=(2, 0) + "sagemaker.modules.train.container_drivers.drivers.torchrun_driver.pytorch_version", + return_value=(2, 0), ) @patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.get_base_pytorch_command", + "sagemaker.modules.train.container_drivers.drivers.torchrun_driver.get_base_pytorch_command", return_value=["torchrun"], ) @patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.read_source_code_json", - return_value=DUMMY_SOURCE_CODE, -) -@patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.read_distributed_json", - return_value=DUMMY_distributed, -) -@patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.hyperparameters_to_cli_args", + "sagemaker.modules.train.container_drivers.drivers.torchrun_driver.hyperparameters_to_cli_args", return_value=[], ) def test_create_commands_single_node( mock_hyperparameters_to_cli_args, - mock_read_distributed_json, - mock_read_source_code_json, mock_get_base_pytorch_command, mock_pytorch_version, mock_get_process_count, @@ -102,7 +91,7 @@ def test_create_commands_single_node( "torchrun", "--nnodes=1", "--nproc_per_node=2", - "/opt/ml/input/data/code/script.py", + "script.py", ] command = torchrun_driver.create_commands() @@ -118,38 +107,29 @@ def test_create_commands_single_node( "SM_MASTER_ADDR": "algo-1", "SM_MASTER_PORT": "7777", "SM_CURRENT_HOST_RANK": "0", + "SM_HPS": json.dumps({}), + "SM_DISTRIBUTED_CONFIG": json.dumps(DUMMY_DISTRIBUTED), + "SM_ENTRY_SCRIPT": "script.py", }, ) @patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.USER_CODE_PATH", - "/opt/ml/input/data/code", + "sagemaker.modules.train.container_drivers.drivers.torchrun_driver.get_process_count", + return_value=2, ) @patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.get_process_count", return_value=2 + "sagemaker.modules.train.container_drivers.drivers.torchrun_driver.pytorch_version", + return_value=(2, 0), ) @patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.pytorch_version", return_value=(2, 0) -) -@patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.get_base_pytorch_command", + "sagemaker.modules.train.container_drivers.drivers.torchrun_driver.get_base_pytorch_command", return_value=["torchrun"], ) @patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.read_source_code_json", - return_value=DUMMY_SOURCE_CODE, -) -@patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.read_distributed_json", - return_value=DUMMY_distributed, -) -@patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.hyperparameters_to_cli_args", + "sagemaker.modules.train.container_drivers.drivers.torchrun_driver.hyperparameters_to_cli_args", return_value=[], ) def test_create_commands_multi_node( mock_hyperparameters_to_cli_args, - mock_read_distributed_json, - mock_read_source_code_json, mock_get_base_pytorch_command, mock_pytorch_version, mock_get_process_count, @@ -161,7 +141,7 @@ def test_create_commands_multi_node( "--master_addr=algo-1", "--master_port=7777", "--node_rank=0", - "/opt/ml/input/data/code/script.py", + "script.py", ] command = torchrun_driver.create_commands() diff --git a/tests/unit/sagemaker/modules/train/container_drivers/test_utils.py b/tests/unit/sagemaker/modules/train/container_drivers/test_utils.py index aba97996b0..bc37a32552 100644 --- a/tests/unit/sagemaker/modules/train/container_drivers/test_utils.py +++ b/tests/unit/sagemaker/modules/train/container_drivers/test_utils.py @@ -13,7 +13,7 @@ """Container Utils Unit Tests.""" from __future__ import absolute_import -from sagemaker.modules.train.container_drivers.utils import ( +from sagemaker.modules.train.container_drivers.common.utils import ( safe_deserialize, safe_serialize, hyperparameters_to_cli_args, diff --git a/tests/unit/sagemaker/modules/train/test_model_trainer.py b/tests/unit/sagemaker/modules/train/test_model_trainer.py index 29da03bcd9..aa9f25caa3 100644 --- a/tests/unit/sagemaker/modules/train/test_model_trainer.py +++ b/tests/unit/sagemaker/modules/train/test_model_trainer.py @@ -66,7 +66,7 @@ ) from sagemaker.modules.distributed import Torchrun, SMP, MPI from sagemaker.modules.train.sm_recipes.utils import _load_recipes_cfg -from sagemaker.modules.templates import EXEUCTE_TORCHRUN_DRIVER, EXECUTE_MPI_DRIVER +from sagemaker.modules.templates import EXEUCTE_DISTRIBUTED_DRIVER from tests.unit import DATA_DIR DEFAULT_BASE_NAME = "dummy-image-job" @@ -411,7 +411,9 @@ def test_create_input_data_channel(mock_default_bucket, mock_upload_data, model_ { "source_code": DEFAULT_SOURCE_CODE, "distributed": Torchrun(), - "expected_template": EXEUCTE_TORCHRUN_DRIVER, + "expected_template": EXEUCTE_DISTRIBUTED_DRIVER.format( + driver_name="Torchrun", driver_script="torchrun_driver.py" + ), "expected_hyperparameters": {}, }, { @@ -424,7 +426,9 @@ def test_create_input_data_channel(mock_default_bucket, mock_upload_data, model_ tensor_parallel_degree=5, ) ), - "expected_template": EXEUCTE_TORCHRUN_DRIVER, + "expected_template": EXEUCTE_DISTRIBUTED_DRIVER.format( + driver_name="Torchrun", driver_script="torchrun_driver.py" + ), "expected_hyperparameters": { "mp_parameters": json.dumps( { @@ -441,7 +445,9 @@ def test_create_input_data_channel(mock_default_bucket, mock_upload_data, model_ "distributed": MPI( mpi_additional_options=["-x", "VAR1", "-x", "VAR2"], ), - "expected_template": EXECUTE_MPI_DRIVER, + "expected_template": EXEUCTE_DISTRIBUTED_DRIVER.format( + driver_name="MPI", driver_script="mpi_driver.py" + ), "expected_hyperparameters": {}, }, ], From 78175e4b39dcb3882b1504c2a8567e792b7af766 Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos Date: Thu, 13 Feb 2025 17:13:51 -0800 Subject: [PATCH 02/16] pylint --- src/sagemaker/modules/distributed.py | 2 -- src/sagemaker/modules/train/container_drivers/__init__.py | 2 +- .../modules/train/container_drivers/common/__init__.py | 2 +- .../modules/train/container_drivers/drivers/__init__.py | 2 +- .../train/container_drivers/drivers/basic_script_driver.py | 2 +- .../modules/train/container_drivers/drivers/mpi_driver.py | 2 +- .../modules/train/container_drivers/drivers/mpi_utils.py | 5 +++-- .../train/container_drivers/drivers/torchrun_driver.py | 2 +- .../modules/train/container_drivers/scripts/environment.py | 2 +- 9 files changed, 10 insertions(+), 11 deletions(-) diff --git a/src/sagemaker/modules/distributed.py b/src/sagemaker/modules/distributed.py index 580ba16af9..adda80c1be 100644 --- a/src/sagemaker/modules/distributed.py +++ b/src/sagemaker/modules/distributed.py @@ -95,7 +95,6 @@ def driver_dir(self) -> str: Returns: str: Path to directory containing the driver script """ - pass @property @abstractmethod @@ -108,7 +107,6 @@ def driver_script(self) -> str: Returns: str: Name of the driver script file """ - pass class Torchrun(DistributedConfig): diff --git a/src/sagemaker/modules/train/container_drivers/__init__.py b/src/sagemaker/modules/train/container_drivers/__init__.py index f59a8d25f2..864f3663b8 100644 --- a/src/sagemaker/modules/train/container_drivers/__init__.py +++ b/src/sagemaker/modules/train/container_drivers/__init__.py @@ -11,4 +11,4 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. """Sagemaker modules container drivers directory.""" -from __future__ import absolute_import \ No newline at end of file +from __future__ import absolute_import diff --git a/src/sagemaker/modules/train/container_drivers/common/__init__.py b/src/sagemaker/modules/train/container_drivers/common/__init__.py index 64e4ba0091..aab88c6b97 100644 --- a/src/sagemaker/modules/train/container_drivers/common/__init__.py +++ b/src/sagemaker/modules/train/container_drivers/common/__init__.py @@ -11,4 +11,4 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. """Sagemaker modules container drivers - common directory.""" -from __future__ import absolute_import \ No newline at end of file +from __future__ import absolute_import diff --git a/src/sagemaker/modules/train/container_drivers/drivers/__init__.py b/src/sagemaker/modules/train/container_drivers/drivers/__init__.py index 68c15efa09..a44e7e81a9 100644 --- a/src/sagemaker/modules/train/container_drivers/drivers/__init__.py +++ b/src/sagemaker/modules/train/container_drivers/drivers/__init__.py @@ -11,4 +11,4 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. """Sagemaker modules container drivers - drivers directory.""" -from __future__ import absolute_import \ No newline at end of file +from __future__ import absolute_import diff --git a/src/sagemaker/modules/train/container_drivers/drivers/basic_script_driver.py b/src/sagemaker/modules/train/container_drivers/drivers/basic_script_driver.py index f39563ed92..0b086a8e4f 100644 --- a/src/sagemaker/modules/train/container_drivers/drivers/basic_script_driver.py +++ b/src/sagemaker/modules/train/container_drivers/drivers/basic_script_driver.py @@ -23,7 +23,7 @@ sys.path.insert(0, str(Path(__file__).parent.parent)) -from common.utils import ( # noqa: E402 +from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611 logger, get_python_executable, execute_commands, diff --git a/src/sagemaker/modules/train/container_drivers/drivers/mpi_driver.py b/src/sagemaker/modules/train/container_drivers/drivers/mpi_driver.py index c71737b277..4f5081c670 100644 --- a/src/sagemaker/modules/train/container_drivers/drivers/mpi_driver.py +++ b/src/sagemaker/modules/train/container_drivers/drivers/mpi_driver.py @@ -29,7 +29,7 @@ sys.path.insert(0, str(Path(__file__).parent.parent)) -from common.utils import ( # noqa: E402 +from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611 logger, hyperparameters_to_cli_args, get_process_count, diff --git a/src/sagemaker/modules/train/container_drivers/drivers/mpi_utils.py b/src/sagemaker/modules/train/container_drivers/drivers/mpi_utils.py index d19bc2d78d..ec9e1fcef9 100644 --- a/src/sagemaker/modules/train/container_drivers/drivers/mpi_utils.py +++ b/src/sagemaker/modules/train/container_drivers/drivers/mpi_utils.py @@ -17,14 +17,15 @@ import sys import subprocess import time -import paramiko from pathlib import Path from typing import List +import paramiko + sys.path.insert(0, str(Path(__file__).parent.parent)) -from common.utils import ( # noqa: E402 +from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611 SM_EFA_NCCL_INSTANCES, SM_EFA_RDMA_INSTANCES, get_python_executable, diff --git a/src/sagemaker/modules/train/container_drivers/drivers/torchrun_driver.py b/src/sagemaker/modules/train/container_drivers/drivers/torchrun_driver.py index 30410a538a..ed77e17235 100644 --- a/src/sagemaker/modules/train/container_drivers/drivers/torchrun_driver.py +++ b/src/sagemaker/modules/train/container_drivers/drivers/torchrun_driver.py @@ -22,7 +22,7 @@ sys.path.insert(0, str(Path(__file__).parent.parent)) -from common.utils import ( # noqa: E402 +from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611 logger, hyperparameters_to_cli_args, get_process_count, diff --git a/src/sagemaker/modules/train/container_drivers/scripts/environment.py b/src/sagemaker/modules/train/container_drivers/scripts/environment.py index 3e405c4289..0ce24c55d8 100644 --- a/src/sagemaker/modules/train/container_drivers/scripts/environment.py +++ b/src/sagemaker/modules/train/container_drivers/scripts/environment.py @@ -24,7 +24,7 @@ sys.path.insert(0, str(Path(__file__).parent.parent)) -from common.utils import ( # noqa: E402 +from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611 safe_serialize, safe_deserialize, read_distributed_json, From 6a6c5417d5ec509e44e90a6e50f17d8dd15d34a0 Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos Date: Thu, 13 Feb 2025 18:51:38 -0800 Subject: [PATCH 03/16] Include none types when creating config jsons for safer reference --- src/sagemaker/modules/distributed.py | 20 ++++++++++++ .../container_drivers/drivers/mpi_driver.py | 4 +-- .../drivers/torchrun_driver.py | 2 +- src/sagemaker/modules/train/model_trainer.py | 4 +-- .../scripts/test_enviornment.py | 31 ++++++++++++++++++- .../train/container_drivers/test_utils.py | 17 ++++++++++ 6 files changed, 72 insertions(+), 6 deletions(-) diff --git a/src/sagemaker/modules/distributed.py b/src/sagemaker/modules/distributed.py index adda80c1be..02933f6dc1 100644 --- a/src/sagemaker/modules/distributed.py +++ b/src/sagemaker/modules/distributed.py @@ -128,10 +128,20 @@ class Torchrun(DistributedConfig): @property def driver_dir(self) -> str: + """Directory containing the driver script. + + Returns: + str: Path to directory containing the driver script + """ return os.path.join(SM_DRIVERS_LOCAL_PATH, "drivers") @property def driver_script(self) -> str: + """Name of the driver script. + + Returns: + str: Name of the driver script file + """ return "torchrun_driver.py" @@ -154,8 +164,18 @@ class MPI(DistributedConfig): @property def driver_dir(self) -> str: + """Directory containing the driver script. + + Returns: + str: Path to directory containing the driver script + """ return os.path.join(SM_DRIVERS_LOCAL_PATH, "drivers") @property def driver_script(self) -> str: + """Name of the driver script. + + Returns: + str: Name of the driver script + """ return "mpi_driver.py" diff --git a/src/sagemaker/modules/train/container_drivers/drivers/mpi_driver.py b/src/sagemaker/modules/train/container_drivers/drivers/mpi_driver.py index 4f5081c670..9946272617 100644 --- a/src/sagemaker/modules/train/container_drivers/drivers/mpi_driver.py +++ b/src/sagemaker/modules/train/container_drivers/drivers/mpi_driver.py @@ -77,7 +77,7 @@ def main(): host_list = json.loads(os.environ["SM_HOSTS"]) host_count = int(os.environ["SM_HOST_COUNT"]) - process_count = int(distributed_config.get("process_count_per_node", 0)) + process_count = int(distributed_config["process_count_per_node"] or 0) process_count = get_process_count(process_count) if process_count > 1: @@ -87,7 +87,7 @@ def main(): host_count=host_count, host_list=host_list, num_processes=process_count, - additional_options=distributed_config.get("mpi_additional_options", []), + additional_options=distributed_config["mpi_additional_options"] or [], entry_script_path=entry_script, ) diff --git a/src/sagemaker/modules/train/container_drivers/drivers/torchrun_driver.py b/src/sagemaker/modules/train/container_drivers/drivers/torchrun_driver.py index ed77e17235..7fcfabe05d 100644 --- a/src/sagemaker/modules/train/container_drivers/drivers/torchrun_driver.py +++ b/src/sagemaker/modules/train/container_drivers/drivers/torchrun_driver.py @@ -69,7 +69,7 @@ def create_commands(): distributed_config = json.loads(os.environ["SM_DISTRIBUTED_CONFIG"]) hyperparameters = json.loads(os.environ["SM_HPS"]) - process_count = int(distributed_config.get("process_count_per_node", 0)) + process_count = int(distributed_config["process_count_per_node"] or 0) process_count = get_process_count(process_count) host_count = int(os.environ["SM_HOST_COUNT"]) diff --git a/src/sagemaker/modules/train/model_trainer.py b/src/sagemaker/modules/train/model_trainer.py index ab518be4bc..284dd41bf5 100644 --- a/src/sagemaker/modules/train/model_trainer.py +++ b/src/sagemaker/modules/train/model_trainer.py @@ -776,7 +776,7 @@ def _write_source_code_json(self, tmp_dir: TemporaryDirectory, source_code: Sour """Write the source code configuration to a JSON file.""" file_path = os.path.join(tmp_dir.name, SOURCE_CODE_JSON) with open(file_path, "w") as f: - dump = source_code.model_dump(exclude_none=True) if source_code else {} + dump = source_code.model_dump() if source_code else {} f.write(json.dumps(dump)) def _write_distributed_json( @@ -787,7 +787,7 @@ def _write_distributed_json( """Write the distributed runner configuration to a JSON file.""" file_path = os.path.join(tmp_dir.name, DISTRIBUTED_JSON) with open(file_path, "w") as f: - dump = distributed.model_dump(exclude_none=True) if distributed else {} + dump = distributed.model_dump() if distributed else {} f.write(json.dumps(dump)) def _prepare_train_script( diff --git a/tests/unit/sagemaker/modules/train/container_drivers/scripts/test_enviornment.py b/tests/unit/sagemaker/modules/train/container_drivers/scripts/test_enviornment.py index 1024f60646..a3f54ad439 100644 --- a/tests/unit/sagemaker/modules/train/container_drivers/scripts/test_enviornment.py +++ b/tests/unit/sagemaker/modules/train/container_drivers/scripts/test_enviornment.py @@ -73,6 +73,15 @@ }, } +SOURCE_CODE = { + "source_dir": "code", + "entry_script": "train.py", +} + +DISTRIBUTED_CONFIG = { + "process_count_per_node": 2, +} + OUTPUT_FILE = os.path.join(os.path.dirname(__file__), "sm_training.env") # flake8: noqa @@ -87,6 +96,10 @@ export SM_LOG_LEVEL='20' export SM_MASTER_ADDR='algo-1' export SM_MASTER_PORT='7777' +export SM_SOURCE_DIR='/opt/ml/input/data/code' +export SM_ENTRY_SCRIPT='train.py' +export SM_DRIVER_DIR='/opt/ml/input/data/sm_drivers/drivers' +export SM_DISTRIBUTED_CONFIG='{"process_count_per_node": 2}' export SM_CHANNEL_TRAIN='/opt/ml/input/data/train' export SM_CHANNEL_VALIDATION='/opt/ml/input/data/validation' export SM_CHANNELS='["train", "validation"]' @@ -110,6 +123,14 @@ """ +@patch( + "sagemaker.modules.train.container_drivers.scripts.environment.read_source_code_json", + return_value=SOURCE_CODE, +) +@patch( + "sagemaker.modules.train.container_drivers.scripts.environment.read_distributed_json", + return_value=DISTRIBUTED_CONFIG, +) @patch("sagemaker.modules.train.container_drivers.scripts.environment.num_cpus", return_value=8) @patch("sagemaker.modules.train.container_drivers.scripts.environment.num_gpus", return_value=0) @patch("sagemaker.modules.train.container_drivers.scripts.environment.num_neurons", return_value=0) @@ -122,7 +143,13 @@ side_effect=safe_deserialize, ) def test_set_env( - mock_safe_deserialize, mock_safe_serialize, mock_num_cpus, mock_num_gpus, mock_num_neurons + mock_safe_deserialize, + mock_safe_serialize, + mock_num_neurons, + mock_num_gpus, + mock_num_cpus, + mock_read_distributed_json, + mock_read_source_code_json, ): with patch.dict(os.environ, {"TRAINING_JOB_NAME": "test-job"}): set_env( @@ -135,6 +162,8 @@ def test_set_env( mock_num_cpus.assert_called_once() mock_num_gpus.assert_called_once() mock_num_neurons.assert_called_once() + mock_read_distributed_json.assert_called_once() + mock_read_source_code_json.assert_called_once() with open(OUTPUT_FILE, "r") as f: env_file = f.read().strip() diff --git a/tests/unit/sagemaker/modules/train/container_drivers/test_utils.py b/tests/unit/sagemaker/modules/train/container_drivers/test_utils.py index bc37a32552..beff06e8d8 100644 --- a/tests/unit/sagemaker/modules/train/container_drivers/test_utils.py +++ b/tests/unit/sagemaker/modules/train/container_drivers/test_utils.py @@ -12,11 +12,13 @@ # language governing permissions and limitations under the License. """Container Utils Unit Tests.""" from __future__ import absolute_import +import os from sagemaker.modules.train.container_drivers.common.utils import ( safe_deserialize, safe_serialize, hyperparameters_to_cli_args, + get_process_count, ) SM_HPS = { @@ -119,3 +121,18 @@ def test_safe_serialize_empty_data(): assert safe_serialize("") == "" assert safe_serialize([]) == "[]" assert safe_serialize({}) == "{}" + + +def test_get_process_count(): + assert get_process_count() == 1 + assert get_process_count(2) == 2 + os.environ["SM_NUM_GPUS"] = "4" + assert get_process_count() == 4 + os.environ["SM_NUM_GPUS"] = "0" + os.environ["SM_NUM_NEURONS"] = "8" + assert get_process_count() == 8 + os.environ["SM_NUM_NEURONS"] = "0" + assert get_process_count() == 1 + del os.environ["SM_NUM_GPUS"] + del os.environ["SM_NUM_NEURONS"] + assert get_process_count() == 1 From b3ae566b4e3bfa486c6df404ddc833732121d69e Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos Date: Thu, 13 Feb 2025 19:42:02 -0800 Subject: [PATCH 04/16] fix: update test to account for changes --- tests/unit/sagemaker/modules/train/test_model_trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unit/sagemaker/modules/train/test_model_trainer.py b/tests/unit/sagemaker/modules/train/test_model_trainer.py index aa9f25caa3..1cf9c9c863 100644 --- a/tests/unit/sagemaker/modules/train/test_model_trainer.py +++ b/tests/unit/sagemaker/modules/train/test_model_trainer.py @@ -504,19 +504,19 @@ def test_train_with_distributed_config( assert os.path.exists(expected_runner_json_path) with open(expected_runner_json_path, "r") as f: runner_json_content = f.read() - assert test_case["distributed"].model_dump(exclude_none=True) == ( + assert test_case["distributed"].model_dump() == ( json.loads(runner_json_content) ) assert os.path.exists(expected_source_code_json_path) with open(expected_source_code_json_path, "r") as f: source_code_json_content = f.read() - assert test_case["source_code"].model_dump(exclude_none=True) == ( + assert test_case["source_code"].model_dump() == ( json.loads(source_code_json_content) ) assert os.path.exists(expected_source_code_json_path) with open(expected_source_code_json_path, "r") as f: source_code_json_content = f.read() - assert test_case["source_code"].model_dump(exclude_none=True) == ( + assert test_case["source_code"].model_dump() == ( json.loads(source_code_json_content) ) finally: From a62e862454e0c50e029f019734cc8d6a2334cb49 Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos Date: Thu, 13 Feb 2025 19:47:22 -0800 Subject: [PATCH 05/16] format --- .../sagemaker/modules/train/test_model_trainer.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/tests/unit/sagemaker/modules/train/test_model_trainer.py b/tests/unit/sagemaker/modules/train/test_model_trainer.py index 1cf9c9c863..9d17a0db00 100644 --- a/tests/unit/sagemaker/modules/train/test_model_trainer.py +++ b/tests/unit/sagemaker/modules/train/test_model_trainer.py @@ -504,21 +504,15 @@ def test_train_with_distributed_config( assert os.path.exists(expected_runner_json_path) with open(expected_runner_json_path, "r") as f: runner_json_content = f.read() - assert test_case["distributed"].model_dump() == ( - json.loads(runner_json_content) - ) + assert test_case["distributed"].model_dump() == (json.loads(runner_json_content)) assert os.path.exists(expected_source_code_json_path) with open(expected_source_code_json_path, "r") as f: source_code_json_content = f.read() - assert test_case["source_code"].model_dump() == ( - json.loads(source_code_json_content) - ) + assert test_case["source_code"].model_dump() == (json.loads(source_code_json_content)) assert os.path.exists(expected_source_code_json_path) with open(expected_source_code_json_path, "r") as f: source_code_json_content = f.read() - assert test_case["source_code"].model_dump() == ( - json.loads(source_code_json_content) - ) + assert test_case["source_code"].model_dump() == (json.loads(source_code_json_content)) finally: shutil.rmtree(tmp_dir.name) assert not os.path.exists(tmp_dir.name) From 757e258eff38082f86237af5c1bea9549a0c59bc Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos Date: Fri, 14 Feb 2025 13:56:35 -0800 Subject: [PATCH 06/16] Add integ test --- tests/data/modules/custom_drivers/driver.py | 34 +++++++++++++++++++ tests/data/modules/scripts/entry_script.py | 19 +++++++++++ .../modules/train/test_model_trainer.py | 34 ++++++++++++++++++- 3 files changed, 86 insertions(+), 1 deletion(-) create mode 100644 tests/data/modules/custom_drivers/driver.py create mode 100644 tests/data/modules/scripts/entry_script.py diff --git a/tests/data/modules/custom_drivers/driver.py b/tests/data/modules/custom_drivers/driver.py new file mode 100644 index 0000000000..e2a1fc7a52 --- /dev/null +++ b/tests/data/modules/custom_drivers/driver.py @@ -0,0 +1,34 @@ +import json +import os +import subprocess +import sys + + +def main(): + driver_config = json.loads(os.environ["SM_DISTRIBUTED_CONFIG"]) + process_count_per_node = driver_config["process_count_per_node"] + assert process_count_per_node != None + + hps = json.loads(os.environ["SM_HPS"]) + assert hps != None + assert isinstance(hps, dict) + + source_dir = os.environ["SM_SOURCE_DIR"] + assert source_dir == "/opt/ml/input/data/code" + sm_drivers_dir = os.environ["SM_DRIVER_DIR"] + assert sm_drivers_dir == "/opt/ml/input/data/sm_drivers/drivers" + + entry_script = os.environ["SM_ENTRY_SCRIPT"] + assert entry_script != None + + python = sys.executable + + command = [python, entry_script] + print(f"Running command: {command}") + subprocess.run(command, check=True) + + +if __name__ == "__main__": + print("Running custom driver script") + main() + print("Finished running custom driver script") diff --git a/tests/data/modules/scripts/entry_script.py b/tests/data/modules/scripts/entry_script.py new file mode 100644 index 0000000000..3c972bd956 --- /dev/null +++ b/tests/data/modules/scripts/entry_script.py @@ -0,0 +1,19 @@ +import json +import os +import time + + +def main(): + hps = json.loads(os.environ["SM_HPS"]) + assert hps != None + print(f"Hyperparameters: {hps}") + + print("Running pseudo training script") + for epochs in range(hps["epochs"]): + print(f"Epoch: {epochs}") + time.sleep(1) + print("Finished running pseudo training script") + + +if __name__ == "__main__": + main() diff --git a/tests/integ/sagemaker/modules/train/test_model_trainer.py b/tests/integ/sagemaker/modules/train/test_model_trainer.py index cd298402b2..cb5cfb10f1 100644 --- a/tests/integ/sagemaker/modules/train/test_model_trainer.py +++ b/tests/integ/sagemaker/modules/train/test_model_trainer.py @@ -17,7 +17,7 @@ from sagemaker.modules.train import ModelTrainer from sagemaker.modules.configs import SourceCode, Compute -from sagemaker.modules.distributed import MPI, Torchrun +from sagemaker.modules.distributed import MPI, Torchrun, DistributedConfig EXPECTED_HYPERPARAMETERS = { "integer": 1, @@ -106,3 +106,35 @@ def test_hp_contract_torchrun_script(modules_sagemaker_session): ) model_trainer.train() + + +def test_custom_distributed_driver(modules_sagemaker_session): + class CustomDriver(DistributedConfig): + process_count_per_node: int = None + + @property + def driver_dir(self) -> str: + return f"{DATA_DIR}/modules/custom_drivers" + + @property + def driver_script(self) -> str: + return "driver.py" + + source_code = SourceCode( + source_dir=f"{DATA_DIR}/modules/scripts", + entry_script="entry_script.py", + ) + + hyperparameters = {"epochs": 10} + + custom_driver = CustomDriver(process_count_per_node=2) + + model_trainer = ModelTrainer( + sagemaker_session=modules_sagemaker_session, + training_image=DEFAULT_CPU_IMAGE, + hyperparameters=hyperparameters, + source_code=source_code, + distributed=custom_driver, + base_job_name="custom-distributed-driver", + ) + model_trainer.train() From fb4fa7e59eb3968998678a53b4e32051a575a428 Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos Date: Tue, 25 Feb 2025 13:35:29 -0800 Subject: [PATCH 07/16] pylint --- src/sagemaker/modules/distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/modules/distributed.py b/src/sagemaker/modules/distributed.py index 02933f6dc1..31465e79d0 100644 --- a/src/sagemaker/modules/distributed.py +++ b/src/sagemaker/modules/distributed.py @@ -17,9 +17,9 @@ from abc import ABC, abstractmethod from typing import Optional, Dict, Any, List -from pydantic import BaseModel from sagemaker.modules.utils import safe_serialize from sagemaker.modules.constants import SM_DRIVERS_LOCAL_PATH +from sagemaker.modules.configs import BaseConfig class SMP(BaseConfig): From 3e23078089c309e57a2781a79fd1aff830b12d37 Mon Sep 17 00:00:00 2001 From: ci Date: Tue, 25 Feb 2025 22:49:56 +0000 Subject: [PATCH 08/16] prepare release v2.240.0 --- CHANGELOG.md | 21 +++++++++++++++++++++ VERSION | 2 +- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 446b4db426..742e46d127 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,26 @@ # Changelog +## v2.240.0 (2025-02-25) + +### Features + + * Add support for TGI Neuronx 0.0.27 and HF PT 2.3.0 image in PySDK + +### Bug Fixes and Other Changes + + * Remove main function entrypoint in ModelBuilder dependency manager. + * forbid extras in Configs + * altconfig hubcontent and reenable integ test + * Merge branch 'master-rba' into local_merge + * py_version doc fixes + * Add backward compatbility for RecordSerializer and RecordDeserializer + * update image_uri_configs 02-21-2025 06:18:10 PST + * update image_uri_configs 02-20-2025 06:18:08 PST + +### Documentation Changes + + * Removed a line about python version requirements of training script which can misguide users. + ## v2.239.3 (2025-02-19) ### Bug Fixes and Other Changes diff --git a/VERSION b/VERSION index f61726ee77..d7ff33493f 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.239.4.dev0 +2.240.0 From 0433e5c2b6fbbb16ae9d0a6999baa9b688672f40 Mon Sep 17 00:00:00 2001 From: ci Date: Tue, 25 Feb 2025 22:50:01 +0000 Subject: [PATCH 09/16] update development version to v2.240.1.dev0 --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index d7ff33493f..1b1f3a78e8 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.240.0 +2.240.1.dev0 From 4d1bceedee8e4c616f27b7fe566d480ac5333395 Mon Sep 17 00:00:00 2001 From: pintaoz-aws <167920275+pintaoz-aws@users.noreply.github.com> Date: Fri, 28 Feb 2025 12:17:41 -0800 Subject: [PATCH 10/16] Fix key error in _send_metrics() (#5068) Co-authored-by: pintaoz --- src/sagemaker/experiments/_metrics.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/experiments/_metrics.py b/src/sagemaker/experiments/_metrics.py index 31dd679cc8..026e73e8a6 100644 --- a/src/sagemaker/experiments/_metrics.py +++ b/src/sagemaker/experiments/_metrics.py @@ -197,8 +197,8 @@ def _send_metrics(self, metrics): response = self._metrics_client.batch_put_metrics(**request) errors = response["Errors"] if "Errors" in response else None if errors: - message = errors[0]["Message"] - raise Exception(f'{len(errors)} errors with message "{message}"') + error_code = errors[0]["Code"] + raise Exception(f'{len(errors)} errors with error code "{error_code}"') def _construct_batch_put_metrics_request(self, batch): """Creates dictionary object used as request to metrics service.""" From 7cdf93358096c1b846716f4a0eec99c3a37b93a9 Mon Sep 17 00:00:00 2001 From: Keshav Chandak Date: Sat, 1 Mar 2025 03:32:18 +0530 Subject: [PATCH 11/16] fix: Added check for the presence of model package group before creating one (#5063) Co-authored-by: Keshav Chandak --- src/sagemaker/session.py | 56 +++++++++++++++++++++++++++--- tests/unit/test_session.py | 70 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 122 insertions(+), 4 deletions(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index c6a2014ae5..b2398e03d1 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -4347,11 +4347,59 @@ def submit(request): if model_package_group_name is not None and not model_package_group_name.startswith( "arn:" ): - _create_resource( - lambda: self.sagemaker_client.create_model_package_group( - ModelPackageGroupName=request["ModelPackageGroupName"] + is_model_package_group_present = False + try: + model_package_groups_response = self.search( + resource="ModelPackageGroup", + search_expression={ + "Filters": [ + { + "Name": "ModelPackageGroupName", + "Value": request["ModelPackageGroupName"], + "Operator": "Equals", + } + ], + }, + ) + if len(model_package_groups_response.get("Results")) > 0: + is_model_package_group_present = True + except Exception: # pylint: disable=W0703 + model_package_groups = [] + model_package_groups_response = self.sagemaker_client.list_model_package_groups( + NameContains=request["ModelPackageGroupName"], + ) + model_package_groups = ( + model_package_groups + + model_package_groups_response["ModelPackageGroupSummaryList"] + ) + next_token = model_package_groups_response.get("NextToken") + + while next_token is not None and next_token != "": + model_package_groups_response = ( + self.sagemaker_client.list_model_package_groups( + NameContains=request["ModelPackageGroupName"], NextToken=next_token + ) + ) + model_package_groups = ( + model_package_groups + + model_package_groups_response["ModelPackageGroupSummaryList"] + ) + next_token = model_package_groups_response.get("NextToken") + + filtered_model_package_group = list( + filter( + lambda mpg: mpg.get("ModelPackageGroupName") + == request["ModelPackageGroupName"], + model_package_groups, + ) + ) + is_model_package_group_present = len(filtered_model_package_group) > 0 + if not is_model_package_group_present: + _create_resource( + lambda: self.sagemaker_client.create_model_package_group( + ModelPackageGroupName=request["ModelPackageGroupName"] + ) ) - ) if "SourceUri" in request and request["SourceUri"] is not None: # Remove inference spec from request if the # given source uri can lead to auto-population of it diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index d2d2c3bcfb..f873e9b14c 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -5006,6 +5006,7 @@ def test_create_model_package_with_sagemaker_config_injection(sagemaker_session) domain = "COMPUTER_VISION" task = "IMAGE_CLASSIFICATION" sample_payload_url = "s3://test-bucket/model" + sagemaker_session.sagemaker_client.search.return_value = {"Results": []} sagemaker_session.create_model_package_from_containers( containers=containers, content_types=content_types, @@ -5094,6 +5095,8 @@ def test_create_model_package_from_containers_with_source_uri_and_inference_spec skip_model_validation = "All" source_uri = "dummy-source-uri" + sagemaker_session.sagemaker_client.search.return_value = {"Results": []} + created_versioned_mp_arn = ( "arn:aws:sagemaker:us-west-2:123456789123:model-package/unit-test-package-version/1" ) @@ -5149,6 +5152,7 @@ def test_create_model_package_from_containers_with_source_uri_for_unversioned_mp approval_status = ("Approved",) skip_model_validation = "All" source_uri = "dummy-source-uri" + sagemaker_session.sagemaker_client.search.return_value = {"Results": []} with pytest.raises( ValueError, @@ -5221,6 +5225,8 @@ def test_create_model_package_from_containers_with_source_uri_set_to_mp(sagemake return_value={"ModelPackageArn": created_versioned_mp_arn} ) + sagemaker_session.sagemaker_client.search.return_value = {"Results": []} + sagemaker_session.create_model_package_from_containers( model_package_group_name=model_package_group_name, containers=containers, @@ -5443,6 +5449,7 @@ def test_create_model_package_from_containers_without_instance_types(sagemaker_s approval_status = ("Approved",) description = "description" customer_metadata_properties = {"key1": "value1"} + sagemaker_session.sagemaker_client.search.return_value = {"Results": []} sagemaker_session.create_model_package_from_containers( containers=containers, content_types=content_types, @@ -5510,6 +5517,7 @@ def test_create_model_package_from_containers_with_one_instance_types( approval_status = ("Approved",) description = "description" customer_metadata_properties = {"key1": "value1"} + sagemaker_session.sagemaker_client.search.return_value = {"Results": []} sagemaker_session.create_model_package_from_containers( containers=containers, content_types=content_types, @@ -7183,3 +7191,65 @@ def test_delete_hub_content_reference(sagemaker_session): } sagemaker_session.sagemaker_client.delete_hub_content_reference.assert_called_with(**request) + + +def test_create_model_package_from_containers_to_create_mpg_if_not_present_without_search( + sagemaker_session, +): + sagemaker_session.sagemaker_client.search.side_effect = Exception() + sagemaker_session.sagemaker_client.search.return_value = {} + sagemaker_session.sagemaker_client.list_model_package_groups.side_effect = [ + { + "ModelPackageGroupSummaryList": [{"ModelPackageGroupName": "mock-mpg"}], + "NextToken": "NextToken", + }, + {"ModelPackageGroupSummaryList": [{"ModelPackageGroupName": "mock-mpg-test"}]}, + ] + sagemaker_session.create_model_package_from_containers( + source_uri="mock-source-uri", model_package_group_name="mock-mpg" + ) + sagemaker_session.sagemaker_client.create_model_package_group.assert_not_called() + sagemaker_session.create_model_package_from_containers( + source_uri="mock-source-uri", + model_package_group_name="arn:aws:sagemaker:us-east-1:215995503607:model-package-group/mock-mpg", + ) + sagemaker_session.sagemaker_client.create_model_package_group.assert_not_called() + sagemaker_session.sagemaker_client.list_model_package_groups.side_effect = [ + {"ModelPackageGroupSummaryList": []} + ] + sagemaker_session.create_model_package_from_containers( + source_uri="mock-source-uri", model_package_group_name="mock-mpg" + ) + sagemaker_session.sagemaker_client.create_model_package_group.assert_called_with( + ModelPackageGroupName="mock-mpg" + ) + + +def test_create_model_package_from_containers_to_create_mpg_if_not_present(sagemaker_session): + # with search api + sagemaker_session.sagemaker_client.search.return_value = { + "Results": [ + { + "ModelPackageGroup": { + "ModelPackageGroupName": "mock-mpg", + "ModelPackageGroupArn": "arn:aws:sagemaker:us-west-2:123456789012:model-package-group/mock-mpg", + } + } + ] + } + sagemaker_session.create_model_package_from_containers( + source_uri="mock-source-uri", model_package_group_name="mock-mpg" + ) + sagemaker_session.sagemaker_client.create_model_package_group.assert_not_called() + sagemaker_session.create_model_package_from_containers( + source_uri="mock-source-uri", + model_package_group_name="arn:aws:sagemaker:us-east-1:215995503607:model-package-group/mock-mpg", + ) + sagemaker_session.sagemaker_client.create_model_package_group.assert_not_called() + sagemaker_session.sagemaker_client.search.return_value = {"Results": []} + sagemaker_session.create_model_package_from_containers( + source_uri="mock-source-uri", model_package_group_name="mock-mpg" + ) + sagemaker_session.sagemaker_client.create_model_package_group.assert_called_with( + ModelPackageGroupName="mock-mpg" + ) From af7d7decc4c6c76958e85060ac6570e7406685b6 Mon Sep 17 00:00:00 2001 From: pintaoz-aws <167920275+pintaoz-aws@users.noreply.github.com> Date: Mon, 3 Mar 2025 10:42:28 -0800 Subject: [PATCH 12/16] Use sagemaker session's s3_resource in download_folder (#5064) Co-authored-by: pintaoz --- src/sagemaker/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index c575b1eeb6..1a75a3a5cc 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -397,8 +397,7 @@ def download_folder(bucket_name, prefix, target, sagemaker_session): sagemaker_session (sagemaker.session.Session): a sagemaker session to interact with S3. """ - boto_session = sagemaker_session.boto_session - s3 = boto_session.resource("s3", region_name=boto_session.region_name) + s3 = sagemaker_session.s3_resource prefix = prefix.lstrip("/") From e49095bcbcd3a4ab610393a23a292278280cf033 Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos Date: Tue, 4 Mar 2025 18:26:25 -0800 Subject: [PATCH 13/16] remove union --- src/sagemaker/modules/train/model_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/modules/train/model_trainer.py b/src/sagemaker/modules/train/model_trainer.py index 284dd41bf5..09e14c650b 100644 --- a/src/sagemaker/modules/train/model_trainer.py +++ b/src/sagemaker/modules/train/model_trainer.py @@ -152,7 +152,7 @@ class ModelTrainer(BaseModel): source_code (Optional[SourceCode]): The source code configuration. This is used to configure the source code for running the training job. - distributed (Optional[Union[DistributedConfig]]): + distributed (Optional[DistributedConfig]): The distributed runner for the training job. This is used to configure a distributed training job. If specifed, ``source_code`` must also be provided. @@ -213,7 +213,7 @@ class ModelTrainer(BaseModel): role: Optional[str] = None base_job_name: Optional[str] = None source_code: Optional[SourceCode] = None - distributed: Optional[Union[DistributedConfig]] = None + distributed: Optional[DistributedConfig] = None compute: Optional[Compute] = None networking: Optional[Networking] = None stopping_condition: Optional[StoppingCondition] = None From c5b1e12f25dfe762057917001fe8c433f4b067fe Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos Date: Tue, 4 Mar 2025 18:31:14 -0800 Subject: [PATCH 14/16] fix merge artifact --- src/sagemaker/modules/distributed.py | 2 +- tests/integ/sagemaker/modules/train/test_model_trainer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/modules/distributed.py b/src/sagemaker/modules/distributed.py index 31465e79d0..b7b2f12568 100644 --- a/src/sagemaker/modules/distributed.py +++ b/src/sagemaker/modules/distributed.py @@ -76,7 +76,7 @@ def _to_mp_hyperparameters(self) -> Dict[str, Any]: return hyperparameters -class DistributedConfig(BaseModel, ABC): +class DistributedConfig(BaseConfig, ABC): """Abstract base class for distributed training configurations. This class defines the interface that all distributed training configurations diff --git a/tests/integ/sagemaker/modules/train/test_model_trainer.py b/tests/integ/sagemaker/modules/train/test_model_trainer.py index 2517040e69..a1e3106553 100644 --- a/tests/integ/sagemaker/modules/train/test_model_trainer.py +++ b/tests/integ/sagemaker/modules/train/test_model_trainer.py @@ -142,7 +142,6 @@ def driver_script(self) -> str: source_code = SourceCode( source_dir=f"{DATA_DIR}/modules/scripts", - requirements= entry_script="entry_script.py", ) @@ -158,3 +157,4 @@ def driver_script(self) -> str: distributed=custom_driver, base_job_name="custom-distributed-driver", ) + model_trainer.train() From b562d69452028b519363ba2df0fa5a328b5f95b8 Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos Date: Tue, 4 Mar 2025 20:24:08 -0800 Subject: [PATCH 15/16] Change dir path to distributed_drivers --- src/sagemaker/modules/distributed.py | 4 +-- src/sagemaker/modules/templates.py | 4 +-- .../__init__.py | 0 .../basic_script_driver.py | 0 .../mpi_driver.py | 0 .../mpi_utils.py | 0 .../torchrun_driver.py | 0 .../container_drivers/scripts/environment.py | 4 +-- src/sagemaker/modules/train/model_trainer.py | 2 +- tests/data/modules/custom_drivers/driver.py | 4 +-- .../scripts/test_enviornment.py | 2 +- .../container_drivers/test_mpi_driver.py | 34 +++++++++---------- .../train/container_drivers/test_mpi_utils.py | 8 ++--- .../container_drivers/test_torchrun_driver.py | 28 ++++++++------- 14 files changed, 46 insertions(+), 44 deletions(-) rename src/sagemaker/modules/train/container_drivers/{drivers => distributed_drivers}/__init__.py (100%) rename src/sagemaker/modules/train/container_drivers/{drivers => distributed_drivers}/basic_script_driver.py (100%) rename src/sagemaker/modules/train/container_drivers/{drivers => distributed_drivers}/mpi_driver.py (100%) rename src/sagemaker/modules/train/container_drivers/{drivers => distributed_drivers}/mpi_utils.py (100%) rename src/sagemaker/modules/train/container_drivers/{drivers => distributed_drivers}/torchrun_driver.py (100%) diff --git a/src/sagemaker/modules/distributed.py b/src/sagemaker/modules/distributed.py index b7b2f12568..f248b9b77c 100644 --- a/src/sagemaker/modules/distributed.py +++ b/src/sagemaker/modules/distributed.py @@ -133,7 +133,7 @@ def driver_dir(self) -> str: Returns: str: Path to directory containing the driver script """ - return os.path.join(SM_DRIVERS_LOCAL_PATH, "drivers") + return os.path.join(SM_DRIVERS_LOCAL_PATH, "distributed_drivers") @property def driver_script(self) -> str: @@ -169,7 +169,7 @@ def driver_dir(self) -> str: Returns: str: Path to directory containing the driver script """ - return os.path.join(SM_DRIVERS_LOCAL_PATH, "drivers") + return os.path.join(SM_DRIVERS_LOCAL_PATH, "distributed_drivers") @property def driver_script(self) -> str: diff --git a/src/sagemaker/modules/templates.py b/src/sagemaker/modules/templates.py index 9dfef646ed..d888b7bcb9 100644 --- a/src/sagemaker/modules/templates.py +++ b/src/sagemaker/modules/templates.py @@ -21,12 +21,12 @@ EXECUTE_BASIC_SCRIPT_DRIVER = """ echo "Running Basic Script driver" -$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/drivers/basic_script_driver.py +$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/distributed_drivers/basic_script_driver.py """ EXEUCTE_DISTRIBUTED_DRIVER = """ echo "Running {driver_name} Driver" -$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/drivers/{driver_script} +$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/distributed_drivers/{driver_script} """ TRAIN_SCRIPT_TEMPLATE = """ diff --git a/src/sagemaker/modules/train/container_drivers/drivers/__init__.py b/src/sagemaker/modules/train/container_drivers/distributed_drivers/__init__.py similarity index 100% rename from src/sagemaker/modules/train/container_drivers/drivers/__init__.py rename to src/sagemaker/modules/train/container_drivers/distributed_drivers/__init__.py diff --git a/src/sagemaker/modules/train/container_drivers/drivers/basic_script_driver.py b/src/sagemaker/modules/train/container_drivers/distributed_drivers/basic_script_driver.py similarity index 100% rename from src/sagemaker/modules/train/container_drivers/drivers/basic_script_driver.py rename to src/sagemaker/modules/train/container_drivers/distributed_drivers/basic_script_driver.py diff --git a/src/sagemaker/modules/train/container_drivers/drivers/mpi_driver.py b/src/sagemaker/modules/train/container_drivers/distributed_drivers/mpi_driver.py similarity index 100% rename from src/sagemaker/modules/train/container_drivers/drivers/mpi_driver.py rename to src/sagemaker/modules/train/container_drivers/distributed_drivers/mpi_driver.py diff --git a/src/sagemaker/modules/train/container_drivers/drivers/mpi_utils.py b/src/sagemaker/modules/train/container_drivers/distributed_drivers/mpi_utils.py similarity index 100% rename from src/sagemaker/modules/train/container_drivers/drivers/mpi_utils.py rename to src/sagemaker/modules/train/container_drivers/distributed_drivers/mpi_utils.py diff --git a/src/sagemaker/modules/train/container_drivers/drivers/torchrun_driver.py b/src/sagemaker/modules/train/container_drivers/distributed_drivers/torchrun_driver.py similarity index 100% rename from src/sagemaker/modules/train/container_drivers/drivers/torchrun_driver.py rename to src/sagemaker/modules/train/container_drivers/distributed_drivers/torchrun_driver.py diff --git a/src/sagemaker/modules/train/container_drivers/scripts/environment.py b/src/sagemaker/modules/train/container_drivers/scripts/environment.py index 0ce24c55d8..897b1f8af4 100644 --- a/src/sagemaker/modules/train/container_drivers/scripts/environment.py +++ b/src/sagemaker/modules/train/container_drivers/scripts/environment.py @@ -48,7 +48,7 @@ SM_OUTPUT_FAILURE = "/opt/ml/output/failure" SM_OUTPUT_DATA_DIR = "/opt/ml/output/data" SM_SOURCE_DIR_PATH = "/opt/ml/input/data/code" -SM_DRIVER_DIR_PATH = "/opt/ml/input/data/sm_drivers/drivers" +SM_DISTRIBUTED_DRIVER_DIR_PATH = "/opt/ml/input/data/sm_drivers/distributed_drivers" SM_MASTER_ADDR = "algo-1" SM_MASTER_PORT = 7777 @@ -173,7 +173,7 @@ def set_env( distributed = read_distributed_json() if distributed: - env_vars["SM_DRIVER_DIR"] = SM_DRIVER_DIR_PATH + env_vars["SM_DISTRIBUTED_DRIVER_DIR"] = SM_DISTRIBUTED_DRIVER_DIR_PATH env_vars["SM_DISTRIBUTED_CONFIG"] = distributed # Data Channels diff --git a/src/sagemaker/modules/train/model_trainer.py b/src/sagemaker/modules/train/model_trainer.py index 5d3c04050e..aef6e3312b 100644 --- a/src/sagemaker/modules/train/model_trainer.py +++ b/src/sagemaker/modules/train/model_trainer.py @@ -569,7 +569,7 @@ def train( # If distributed is provided, overwrite code under /drivers if self.distributed: distributed_driver_dir = self.distributed.driver_dir - driver_dir = os.path.join(tmp_dir.name, "drivers") + driver_dir = os.path.join(tmp_dir.name, "distributed_drivers") shutil.copytree(distributed_driver_dir, driver_dir, dirs_exist_ok=True) # If source code is provided, create a channel for the source code diff --git a/tests/data/modules/custom_drivers/driver.py b/tests/data/modules/custom_drivers/driver.py index e2a1fc7a52..3395b80da9 100644 --- a/tests/data/modules/custom_drivers/driver.py +++ b/tests/data/modules/custom_drivers/driver.py @@ -15,8 +15,8 @@ def main(): source_dir = os.environ["SM_SOURCE_DIR"] assert source_dir == "/opt/ml/input/data/code" - sm_drivers_dir = os.environ["SM_DRIVER_DIR"] - assert sm_drivers_dir == "/opt/ml/input/data/sm_drivers/drivers" + sm_drivers_dir = os.environ["SM_DISTRIBUTED_DRIVER_DIR"] + assert sm_drivers_dir == "/opt/ml/input/data/sm_drivers/distributed_drivers" entry_script = os.environ["SM_ENTRY_SCRIPT"] assert entry_script != None diff --git a/tests/unit/sagemaker/modules/train/container_drivers/scripts/test_enviornment.py b/tests/unit/sagemaker/modules/train/container_drivers/scripts/test_enviornment.py index a3f54ad439..fe4fa08825 100644 --- a/tests/unit/sagemaker/modules/train/container_drivers/scripts/test_enviornment.py +++ b/tests/unit/sagemaker/modules/train/container_drivers/scripts/test_enviornment.py @@ -98,7 +98,7 @@ export SM_MASTER_PORT='7777' export SM_SOURCE_DIR='/opt/ml/input/data/code' export SM_ENTRY_SCRIPT='train.py' -export SM_DRIVER_DIR='/opt/ml/input/data/sm_drivers/drivers' +export SM_DISTRIBUTED_DRIVER_DIR='/opt/ml/input/data/sm_drivers/distributed_drivers' export SM_DISTRIBUTED_CONFIG='{"process_count_per_node": 2}' export SM_CHANNEL_TRAIN='/opt/ml/input/data/train' export SM_CHANNEL_VALIDATION='/opt/ml/input/data/validation' diff --git a/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_driver.py b/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_driver.py index a752360981..4eb7512d22 100644 --- a/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_driver.py +++ b/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_driver.py @@ -22,7 +22,7 @@ sys.modules["utils"] = MagicMock() sys.modules["mpi_utils"] = MagicMock() -from sagemaker.modules.train.container_drivers.drivers import mpi_driver # noqa: E402 +from sagemaker.modules.train.container_drivers.distributed_drivers import mpi_driver # noqa: E402 DUMMY_MPI_COMMAND = [ @@ -63,13 +63,13 @@ "SM_ENTRY_SCRIPT": "/opt/ml/input/data/code/script.py", }, ) -@patch("sagemaker.modules.train.container_drivers.drivers.mpi_driver.write_env_vars_to_file") -@patch("sagemaker.modules.train.container_drivers.drivers.mpi_driver.start_sshd_daemon") -@patch("sagemaker.modules.train.container_drivers.drivers.mpi_driver.bootstrap_master_node") -@patch("sagemaker.modules.train.container_drivers.drivers.mpi_driver.bootstrap_worker_node") -@patch("sagemaker.modules.train.container_drivers.drivers.mpi_driver.hyperparameters_to_cli_args") -@patch("sagemaker.modules.train.container_drivers.drivers.mpi_driver.get_mpirun_command") -@patch("sagemaker.modules.train.container_drivers.drivers.mpi_driver.execute_commands") +@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.write_env_vars_to_file") +@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.start_sshd_daemon") +@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.bootstrap_master_node") +@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.bootstrap_worker_node") +@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.hyperparameters_to_cli_args") +@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.get_mpirun_command") +@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.execute_commands") def test_mpi_driver_worker( mock_execute_commands, mock_get_mpirun_command, @@ -104,15 +104,15 @@ def test_mpi_driver_worker( "SM_ENTRY_SCRIPT": "script.py", }, ) -@patch("sagemaker.modules.train.container_drivers.drivers.mpi_driver.write_env_vars_to_file") -@patch("sagemaker.modules.train.container_drivers.drivers.mpi_driver.start_sshd_daemon") -@patch("sagemaker.modules.train.container_drivers.drivers.mpi_driver.bootstrap_master_node") -@patch("sagemaker.modules.train.container_drivers.drivers.mpi_driver.bootstrap_worker_node") -@patch("sagemaker.modules.train.container_drivers.drivers.mpi_driver.get_process_count") -@patch("sagemaker.modules.train.container_drivers.drivers.mpi_driver.hyperparameters_to_cli_args") -@patch("sagemaker.modules.train.container_drivers.drivers.mpi_driver.get_mpirun_command") -@patch("sagemaker.modules.train.container_drivers.drivers.mpi_driver.execute_commands") -@patch("sagemaker.modules.train.container_drivers.drivers.mpi_driver.write_status_file_to_workers") +@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.write_env_vars_to_file") +@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.start_sshd_daemon") +@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.bootstrap_master_node") +@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.bootstrap_worker_node") +@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.get_process_count") +@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.hyperparameters_to_cli_args") +@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.get_mpirun_command") +@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.execute_commands") +@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.write_status_file_to_workers") def test_mpi_driver_master( mock_write_status_file_to_workers, mock_execute_commands, diff --git a/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_utils.py b/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_utils.py index 6c9f2545f0..35208d708a 100644 --- a/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_utils.py +++ b/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_utils.py @@ -27,7 +27,7 @@ mock_utils.get_python_executable = Mock(return_value="/usr/bin/python") with patch.dict("sys.modules", {"utils": mock_utils}): - from sagemaker.modules.train.container_drivers.drivers.mpi_utils import ( + from sagemaker.modules.train.container_drivers.distributed_drivers.mpi_utils import ( CustomHostKeyPolicy, _can_connect, write_status_file_to_workers, @@ -65,7 +65,7 @@ def test_custom_host_key_policy_invalid_hostname(): @patch("paramiko.SSHClient") -@patch("sagemaker.modules.train.container_drivers.drivers.mpi_utils.logger") +@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_utils.logger") def test_can_connect_success(mock_logger, mock_ssh_client): """Test successful SSH connection.""" mock_client = Mock() @@ -81,7 +81,7 @@ def test_can_connect_success(mock_logger, mock_ssh_client): @patch("paramiko.SSHClient") -@patch("sagemaker.modules.train.container_drivers.drivers.mpi_utils.logger") +@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_utils.logger") def test_can_connect_failure(mock_logger, mock_ssh_client): """Test SSH connection failure.""" mock_client = Mock() @@ -97,7 +97,7 @@ def test_can_connect_failure(mock_logger, mock_ssh_client): @patch("subprocess.run") -@patch("sagemaker.modules.train.container_drivers.drivers.mpi_utils.logger") +@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_utils.logger") def test_write_status_file_to_workers_failure(mock_logger, mock_run): """Test failed status file writing to workers with retry timeout.""" mock_run.side_effect = subprocess.CalledProcessError(1, "ssh") diff --git a/tests/unit/sagemaker/modules/train/container_drivers/test_torchrun_driver.py b/tests/unit/sagemaker/modules/train/container_drivers/test_torchrun_driver.py index bfd26001c4..2568346158 100644 --- a/tests/unit/sagemaker/modules/train/container_drivers/test_torchrun_driver.py +++ b/tests/unit/sagemaker/modules/train/container_drivers/test_torchrun_driver.py @@ -21,17 +21,19 @@ sys.modules["utils"] = MagicMock() -from sagemaker.modules.train.container_drivers.drivers import torchrun_driver # noqa: E402 +from sagemaker.modules.train.container_drivers.distributed_drivers import ( # noqa: E402 + torchrun_driver, +) DUMMY_DISTRIBUTED = {"process_count_per_node": 2} @patch( - "sagemaker.modules.train.container_drivers.drivers.torchrun_driver.get_python_executable", + "sagemaker.modules.train.container_drivers.distributed_drivers.torchrun_driver.get_python_executable", return_value="python3", ) @patch( - "sagemaker.modules.train.container_drivers.drivers.torchrun_driver.pytorch_version", + "sagemaker.modules.train.container_drivers.distributed_drivers.torchrun_driver.pytorch_version", return_value=(2, 0), ) def test_get_base_pytorch_command_torchrun(mock_pytorch_version, mock_get_python_executable): @@ -39,11 +41,11 @@ def test_get_base_pytorch_command_torchrun(mock_pytorch_version, mock_get_python @patch( - "sagemaker.modules.train.container_drivers.drivers.torchrun_driver.get_python_executable", + "sagemaker.modules.train.container_drivers.distributed_drivers.torchrun_driver.get_python_executable", return_value="python3", ) @patch( - "sagemaker.modules.train.container_drivers.drivers.torchrun_driver.pytorch_version", + "sagemaker.modules.train.container_drivers.distributed_drivers.torchrun_driver.pytorch_version", return_value=(1, 8), ) def test_get_base_pytorch_command_torch_distributed_launch( @@ -66,19 +68,19 @@ def test_get_base_pytorch_command_torch_distributed_launch( }, ) @patch( - "sagemaker.modules.train.container_drivers.drivers.torchrun_driver.get_process_count", + "sagemaker.modules.train.container_drivers.distributed_drivers.torchrun_driver.get_process_count", return_value=2, ) @patch( - "sagemaker.modules.train.container_drivers.drivers.torchrun_driver.pytorch_version", + "sagemaker.modules.train.container_drivers.distributed_drivers.torchrun_driver.pytorch_version", return_value=(2, 0), ) @patch( - "sagemaker.modules.train.container_drivers.drivers.torchrun_driver.get_base_pytorch_command", + "sagemaker.modules.train.container_drivers.distributed_drivers.torchrun_driver.get_base_pytorch_command", return_value=["torchrun"], ) @patch( - "sagemaker.modules.train.container_drivers.drivers.torchrun_driver.hyperparameters_to_cli_args", + "sagemaker.modules.train.container_drivers.distributed_drivers.torchrun_driver.hyperparameters_to_cli_args", return_value=[], ) def test_create_commands_single_node( @@ -113,19 +115,19 @@ def test_create_commands_single_node( }, ) @patch( - "sagemaker.modules.train.container_drivers.drivers.torchrun_driver.get_process_count", + "sagemaker.modules.train.container_drivers.distributed_drivers.torchrun_driver.get_process_count", return_value=2, ) @patch( - "sagemaker.modules.train.container_drivers.drivers.torchrun_driver.pytorch_version", + "sagemaker.modules.train.container_drivers.distributed_drivers.torchrun_driver.pytorch_version", return_value=(2, 0), ) @patch( - "sagemaker.modules.train.container_drivers.drivers.torchrun_driver.get_base_pytorch_command", + "sagemaker.modules.train.container_drivers.distributed_drivers.torchrun_driver.get_base_pytorch_command", return_value=["torchrun"], ) @patch( - "sagemaker.modules.train.container_drivers.drivers.torchrun_driver.hyperparameters_to_cli_args", + "sagemaker.modules.train.container_drivers.distributed_drivers.torchrun_driver.hyperparameters_to_cli_args", return_value=[], ) def test_create_commands_multi_node( From 2c3c2e90cfa6add84890737f060246ba43c0b6f9 Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos Date: Tue, 4 Mar 2025 23:06:11 -0800 Subject: [PATCH 16/16] update paths --- .../container_drivers/test_mpi_driver.py | 44 ++++++++++++++----- 1 file changed, 33 insertions(+), 11 deletions(-) diff --git a/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_driver.py b/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_driver.py index 4eb7512d22..bf51db8285 100644 --- a/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_driver.py +++ b/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_driver.py @@ -63,12 +63,22 @@ "SM_ENTRY_SCRIPT": "/opt/ml/input/data/code/script.py", }, ) -@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.write_env_vars_to_file") +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.write_env_vars_to_file" +) @patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.start_sshd_daemon") -@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.bootstrap_master_node") -@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.bootstrap_worker_node") -@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.hyperparameters_to_cli_args") -@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.get_mpirun_command") +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.bootstrap_master_node" +) +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.bootstrap_worker_node" +) +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.hyperparameters_to_cli_args" +) +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.get_mpirun_command" +) @patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.execute_commands") def test_mpi_driver_worker( mock_execute_commands, @@ -104,15 +114,27 @@ def test_mpi_driver_worker( "SM_ENTRY_SCRIPT": "script.py", }, ) -@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.write_env_vars_to_file") +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.write_env_vars_to_file" +) @patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.start_sshd_daemon") -@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.bootstrap_master_node") -@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.bootstrap_worker_node") +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.bootstrap_master_node" +) +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.bootstrap_worker_node" +) @patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.get_process_count") -@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.hyperparameters_to_cli_args") -@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.get_mpirun_command") +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.hyperparameters_to_cli_args" +) +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.get_mpirun_command" +) @patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.execute_commands") -@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.write_status_file_to_workers") +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.write_status_file_to_workers" +) def test_mpi_driver_master( mock_write_status_file_to_workers, mock_execute_commands,