diff --git a/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/common/utils.py b/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/common/utils.py index 03146a3bbe..641809ebb8 100644 --- a/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/common/utils.py +++ b/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/common/utils.py @@ -50,12 +50,15 @@ "ml.p4d.24xlarge", "ml.p4de.24xlarge", "ml.p5.48xlarge", + "ml.p5e.48xlarge", "ml.trn1.32xlarge", ] SM_EFA_RDMA_INSTANCES = [ "ml.p4d.24xlarge", "ml.p4de.24xlarge", + "ml.p5.48xlarge", + "ml.p5e.48xlarge", "ml.trn1.32xlarge", ] diff --git a/sagemaker-core/src/sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py b/sagemaker-core/src/sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py index 2c20151ed1..4b6be357f0 100644 --- a/sagemaker-core/src/sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +++ b/sagemaker-core/src/sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py @@ -75,12 +75,15 @@ "ml.p4d.24xlarge", "ml.p4de.24xlarge", "ml.p5.48xlarge", + "ml.p5e.48xlarge", "ml.trn1.32xlarge", ] SM_EFA_RDMA_INSTANCES = [ "ml.p4d.24xlarge", "ml.p4de.24xlarge", + "ml.p5.48xlarge", + "ml.p5e.48xlarge", "ml.trn1.32xlarge", ] diff --git a/sagemaker-train/src/sagemaker/train/container_drivers/common/utils.py b/sagemaker-train/src/sagemaker/train/container_drivers/common/utils.py index 03146a3bbe..641809ebb8 100644 --- a/sagemaker-train/src/sagemaker/train/container_drivers/common/utils.py +++ b/sagemaker-train/src/sagemaker/train/container_drivers/common/utils.py @@ -50,12 +50,15 @@ "ml.p4d.24xlarge", "ml.p4de.24xlarge", "ml.p5.48xlarge", + "ml.p5e.48xlarge", "ml.trn1.32xlarge", ] SM_EFA_RDMA_INSTANCES = [ "ml.p4d.24xlarge", "ml.p4de.24xlarge", + "ml.p5.48xlarge", + "ml.p5e.48xlarge", "ml.trn1.32xlarge", ] diff --git a/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/bootstrap_runtime_environment.py b/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/bootstrap_runtime_environment.py index afe0f80012..086510341b 100644 --- a/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/bootstrap_runtime_environment.py +++ b/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/bootstrap_runtime_environment.py @@ -75,12 +75,15 @@ "ml.p4d.24xlarge", "ml.p4de.24xlarge", "ml.p5.48xlarge", + "ml.p5e.48xlarge", "ml.trn1.32xlarge", ] SM_EFA_RDMA_INSTANCES = [ "ml.p4d.24xlarge", "ml.p4de.24xlarge", + "ml.p5.48xlarge", + "ml.p5e.48xlarge", "ml.trn1.32xlarge", ] diff --git a/sagemaker-train/tests/unit/train/container_drivers/test_torchrun_driver.py b/sagemaker-train/tests/unit/train/container_drivers/test_torchrun_driver.py index 1cbfbcd872..11bd5aa127 100644 --- a/sagemaker-train/tests/unit/train/container_drivers/test_torchrun_driver.py +++ b/sagemaker-train/tests/unit/train/container_drivers/test_torchrun_driver.py @@ -17,6 +17,7 @@ import sys import json +import pytest from unittest.mock import patch, MagicMock sys.modules["utils"] = MagicMock() @@ -146,3 +147,28 @@ def test_create_commands_multi_node( command = torchrun_driver.create_commands() assert command == expected_command + + +@pytest.mark.parametrize("instance_type", ["ml.p5.48xlarge", "ml.p5e.48xlarge"]) +@patch.dict( + os.environ, + { + "SM_NETWORK_INTERFACE_NAME": "eth0", + "SM_HOST_COUNT": "2", + "SM_MASTER_ADDR": "algo-1", + "SM_MASTER_PORT": "7777", + "SM_CURRENT_HOST_RANK": "0", + "SM_HPS": json.dumps({}), + "SM_DISTRIBUTED_CONFIG": json.dumps(DUMMY_DISTRIBUTED), + "SM_ENTRY_SCRIPT": "script.py", + }, +) +def test_p5_p5e_efa_environment_setup(instance_type): + """Test that P5 and P5e instances are in EFA instance lists.""" + from sagemaker.train.container_drivers.common.utils import ( + SM_EFA_NCCL_INSTANCES, + SM_EFA_RDMA_INSTANCES, + ) + + assert instance_type in SM_EFA_NCCL_INSTANCES + assert instance_type in SM_EFA_RDMA_INSTANCES