From 834ab00001010f5b44ea3062eb15c2dd054ea783 Mon Sep 17 00:00:00 2001 From: Rohan Narayan Date: Fri, 7 Feb 2025 17:29:22 +0000 Subject: [PATCH 1/3] fix: keep sagemaker_session from being overridden to None, add unit/integ tests --- src/sagemaker/jumpstart/cache.py | 6 ++++-- src/sagemaker/jumpstart/hub/utils.py | 6 ++++++ .../model/test_jumpstart_private_hub_model.py | 17 ++++++++++++++++ .../sagemaker/jumpstart/hub/test_utils.py | 16 ++++++++++++++- tests/unit/sagemaker/jumpstart/test_cache.py | 20 +++++++++++++++++++ 5 files changed, 62 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 8ac813a6c5..9027a143de 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -110,7 +110,8 @@ def __init__( sagemaker_session: sagemaker session object to use. Default: session object from default region us-west-2. """ - + # if sagemaker_session is None: + # sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION self._region = region or utils.get_region_fallback( s3_bucket_name=s3_bucket_name, s3_client=s3_client ) @@ -150,7 +151,8 @@ def __init__( if s3_client_config else boto3.client("s3", region_name=self._region) ) - self._sagemaker_session = sagemaker_session + # Fallback in case a caller overrides sagemaker_session to None + self._sagemaker_session = sagemaker_session or DEFAULT_JUMPSTART_SAGEMAKER_SESSION def set_region(self, region: str) -> None: """Set region for cache. Clears cache after new region is set.""" diff --git a/src/sagemaker/jumpstart/hub/utils.py b/src/sagemaker/jumpstart/hub/utils.py index edc3c08fa7..1bbc6198a2 100644 --- a/src/sagemaker/jumpstart/hub/utils.py +++ b/src/sagemaker/jumpstart/hub/utils.py @@ -78,6 +78,9 @@ def construct_hub_arn_from_name( account_id: Optional[str] = None, ) -> str: """Constructs a Hub arn from the Hub name using default Session values.""" + if session is None: + # session is overridden to none by some callers + session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION account_id = account_id or session.account_id() region = region or session.boto_region_name @@ -211,6 +214,9 @@ def get_hub_model_version( ClientError: If the specified model is not found in the hub. KeyError: If the specified model version is not found. """ + if sagemaker_session is None: + # sagemaker_session is overridden to none by some callers + sagemaker_session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION try: hub_content_summaries = sagemaker_session.list_hub_content_versions( diff --git a/tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py b/tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py index fa3e37f403..c378520196 100644 --- a/tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py +++ b/tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py @@ -82,6 +82,23 @@ def test_jumpstart_hub_model(setup, add_model_references): assert sagemaker_session.endpoint_in_service_or_not(predictor.endpoint_name) +def test_jumpstart_hub_model_with_default_session(setup, add_model_references): + model_version = "*" + hub_name = os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME] + + model_id = "catboost-classification-model" + + sagemaker_session = get_sm_session() + + model = JumpStartModel(model_id=model_id, model_version=model_version, hub_name=hub_name) + + predictor = model.deploy( + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], + ) + + assert sagemaker_session.endpoint_in_service_or_not(predictor.endpoint_name) + + def test_jumpstart_hub_gated_model(setup, add_model_references): model_id = "meta-textgeneration-llama-3-2-1b" diff --git a/tests/unit/sagemaker/jumpstart/hub/test_utils.py b/tests/unit/sagemaker/jumpstart/hub/test_utils.py index 6dbb1340f4..a41511f27d 100644 --- a/tests/unit/sagemaker/jumpstart/hub/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/hub/test_utils.py @@ -14,7 +14,10 @@ from unittest.mock import patch, Mock from sagemaker.jumpstart.types import HubArnExtractedInfo -from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME +from sagemaker.jumpstart.constants import ( + JUMPSTART_DEFAULT_REGION_NAME, + DEFAULT_JUMPSTART_SAGEMAKER_SESSION +) from sagemaker.jumpstart.hub import parser_utils, utils @@ -80,6 +83,17 @@ def test_construct_hub_arn_from_name(): ) +def test_construct_hub_arn_from_name_with_session_none(): + hub_name = "my-cool-hub" + account_id = DEFAULT_JUMPSTART_SAGEMAKER_SESSION.account_id() + boto_region_name = DEFAULT_JUMPSTART_SAGEMAKER_SESSION.boto_region_name + + assert ( + utils.construct_hub_arn_from_name(hub_name=hub_name, session=None) + == f"arn:aws:sagemaker:{boto_region_name}:{account_id}:hub/{hub_name}" + ) + + def test_construct_hub_model_arn_from_inputs(): model_name, version = "pytorch-ic-imagenet-v2", "1.0.2" hub_arn = "arn:aws:sagemaker:us-west-2:123456789123:hub/my-mock-hub" diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index da20debc6a..f4d0814048 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -26,6 +26,7 @@ from sagemaker.jumpstart.cache import ( JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY, + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JumpStartModelsCache, ) from sagemaker.jumpstart.constants import ( @@ -53,6 +54,25 @@ from sagemaker.jumpstart.utils import get_jumpstart_content_bucket +@patch("sagemaker.jumpstart.utils.get_region_fallback", lambda *args, **kwargs: "dummy-region") +@patch( + "sagemaker.jumpstart.utils.get_jumpstart_content_bucket", lambda *args, **kwargs: "dummy-bucket" +) +@patch("boto3.client") +def test_jumpstart_cache_init(mock_boto3_client): + cache = JumpStartModelsCache() + assert cache._region == "dummy-region" + assert cache.s3_bucket_name == "dummy-bucket" + assert cache._manifest_file_s3_key == JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY + assert cache._proprietary_manifest_s3_key == JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY + assert cache._sagemaker_session == DEFAULT_JUMPSTART_SAGEMAKER_SESSION + mock_boto3_client.assert_called_once_with("s3", region_name="dummy-region") + + # Some callers override the session to None, should still be set to default + cache = JumpStartModelsCache(sagemaker_session=None) + assert cache._sagemaker_session == DEFAULT_JUMPSTART_SAGEMAKER_SESSION + + @patch.object(JumpStartModelsCache, "_retrieval_function", patched_retrieval_function) @patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3") def test_jumpstart_cache_get_header(): From edd77d20c2c43d3cd0cfb5c035d1cebab89bdb57 Mon Sep 17 00:00:00 2001 From: Rohan Narayan Date: Fri, 7 Feb 2025 17:31:57 +0000 Subject: [PATCH 2/3] remove commented code --- src/sagemaker/jumpstart/cache.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 9027a143de..86ac71a9ee 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -110,8 +110,7 @@ def __init__( sagemaker_session: sagemaker session object to use. Default: session object from default region us-west-2. """ - # if sagemaker_session is None: - # sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION + self._region = region or utils.get_region_fallback( s3_bucket_name=s3_bucket_name, s3_client=s3_client ) From 1918a96743a21d5341b57c783401fabc850317ee Mon Sep 17 00:00:00 2001 From: Rohan Narayan Date: Fri, 7 Feb 2025 18:23:26 +0000 Subject: [PATCH 3/3] fix styling issues --- tests/unit/sagemaker/jumpstart/hub/test_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/sagemaker/jumpstart/hub/test_utils.py b/tests/unit/sagemaker/jumpstart/hub/test_utils.py index a41511f27d..a0b824fc9b 100644 --- a/tests/unit/sagemaker/jumpstart/hub/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/hub/test_utils.py @@ -15,8 +15,8 @@ from unittest.mock import patch, Mock from sagemaker.jumpstart.types import HubArnExtractedInfo from sagemaker.jumpstart.constants import ( - JUMPSTART_DEFAULT_REGION_NAME, - DEFAULT_JUMPSTART_SAGEMAKER_SESSION + JUMPSTART_DEFAULT_REGION_NAME, + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) from sagemaker.jumpstart.hub import parser_utils, utils