diff --git a/src/sagemaker/hyperpod/cli/cluster_stack_utils.py b/src/sagemaker/hyperpod/cli/cluster_stack_utils.py index 5d3c7ad5..3d35eda7 100644 --- a/src/sagemaker/hyperpod/cli/cluster_stack_utils.py +++ b/src/sagemaker/hyperpod/cli/cluster_stack_utils.py @@ -11,7 +11,7 @@ All other functions are private implementation details and should not be used directly. """ -import boto3 +from sagemaker.hyperpod.common.utils import create_boto3_client import click import logging from typing import List, Dict, Any, Optional, Tuple, Callable @@ -53,7 +53,7 @@ def _get_stack_resources(stack_name: str, region: str, logger: Optional[logging. if logger: logger.debug(f"Fetching resources for stack '{stack_name}' in region '{region}'") - cf_client = boto3.client('cloudformation', region_name=region) + cf_client = create_boto3_client('cloudformation', region_name=region) try: resources_response = cf_client.list_stack_resources(StackName=stack_name) resources = resources_response.get('StackResourceSummaries', []) @@ -208,7 +208,7 @@ def _handle_partial_deletion_failure(stack_name: str, region: str, original_reso message_callback("✗ Stack deletion failed") try: - cf_client = boto3.client('cloudformation', region_name=region) + cf_client = create_boto3_client('cloudformation', region_name=region) current_resources_response = cf_client.list_stack_resources(StackName=stack_name) current_resources = current_resources_response.get('StackResourceSummaries', []) @@ -273,7 +273,7 @@ def _perform_stack_deletion(stack_name: str, region: str, retain_list: List[str] if retain_list: logger.debug(f"Retaining resources: {retain_list}") - cf_client = boto3.client('cloudformation', region_name=region) + cf_client = create_boto3_client('cloudformation', region_name=region) delete_params = {'StackName': stack_name} if retain_list: diff --git a/src/sagemaker/hyperpod/cli/commands/cluster.py b/src/sagemaker/hyperpod/cli/commands/cluster.py index 289a827a..0bdc111a 100644 --- a/src/sagemaker/hyperpod/cli/commands/cluster.py +++ b/src/sagemaker/hyperpod/cli/commands/cluster.py @@ -68,6 +68,7 @@ ) from sagemaker.hyperpod.common.utils import ( get_cluster_context as get_cluster_context_util, + _resolve_region, ) from sagemaker.hyperpod.observability.utils import ( get_monitoring_config, @@ -171,7 +172,8 @@ def list_cluster( user_agent_extra=get_user_agent_extra_suffix() ) - session = boto3.Session(region_name=region) if region else boto3.Session() + region = _resolve_region(region) + session = boto3.Session(region_name=region) if not validator.validate_aws_credential(session): logger.error("Failed to list clusters capacity due to invalid AWS credentials.") sys.exit(1) @@ -581,7 +583,8 @@ def timeout_handler(signum, frame): botocore_config = botocore.config.Config( user_agent_extra=get_user_agent_extra_suffix() ) - session = boto3.Session(region_name=region) if region else boto3.Session() + region = _resolve_region(region) + session = boto3.Session(region_name=region) if not validator.validate_aws_credential(session): logger.error("Cannot connect to HyperPod cluster due to aws credentials error") sys.exit(1) @@ -708,7 +711,8 @@ def describe_cluster(cluster_name: str, debug: bool, region: str) -> None: botocore_config = botocore.config.Config( user_agent_extra=get_user_agent_extra_suffix() ) - session = boto3.Session(region_name=region) if region else boto3.Session() + region = _resolve_region(region) + session = boto3.Session(region_name=region) sm_client = get_sagemaker_client(session, botocore_config) # Get cluster details using SageMaker client diff --git a/src/sagemaker/hyperpod/cli/commands/inference.py b/src/sagemaker/hyperpod/cli/commands/inference.py index 49c14fa0..400aa818 100644 --- a/src/sagemaker/hyperpod/cli/commands/inference.py +++ b/src/sagemaker/hyperpod/cli/commands/inference.py @@ -1,6 +1,6 @@ import click import json -import boto3 +from sagemaker.hyperpod.common.utils import create_boto3_client from typing import Optional from tabulate import tabulate @@ -89,7 +89,7 @@ def custom_invoke( except json.JSONDecodeError: raise click.ClickException("--body must be valid JSON") - rt = boto3.client("sagemaker-runtime") + rt = create_boto3_client("sagemaker-runtime") try: endpoint = Endpoint.get(endpoint_name) diff --git a/src/sagemaker/hyperpod/cli/service/get_logs.py b/src/sagemaker/hyperpod/cli/service/get_logs.py index 91ca739f..084f8fc8 100644 --- a/src/sagemaker/hyperpod/cli/service/get_logs.py +++ b/src/sagemaker/hyperpod/cli/service/get_logs.py @@ -11,7 +11,7 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from typing import Optional -import boto3 +from sagemaker.hyperpod.common.utils import create_boto3_client from sagemaker.hyperpod.cli.clients.kubernetes_client import ( KubernetesClient, @@ -129,7 +129,7 @@ def get_log_url(self, eks_cluster_name, region, node_name, pod_name, namespace, return console_prefix + log_group_prefix + log_stream def is_container_insights_addon_enabled(self, eks_cluster_name): - response = boto3.client("eks").list_addons(clusterName=eks_cluster_name, maxResults=50) + response = create_boto3_client("eks").list_addons(clusterName=eks_cluster_name, maxResults=50) if AMAZON_ClOUDWATCH_OBSERVABILITY in response.get('addons', []): return True else: diff --git a/src/sagemaker/hyperpod/common/cli_decorators.py b/src/sagemaker/hyperpod/common/cli_decorators.py index 50642684..6c7a8dc9 100644 --- a/src/sagemaker/hyperpod/common/cli_decorators.py +++ b/src/sagemaker/hyperpod/common/cli_decorators.py @@ -417,10 +417,10 @@ def _is_valid_jumpstart_model_id(model_id: str) -> bool: Uses same SageMaker API that's already being called during creation. """ try: - import boto3 from botocore.exceptions import ClientError + from sagemaker.hyperpod.common.utils import create_boto3_client - sagemaker_client = boto3.client('sagemaker') + sagemaker_client = create_boto3_client('sagemaker') # Use same API call that's failing in the current code sagemaker_client.describe_hub_content( diff --git a/src/sagemaker/hyperpod/common/utils.py b/src/sagemaker/hyperpod/common/utils.py index 43028892..95b34bea 100644 --- a/src/sagemaker/hyperpod/common/utils.py +++ b/src/sagemaker/hyperpod/common/utils.py @@ -130,7 +130,7 @@ def get_region_from_eks_arn(arn: str) -> str: def get_jumpstart_model_instance_types(model_id, region) -> List[str]: - client = boto3.client("sagemaker", region_name=region) + client = create_boto3_client("sagemaker", region_name=region) response = client.describe_hub_content( HubName="SageMakerPublicHub", HubContentType="Model", HubContentName=model_id @@ -145,7 +145,7 @@ def get_jumpstart_model_instance_types(model_id, region) -> List[str]: def get_cluster_instance_types(cluster, region) -> set: instance_types = set({}) - sagemaker_client = boto3.client("sagemaker", region_name=region) + sagemaker_client = create_boto3_client("sagemaker", region_name=region) response = sagemaker_client.describe_cluster(ClusterName=cluster) for instance_group in response["InstanceGroups"]: @@ -278,7 +278,7 @@ def set_cluster_context( logger = logging.getLogger(__name__) logger = setup_logging(logger) - client = boto3.client("sagemaker", region_name=region) + client = create_boto3_client("sagemaker", region_name=region) if not is_eks_orchestrator(client, cluster_name): raise ValueError(f"Cluster '{cluster_name}' is not EKS-orchestrated. HyperPod CLI only supports EKS-orchestrated clusters.") @@ -309,7 +309,7 @@ def get_cluster_context(): def list_clusters( region: Optional[str] = None, ): - client = boto3.client("sagemaker", region_name=region) + client = create_boto3_client("sagemaker", region_name=region) clusters = client.list_clusters() eks_clusters = [] @@ -330,7 +330,7 @@ def get_current_cluster(): region = get_region_from_eks_arn(current_context) hyperpod_clusters = list_clusters(region)["Eks"] - client = boto3.client("sagemaker", region_name=region) + client = create_boto3_client("sagemaker", region_name=region) for cluster_name in hyperpod_clusters: if not is_eks_orchestrator(client, cluster_name): @@ -356,18 +356,42 @@ def get_current_region(): except: return get_aws_default_region() +def _resolve_region(region_name: Optional[str] = None) -> Optional[str]: + """Resolve AWS region using the following fallback order: + 1. Explicit region_name parameter (from --region flag) + 2. AWS_REGION env var + 3. AWS_DEFAULT_REGION / ~/.aws/config (standard boto3 chain) + 4. Region from current cluster context (last resort) + """ + if region_name: + return region_name + + aws_region_env = os.environ.get('AWS_REGION') + if aws_region_env: + return aws_region_env + + boto3_region = boto3.session.Session().region_name + if boto3_region: + return boto3_region + + try: + return get_region_from_eks_arn(get_cluster_context()) + except Exception: + return None + def create_boto3_client(service_name: str, region_name: Optional[str] = None, **kwargs): """Create a boto3 client with smart region handling. Args: service_name (str): AWS service name (e.g., 'sagemaker', 'eks') - region_name (Optional[str]): AWS region. If None, uses AWS default + region_name (Optional[str]): AWS region. If None, resolved via + AWS_REGION env var, boto3 defaults, or cluster context. **kwargs: Additional boto3 client parameters Returns: boto3 client instance """ - return boto3.client(service_name, region_name=region_name or boto3.session.Session().region_name, **kwargs) + return boto3.client(service_name, region_name=_resolve_region(region_name), **kwargs) def region_to_az_ids(region_code: str): """ diff --git a/src/sagemaker/hyperpod/inference/jumpstart_public_hub_visualization_utils.py b/src/sagemaker/hyperpod/inference/jumpstart_public_hub_visualization_utils.py index 2547d57a..ec103c27 100644 --- a/src/sagemaker/hyperpod/inference/jumpstart_public_hub_visualization_utils.py +++ b/src/sagemaker/hyperpod/inference/jumpstart_public_hub_visualization_utils.py @@ -15,7 +15,7 @@ from __future__ import absolute_import import time -import boto3 +from sagemaker.hyperpod.common.utils import create_boto3_client import itables import pandas import logging @@ -32,8 +32,8 @@ class ModelDataLoader: MAX_RESULTS_PER_CALL = 100 def __init__(self, region: str, hub_name: str = "SageMakerPublicHub"): - config = Config(region_name=region, retries={"max_attempts": 10, "mode": "adaptive"}) - self.client = boto3.client("sagemaker", config=config) + config = Config(retries={"max_attempts": 10, "mode": "adaptive"}) + self.client = create_boto3_client("sagemaker", region_name=region, config=config) self.hub_name = hub_name self.all_data = [] self.next_token = None diff --git a/src/sagemaker/hyperpod/observability/utils.py b/src/sagemaker/hyperpod/observability/utils.py index 7bb31a55..48072a18 100644 --- a/src/sagemaker/hyperpod/observability/utils.py +++ b/src/sagemaker/hyperpod/observability/utils.py @@ -1,16 +1,16 @@ import re from typing import Optional -import boto3 import yaml +from sagemaker.hyperpod.common.utils import create_boto3_client from sagemaker.hyperpod.observability.constants import AMAZON_HYPERPOD_OBSERVABILITY, GRAFANA_DASHBOARD_UID from sagemaker.hyperpod.observability.MonitoringConfig import MonitoringConfig # ToDO : move below functions to SDK util method instead of importing from CLI from sagemaker.hyperpod.cli.utils import get_eks_cluster_name, get_hyperpod_cluster_region def is_observability_addon_enabled(eks_cluster_name): - response = boto3.client("eks").list_addons(clusterName=eks_cluster_name, maxResults=50) + response = create_boto3_client("eks").list_addons(clusterName=eks_cluster_name, maxResults=50) if AMAZON_HYPERPOD_OBSERVABILITY in response.get('addons', []): return True else: @@ -41,7 +41,7 @@ def get_monitoring_config() -> Optional[MonitoringConfig]: eks_cluster_name = get_eks_cluster_name() if not is_observability_addon_enabled(eks_cluster_name): return None - response = boto3.client("eks").describe_addon(clusterName=eks_cluster_name, addonName=AMAZON_HYPERPOD_OBSERVABILITY) + response = create_boto3_client("eks").describe_addon(clusterName=eks_cluster_name, addonName=AMAZON_HYPERPOD_OBSERVABILITY) config_values = yaml.safe_load(response['addon']['configurationValues']) try: diff --git a/src/sagemaker/hyperpod/space/hyperpod_space.py b/src/sagemaker/hyperpod/space/hyperpod_space.py index 5ccadbb7..cba16243 100644 --- a/src/sagemaker/hyperpod/space/hyperpod_space.py +++ b/src/sagemaker/hyperpod/space/hyperpod_space.py @@ -1,6 +1,7 @@ import logging import yaml import boto3 +from sagemaker.hyperpod.common.utils import create_boto3_client from typing import List, Optional, ClassVar, Dict, Set, Any from pydantic import BaseModel, Field, ConfigDict, model_validator from kubernetes import client, config @@ -429,7 +430,7 @@ def list(cls, namespace: Optional[str] = None) -> List["HPSpace"]: namespace = get_default_namespace() # Get caller identity - sts_client = boto3.client('sts') + sts_client = create_boto3_client('sts') caller_identity = sts_client.get_caller_identity() caller_arn = caller_identity['Arn'] diff --git a/test/unit_tests/cli/test_inference.py b/test/unit_tests/cli/test_inference.py index a85c1c00..177dedd1 100644 --- a/test/unit_tests/cli/test_inference.py +++ b/test/unit_tests/cli/test_inference.py @@ -387,15 +387,15 @@ def test_custom_create_missing_required_args(): @patch("sagemaker.hyperpod.cli.commands.inference.Endpoint.get") -@patch("sagemaker.hyperpod.cli.commands.inference.boto3") -def test_custom_invoke_success(mock_boto3, mock_endpoint_get): +@patch("sagemaker.hyperpod.cli.commands.inference.create_boto3_client") +def test_custom_invoke_success(mock_create_client, mock_endpoint_get): mock_endpoint = Mock() mock_endpoint.endpoint_status = "InService" mock_endpoint_get.return_value = mock_endpoint mock_body = Mock() mock_body.read.return_value.decode.return_value = '{"ok": true}' - mock_boto3.client.return_value.invoke_endpoint.return_value = {"Body": mock_body} + mock_create_client.return_value.invoke_endpoint.return_value = {"Body": mock_body} runner = CliRunner() result = runner.invoke( @@ -406,8 +406,8 @@ def test_custom_invoke_success(mock_boto3, mock_endpoint_get): assert '"ok": true' in result.output -@patch("sagemaker.hyperpod.cli.commands.inference.boto3") -def test_custom_invoke_invalid_json(mock_boto3): +@patch("sagemaker.hyperpod.cli.commands.inference.create_boto3_client") +def test_custom_invoke_invalid_json(mock_create_client): runner = CliRunner() result = runner.invoke(custom_invoke, ["--endpoint-name", "ep", "--body", "bad"]) assert result.exit_code != 0 diff --git a/test/unit_tests/cluster_management/test_hp_cluster_stack.py b/test/unit_tests/cluster_management/test_hp_cluster_stack.py index 9acc75b4..84912011 100644 --- a/test/unit_tests/cluster_management/test_hp_cluster_stack.py +++ b/test/unit_tests/cluster_management/test_hp_cluster_stack.py @@ -60,21 +60,17 @@ def mock_client_factory(service_name, **kwargs): # Verify create_stack was called self.assertTrue(mock_cf_client.create_stack.called) - @patch('boto3.session.Session') - @patch('boto3.client') - def test_describe_success(self, mock_boto3_client, mock_boto3_session): - mock_region = "us-west-2" - mock_boto3_session.return_value.region_name = mock_region - + @patch('sagemaker.hyperpod.cluster_management.hp_cluster_stack.create_boto3_client') + def test_describe_success(self, mock_create_client): mock_cf_client = MagicMock() - mock_boto3_client.return_value = mock_cf_client + mock_create_client.return_value = mock_cf_client mock_response = {'Stacks': [{'StackName': 'test-stack', 'StackStatus': 'CREATE_COMPLETE'}]} mock_cf_client.describe_stacks.return_value = mock_response result = HpClusterStack.describe('test-stack') - mock_boto3_client.assert_called_once_with('cloudformation', region_name=mock_region) + mock_create_client.assert_called_once_with('cloudformation', region_name=None) mock_cf_client.describe_stacks.assert_called_once_with(StackName='test-stack') self.assertEqual(result, mock_response) @@ -94,21 +90,17 @@ def test_describe_access_denied(self, mock_boto3_client, mock_boto3_session): with self.assertRaises(ValueError): HpClusterStack.describe('test-stack') - @patch('boto3.session.Session') - @patch('boto3.client') - def test_list_success(self, mock_boto3_client, mock_boto3_session): - mock_region = "us-west-2" - mock_boto3_session.return_value.region_name = mock_region - + @patch('sagemaker.hyperpod.cluster_management.hp_cluster_stack.create_boto3_client') + def test_list_success(self, mock_create_client): mock_cf_client = MagicMock() - mock_boto3_client.return_value = mock_cf_client + mock_create_client.return_value = mock_cf_client mock_response = {'StackSummaries': [{'StackName': 'stack1'}, {'StackName': 'stack2'}]} mock_cf_client.list_stacks.return_value = mock_response result = HpClusterStack.list() - mock_boto3_client.assert_called_once_with('cloudformation', region_name=mock_region) + mock_create_client.assert_called_once_with('cloudformation', region_name=None) mock_cf_client.list_stacks.assert_called_once() self.assertEqual(result, mock_response) diff --git a/test/unit_tests/common/test_utils.py b/test/unit_tests/common/test_utils.py index f43e37ff..033db4d8 100644 --- a/test/unit_tests/common/test_utils.py +++ b/test/unit_tests/common/test_utils.py @@ -14,6 +14,7 @@ get_cluster_context, parse_client_kubernetes_version, is_kubernetes_version_compatible, + _resolve_region, ) from kubernetes.client.exceptions import ApiException from pydantic import ValidationError @@ -442,4 +443,40 @@ def test_get_cluster_context_success(self, mock_list_contexts): result = get_cluster_context() self.assertEqual(result, "arn:aws:eks:us-west-2:123456789012:cluster/my-cluster") - mock_list_contexts.assert_called_once() \ No newline at end of file + mock_list_contexts.assert_called_once() + + +class TestResolveRegion(unittest.TestCase): + """Test the _resolve_region function""" + + def test_explicit_region_takes_precedence(self): + with patch.dict('os.environ', {'AWS_REGION': 'us-east-1'}): + assert _resolve_region('eu-west-1') == 'eu-west-1' + + @patch.dict('os.environ', {'AWS_REGION': 'us-west-2'}, clear=False) + @patch('sagemaker.hyperpod.common.utils.boto3.session.Session') + def test_aws_region_env_var(self, mock_session): + assert _resolve_region() == 'us-west-2' + mock_session.assert_not_called() + + @patch.dict('os.environ', {}, clear=True) + @patch('sagemaker.hyperpod.common.utils.boto3.session.Session') + def test_boto3_default_region_fallback(self, mock_session): + mock_session.return_value.region_name = 'ap-southeast-1' + assert _resolve_region() == 'ap-southeast-1' + + @patch('sagemaker.hyperpod.common.utils.get_cluster_context') + @patch.dict('os.environ', {}, clear=True) + @patch('sagemaker.hyperpod.common.utils.boto3.session.Session') + def test_cluster_context_fallback(self, mock_session, mock_context): + mock_session.return_value.region_name = None + mock_context.return_value = 'arn:aws:eks:us-west-2:123456789012:cluster/my-cluster' + assert _resolve_region() == 'us-west-2' + + @patch('sagemaker.hyperpod.common.utils.get_cluster_context') + @patch.dict('os.environ', {}, clear=True) + @patch('sagemaker.hyperpod.common.utils.boto3.session.Session') + def test_returns_none_when_nothing_configured(self, mock_session, mock_context): + mock_session.return_value.region_name = None + mock_context.side_effect = Exception("no context") + assert _resolve_region() is None