From a2d4de442d7f860579ceeb7743043c3569f72c09 Mon Sep 17 00:00:00 2001 From: Srujith Poondla Date: Fri, 16 Jan 2026 12:45:33 -0800 Subject: [PATCH] fix: Add ml.p5e.48xlarge and ml.p5.48xlarge to EFA instance lists Add ml.p5e.48xlarge to SM_EFA_NCCL_INSTANCES and SM_EFA_RDMA_INSTANCES. Add ml.p5.48xlarge to SM_EFA_RDMA_INSTANCES (was missing). Without these entries, NCCL hangs during distributed training initialization on P5e instances due to missing EFA environment variables (FI_PROVIDER, FI_EFA_USE_DEVICE_RDMA, RDMAV_FORK_SAFE). Fixes #5491 --- .../train/container_drivers/common/utils.py | 3 +++ .../bootstrap_runtime_environment.py | 3 +++ .../train/container_drivers/common/utils.py | 3 +++ .../bootstrap_runtime_environment.py | 3 +++ .../container_drivers/test_torchrun_driver.py | 26 +++++++++++++++++++ 5 files changed, 38 insertions(+) diff --git a/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/common/utils.py b/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/common/utils.py index 03146a3bbe..641809ebb8 100644 --- a/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/common/utils.py +++ b/sagemaker-core/src/sagemaker/core/modules/train/container_drivers/common/utils.py @@ -50,12 +50,15 @@ "ml.p4d.24xlarge", "ml.p4de.24xlarge", "ml.p5.48xlarge", + "ml.p5e.48xlarge", "ml.trn1.32xlarge", ] SM_EFA_RDMA_INSTANCES = [ "ml.p4d.24xlarge", "ml.p4de.24xlarge", + "ml.p5.48xlarge", + "ml.p5e.48xlarge", "ml.trn1.32xlarge", ] diff --git a/sagemaker-core/src/sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py b/sagemaker-core/src/sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py index 2c20151ed1..4b6be357f0 100644 --- a/sagemaker-core/src/sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py +++ b/sagemaker-core/src/sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py @@ -75,12 +75,15 @@ "ml.p4d.24xlarge", "ml.p4de.24xlarge", "ml.p5.48xlarge", + "ml.p5e.48xlarge", "ml.trn1.32xlarge", ] SM_EFA_RDMA_INSTANCES = [ "ml.p4d.24xlarge", "ml.p4de.24xlarge", + "ml.p5.48xlarge", + "ml.p5e.48xlarge", "ml.trn1.32xlarge", ] diff --git a/sagemaker-train/src/sagemaker/train/container_drivers/common/utils.py b/sagemaker-train/src/sagemaker/train/container_drivers/common/utils.py index 03146a3bbe..641809ebb8 100644 --- a/sagemaker-train/src/sagemaker/train/container_drivers/common/utils.py +++ b/sagemaker-train/src/sagemaker/train/container_drivers/common/utils.py @@ -50,12 +50,15 @@ "ml.p4d.24xlarge", "ml.p4de.24xlarge", "ml.p5.48xlarge", + "ml.p5e.48xlarge", "ml.trn1.32xlarge", ] SM_EFA_RDMA_INSTANCES = [ "ml.p4d.24xlarge", "ml.p4de.24xlarge", + "ml.p5.48xlarge", + "ml.p5e.48xlarge", "ml.trn1.32xlarge", ] diff --git a/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/bootstrap_runtime_environment.py b/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/bootstrap_runtime_environment.py index afe0f80012..086510341b 100644 --- a/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/bootstrap_runtime_environment.py +++ b/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/bootstrap_runtime_environment.py @@ -75,12 +75,15 @@ "ml.p4d.24xlarge", "ml.p4de.24xlarge", "ml.p5.48xlarge", + "ml.p5e.48xlarge", "ml.trn1.32xlarge", ] SM_EFA_RDMA_INSTANCES = [ "ml.p4d.24xlarge", "ml.p4de.24xlarge", + "ml.p5.48xlarge", + "ml.p5e.48xlarge", "ml.trn1.32xlarge", ] diff --git a/sagemaker-train/tests/unit/train/container_drivers/test_torchrun_driver.py b/sagemaker-train/tests/unit/train/container_drivers/test_torchrun_driver.py index 1cbfbcd872..11bd5aa127 100644 --- a/sagemaker-train/tests/unit/train/container_drivers/test_torchrun_driver.py +++ b/sagemaker-train/tests/unit/train/container_drivers/test_torchrun_driver.py @@ -17,6 +17,7 @@ import sys import json +import pytest from unittest.mock import patch, MagicMock sys.modules["utils"] = MagicMock() @@ -146,3 +147,28 @@ def test_create_commands_multi_node( command = torchrun_driver.create_commands() assert command == expected_command + + +@pytest.mark.parametrize("instance_type", ["ml.p5.48xlarge", "ml.p5e.48xlarge"]) +@patch.dict( + os.environ, + { + "SM_NETWORK_INTERFACE_NAME": "eth0", + "SM_HOST_COUNT": "2", + "SM_MASTER_ADDR": "algo-1", + "SM_MASTER_PORT": "7777", + "SM_CURRENT_HOST_RANK": "0", + "SM_HPS": json.dumps({}), + "SM_DISTRIBUTED_CONFIG": json.dumps(DUMMY_DISTRIBUTED), + "SM_ENTRY_SCRIPT": "script.py", + }, +) +def test_p5_p5e_efa_environment_setup(instance_type): + """Test that P5 and P5e instances are in EFA instance lists.""" + from sagemaker.train.container_drivers.common.utils import ( + SM_EFA_NCCL_INSTANCES, + SM_EFA_RDMA_INSTANCES, + ) + + assert instance_type in SM_EFA_NCCL_INSTANCES + assert instance_type in SM_EFA_RDMA_INSTANCES