From 0aeca78356617d03fdf513b90a3ad6903e91be74 Mon Sep 17 00:00:00 2001 From: Rohan Narayan Date: Wed, 5 Mar 2025 21:41:39 +0000 Subject: [PATCH 01/13] feat: add integ tests for training JumpStart models in private hub --- src/sagemaker/jumpstart/factory/estimator.py | 16 +- src/sagemaker/jumpstart/hub/parsers.py | 6 + src/sagemaker/jumpstart/types.py | 4 + tests/integ/sagemaker/jumpstart/constants.py | 2 +- .../private_hub/estimator/__init__.py | 0 .../test_jumpstart_private_hub_estimator.py | 220 ++++++++++++++++++ 6 files changed, 242 insertions(+), 6 deletions(-) create mode 100644 tests/integ/sagemaker/jumpstart/private_hub/estimator/__init__.py create mode 100644 tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 17ad7a76f5..cf547f9f4a 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -312,17 +312,23 @@ def _add_hub_access_config_to_kwargs_inputs( kwargs: JumpStartEstimatorFitKwargs, hub_access_config=None ): """Adds HubAccessConfig to kwargs inputs""" - + dataset_uri = kwargs.specs.default_training_dataset_uri if isinstance(kwargs.inputs, str): - kwargs.inputs = TrainingInput(s3_data=kwargs.inputs, hub_access_config=hub_access_config) + if dataset_uri is not None and dataset_uri == kwargs.inputs: + kwargs.inputs = TrainingInput(s3_data=kwargs.inputs, hub_access_config=hub_access_config) elif isinstance(kwargs.inputs, TrainingInput): - kwargs.inputs.add_hub_access_config(hub_access_config=hub_access_config) + if dataset_uri is not None and dataset_uri == kwargs.inputs.config["DataSource"]["S3DataSource"]["S3Uri"]: + kwargs.inputs.add_hub_access_config(hub_access_config=hub_access_config) elif isinstance(kwargs.inputs, dict): for k, v in kwargs.inputs.items(): if isinstance(v, str): - kwargs.inputs[k] = TrainingInput(s3_data=v, hub_access_config=hub_access_config) + training_input = TrainingInput(s3_data=v) + if dataset_uri is not None and dataset_uri == v: + training_input.add_hub_access_config(hub_access_config=hub_access_config) + kwargs.inputs[k] = training_input elif isinstance(kwargs.inputs, TrainingInput): - kwargs.inputs[k].add_hub_access_config(hub_access_config=hub_access_config) + if dataset_uri is not None and dataset_uri == kwargs.inputs.config["DataSource"]["S3DataSource"]["S3Uri"]: + kwargs.inputs[k].add_hub_access_config(hub_access_config=hub_access_config) return kwargs diff --git a/src/sagemaker/jumpstart/hub/parsers.py b/src/sagemaker/jumpstart/hub/parsers.py index 01b6c5fe87..8070b54e87 100644 --- a/src/sagemaker/jumpstart/hub/parsers.py +++ b/src/sagemaker/jumpstart/hub/parsers.py @@ -279,4 +279,10 @@ def make_model_specs_from_describe_hub_content_response( specs["training_instance_type_variants"] = ( hub_model_document.training_instance_type_variants ) + if hub_model_document.default_training_dataset_uri: + _, default_training_dataset_key = parse_s3_url( # pylint: disable=unused-variable + hub_model_document.default_training_dataset_uri + ) + specs["default_training_dataset_key"] = default_training_dataset_key + specs["default_training_dataset_uri"] = hub_model_document.default_training_dataset_uri return JumpStartModelSpecs(_to_json(specs), is_hub_content=True) diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 349396205e..5286ad31c2 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1279,6 +1279,8 @@ class JumpStartMetadataBaseFields(JumpStartDataHolderType): "hosting_neuron_model_version", "hub_content_type", "_is_hub_content", + "default_training_dataset_key", + "default_training_dataset_uri", ] _non_serializable_slots = ["_is_hub_content"] @@ -1462,6 +1464,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: else None ) self.model_subscription_link = json_obj.get("model_subscription_link") + self.default_training_dataset_key: Optional[str] = json_obj.get("default_training_dataset_key") + self.default_training_dataset_uri: Optional[str] = json_obj.get("default_training_dataset_uri") def to_json(self) -> Dict[str, Any]: """Returns json representation of JumpStartMetadataBaseFields object.""" diff --git a/tests/integ/sagemaker/jumpstart/constants.py b/tests/integ/sagemaker/jumpstart/constants.py index 1ffb1d8dc0..740d88e9c0 100644 --- a/tests/integ/sagemaker/jumpstart/constants.py +++ b/tests/integ/sagemaker/jumpstart/constants.py @@ -47,7 +47,7 @@ def _to_s3_path(filename: str, s3_prefix: Optional[str]) -> str: ("huggingface-spc-bert-base-cased", "1.0.0"): ("training-datasets/QNLI-tiny/"), ("huggingface-spc-bert-base-cased", "1.2.3"): ("training-datasets/QNLI-tiny/"), ("huggingface-spc-bert-base-cased", "2.0.3"): ("training-datasets/QNLI-tiny/"), - ("huggingface-spc-bert-base-cased", "*"): ("training-datasets/QNLI-tiny/"), + ("huggingface-spc-bert-base-cased", "*"): ("training-datasets/QNLI/"), ("js-trainable-model", "*"): ("training-datasets/QNLI-tiny/"), ("meta-textgeneration-llama-2-7b", "*"): ("training-datasets/sec_amazon/"), ("meta-textgeneration-llama-2-7b", "2.*"): ("training-datasets/sec_amazon/"), diff --git a/tests/integ/sagemaker/jumpstart/private_hub/estimator/__init__.py b/tests/integ/sagemaker/jumpstart/private_hub/estimator/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py b/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py new file mode 100644 index 0000000000..e3bd52ca61 --- /dev/null +++ b/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py @@ -0,0 +1,220 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import os +import time + +import pytest +from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME +from sagemaker.jumpstart.hub.hub import Hub + +from sagemaker.jumpstart.estimator import JumpStartEstimator +from tests.integ.sagemaker.jumpstart.constants import ( + ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME, + ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID, + JUMPSTART_TAG, +) +from tests.integ.sagemaker.jumpstart.utils import ( + get_public_hub_model_arn, + get_sm_session, + with_exponential_backoff, +) +from tests.integ.sagemaker.jumpstart.constants import ( + ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID, + JUMPSTART_TAG, +) +from tests.integ.sagemaker.jumpstart.utils import ( + get_sm_session, + get_training_dataset_for_model_and_version +) + +from sagemaker.jumpstart.utils import get_jumpstart_content_bucket + +MAX_INIT_TIME_SECONDS = 5 + +TEST_MODEL_IDS = { + "huggingface-spc-bert-base-cased", + "meta-textgeneration-llama-2-7b", + "catboost-regression-model", +} + + +@with_exponential_backoff() +def create_model_reference(hub_instance, model_arn): + hub_instance.create_model_reference(model_arn=model_arn) + + +@pytest.fixture(scope="session") +def add_model_references(): + # Create Model References to test in Hub + hub_instance = Hub( + hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session() + ) + for model in TEST_MODEL_IDS: + model_arn = get_public_hub_model_arn(hub_instance, model) + create_model_reference(hub_instance, model_arn) + + +def test_jumpstart_hub_estimator(setup, add_model_references): + + model_id, model_version = "huggingface-spc-bert-base-cased", "*" + + sagemaker_session = get_sm_session() + + estimator = JumpStartEstimator( + model_id=model_id, + role=sagemaker_session.get_caller_identity_arn(), + sagemaker_session=sagemaker_session, + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], + hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], + ) + + estimator.fit( + inputs = { + "training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/" + f"{get_training_dataset_for_model_and_version(model_id, model_version)}", + } + ) + + # test that we can create a JumpStartEstimator from existing job with `attach` + estimator = JumpStartEstimator.attach( + training_job_name=estimator.latest_training_job.name, + model_id=model_id, + model_version=model_version, + sagemaker_session=get_sm_session(), + ) + + # uses ml.p3.2xlarge instance + predictor = estimator.deploy( + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], + role=get_sm_session().get_caller_identity_arn(), + sagemaker_session=get_sm_session(), + ) + + response = predictor.predict(["hello", "world"]) + + assert response is not None + + +def test_jumpstart_hub_estimator_with_default_session(setup, add_model_references): + model_id, model_version = "huggingface-spc-bert-base-cased", "*" + + sagemaker_session = get_sm_session() + + estimator = JumpStartEstimator( + model_id=model_id, + role=sagemaker_session.get_caller_identity_arn(), + sagemaker_session=sagemaker_session, + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], + hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], + ) + + estimator.fit( + inputs = { + "training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/" + f"{get_training_dataset_for_model_and_version(model_id, model_version)}", + } + ) + + + # test that we can create a JumpStartEstimator from existing job with `attach` + estimator = JumpStartEstimator.attach( + training_job_name=estimator.latest_training_job.name, + model_id=model_id, + model_version=model_version, + ) + + # uses ml.p3.2xlarge instance + predictor = estimator.deploy( + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], + role=get_sm_session().get_caller_identity_arn() + ) + + response = predictor.predict(["hello", "world"]) + + assert response is not None + + +def test_jumpstart_hub_gated_estimator_with_eula(setup, add_model_references): + + model_id, model_version = "meta-textgeneration-llama-2-7b", "*" + + estimator = JumpStartEstimator( + model_id=model_id, + role=get_sm_session().get_caller_identity_arn(), + sagemaker_session=get_sm_session(), + hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], + ) + + estimator.fit( + accept_eula=True, + inputs = { + "training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/" + f"{get_training_dataset_for_model_and_version(model_id, model_version)}", + } + ) + + estimator = JumpStartEstimator.attach( + training_job_name=estimator.latest_training_job.name, + model_id=model_id, + model_version=model_version, + sagemaker_session=get_sm_session(), + ) + + # uses ml.p3.2xlarge instance + predictor = estimator.deploy( + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], + role=get_sm_session().get_caller_identity_arn(), + sagemaker_session=get_sm_session(), + ) + + response = predictor.predict(["hello", "world"]) + + assert response is not None + + +def test_jumpstart_hub_gated_estimator_without_eula(setup, add_model_references): + + model_id, model_version = "meta-textgeneration-llama-2-7b", "*" + + estimator = JumpStartEstimator( + model_id=model_id, + role=get_sm_session().get_caller_identity_arn(), + sagemaker_session=get_sm_session(), + hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], + ) + with pytest.raises(Exception): + estimator.fit( + inputs = { + "training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/" + f"{get_training_dataset_for_model_and_version(model_id, model_version)}", + } + ) + + + +def test_instantiating_estimator(setup, add_model_references): + + model_id = "catboost-regression-model" + + start_time = time.perf_counter() + + JumpStartEstimator( + model_id=model_id, + hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], + ) + + elapsed_time = time.perf_counter() - start_time + + assert elapsed_time <= MAX_INIT_TIME_SECONDS From d7175bb8e3825dec341fd912bb50ab454e59948b Mon Sep 17 00:00:00 2001 From: Rohan Narayan Date: Wed, 5 Mar 2025 21:43:10 +0000 Subject: [PATCH 02/13] fixed formatting --- src/sagemaker/jumpstart/factory/estimator.py | 14 ++++++++++--- src/sagemaker/jumpstart/types.py | 8 ++++++-- .../test_jumpstart_private_hub_estimator.py | 20 +++++++++---------- 3 files changed, 26 insertions(+), 16 deletions(-) diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index cf547f9f4a..c5b8570e27 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -315,9 +315,14 @@ def _add_hub_access_config_to_kwargs_inputs( dataset_uri = kwargs.specs.default_training_dataset_uri if isinstance(kwargs.inputs, str): if dataset_uri is not None and dataset_uri == kwargs.inputs: - kwargs.inputs = TrainingInput(s3_data=kwargs.inputs, hub_access_config=hub_access_config) + kwargs.inputs = TrainingInput( + s3_data=kwargs.inputs, hub_access_config=hub_access_config + ) elif isinstance(kwargs.inputs, TrainingInput): - if dataset_uri is not None and dataset_uri == kwargs.inputs.config["DataSource"]["S3DataSource"]["S3Uri"]: + if ( + dataset_uri is not None + and dataset_uri == kwargs.inputs.config["DataSource"]["S3DataSource"]["S3Uri"] + ): kwargs.inputs.add_hub_access_config(hub_access_config=hub_access_config) elif isinstance(kwargs.inputs, dict): for k, v in kwargs.inputs.items(): @@ -327,7 +332,10 @@ def _add_hub_access_config_to_kwargs_inputs( training_input.add_hub_access_config(hub_access_config=hub_access_config) kwargs.inputs[k] = training_input elif isinstance(kwargs.inputs, TrainingInput): - if dataset_uri is not None and dataset_uri == kwargs.inputs.config["DataSource"]["S3DataSource"]["S3Uri"]: + if ( + dataset_uri is not None + and dataset_uri == kwargs.inputs.config["DataSource"]["S3DataSource"]["S3Uri"] + ): kwargs.inputs[k].add_hub_access_config(hub_access_config=hub_access_config) return kwargs diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 5286ad31c2..0cd4bcc902 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1464,8 +1464,12 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: else None ) self.model_subscription_link = json_obj.get("model_subscription_link") - self.default_training_dataset_key: Optional[str] = json_obj.get("default_training_dataset_key") - self.default_training_dataset_uri: Optional[str] = json_obj.get("default_training_dataset_uri") + self.default_training_dataset_key: Optional[str] = json_obj.get( + "default_training_dataset_key" + ) + self.default_training_dataset_uri: Optional[str] = json_obj.get( + "default_training_dataset_uri" + ) def to_json(self) -> Dict[str, Any]: """Returns json representation of JumpStartMetadataBaseFields object.""" diff --git a/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py b/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py index e3bd52ca61..09a072b96e 100644 --- a/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py +++ b/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py @@ -36,7 +36,7 @@ ) from tests.integ.sagemaker.jumpstart.utils import ( get_sm_session, - get_training_dataset_for_model_and_version + get_training_dataset_for_model_and_version, ) from sagemaker.jumpstart.utils import get_jumpstart_content_bucket @@ -81,13 +81,13 @@ def test_jumpstart_hub_estimator(setup, add_model_references): ) estimator.fit( - inputs = { + inputs={ "training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/" f"{get_training_dataset_for_model_and_version(model_id, model_version)}", } ) - # test that we can create a JumpStartEstimator from existing job with `attach` + # test that we can create a JumpStartEstimator from existing job with `attach` estimator = JumpStartEstimator.attach( training_job_name=estimator.latest_training_job.name, model_id=model_id, @@ -121,14 +121,13 @@ def test_jumpstart_hub_estimator_with_default_session(setup, add_model_reference ) estimator.fit( - inputs = { + inputs={ "training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/" f"{get_training_dataset_for_model_and_version(model_id, model_version)}", } ) - - # test that we can create a JumpStartEstimator from existing job with `attach` + # test that we can create a JumpStartEstimator from existing job with `attach` estimator = JumpStartEstimator.attach( training_job_name=estimator.latest_training_job.name, model_id=model_id, @@ -138,7 +137,7 @@ def test_jumpstart_hub_estimator_with_default_session(setup, add_model_reference # uses ml.p3.2xlarge instance predictor = estimator.deploy( tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], - role=get_sm_session().get_caller_identity_arn() + role=get_sm_session().get_caller_identity_arn(), ) response = predictor.predict(["hello", "world"]) @@ -159,10 +158,10 @@ def test_jumpstart_hub_gated_estimator_with_eula(setup, add_model_references): estimator.fit( accept_eula=True, - inputs = { + inputs={ "training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/" f"{get_training_dataset_for_model_and_version(model_id, model_version)}", - } + }, ) estimator = JumpStartEstimator.attach( @@ -196,14 +195,13 @@ def test_jumpstart_hub_gated_estimator_without_eula(setup, add_model_references) ) with pytest.raises(Exception): estimator.fit( - inputs = { + inputs={ "training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/" f"{get_training_dataset_for_model_and_version(model_id, model_version)}", } ) - def test_instantiating_estimator(setup, add_model_references): model_id = "catboost-regression-model" From 8215b32c0eb9bb5617aca185dc30dabb3a6301d1 Mon Sep 17 00:00:00 2001 From: Rohan Narayan Date: Wed, 5 Mar 2025 21:56:32 +0000 Subject: [PATCH 03/13] remove unused imports --- .../estimator/test_jumpstart_private_hub_estimator.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py b/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py index 09a072b96e..cb4582376d 100644 --- a/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py +++ b/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py @@ -30,10 +30,6 @@ get_sm_session, with_exponential_backoff, ) -from tests.integ.sagemaker.jumpstart.constants import ( - ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID, - JUMPSTART_TAG, -) from tests.integ.sagemaker.jumpstart.utils import ( get_sm_session, get_training_dataset_for_model_and_version, From f8f0e14425c86c0e8afd02d30774f5c94350417a Mon Sep 17 00:00:00 2001 From: Rohan Narayan Date: Wed, 5 Mar 2025 22:03:14 +0000 Subject: [PATCH 04/13] fix unused imports --- .../estimator/test_jumpstart_private_hub_estimator.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py b/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py index cb4582376d..fa873bd2fc 100644 --- a/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py +++ b/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py @@ -20,6 +20,8 @@ from sagemaker.jumpstart.hub.hub import Hub from sagemaker.jumpstart.estimator import JumpStartEstimator +from sagemaker.jumpstart.utils import get_jumpstart_content_bucket + from tests.integ.sagemaker.jumpstart.constants import ( ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME, ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID, @@ -29,14 +31,9 @@ get_public_hub_model_arn, get_sm_session, with_exponential_backoff, -) -from tests.integ.sagemaker.jumpstart.utils import ( - get_sm_session, get_training_dataset_for_model_and_version, ) -from sagemaker.jumpstart.utils import get_jumpstart_content_bucket - MAX_INIT_TIME_SECONDS = 5 TEST_MODEL_IDS = { From 622f706819138eaae57ff62fa182bdcf2bbffd86 Mon Sep 17 00:00:00 2001 From: Rohan Narayan Date: Thu, 6 Mar 2025 19:31:25 +0000 Subject: [PATCH 05/13] fix unit test failure and fix bug around versioning --- src/sagemaker/estimator.py | 1 + src/sagemaker/jumpstart/hub/interfaces.py | 2 +- src/sagemaker/jumpstart/hub/utils.py | 19 +++++++++++---- src/sagemaker/jumpstart/types.py | 5 ---- .../test_jumpstart_private_hub_estimator.py | 24 ++++++------------- tests/unit/sagemaker/jumpstart/constants.py | 2 ++ tests/unit/sagemaker/jumpstart/test_types.py | 1 + 7 files changed, 27 insertions(+), 27 deletions(-) diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index fa40719c9f..c5d67a7be7 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -2511,6 +2511,7 @@ def start_new(cls, estimator, inputs, experiment_config): train_args = cls._get_train_args(estimator, inputs, experiment_config) logger.debug("Train args after processing defaults: %s", train_args) + print("rohan debug: ", train_args) estimator.sagemaker_session.train(**train_args) return cls(estimator.sagemaker_session, estimator._current_job_name) diff --git a/src/sagemaker/jumpstart/hub/interfaces.py b/src/sagemaker/jumpstart/hub/interfaces.py index fd38868dcc..d67ef96b3e 100644 --- a/src/sagemaker/jumpstart/hub/interfaces.py +++ b/src/sagemaker/jumpstart/hub/interfaces.py @@ -630,7 +630,6 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: if json_obj.get("ValidationSupported") else None ) - self.default_training_dataset_uri: Optional[str] = json_obj.get("DefaultTrainingDatasetUri") self.resource_name_base: Optional[str] = json_obj.get("ResourceNameBase") self.gated_bucket: bool = bool(json_obj.get("GatedBucket", False)) self.default_payloads: Optional[Dict[str, JumpStartSerializablePayload]] = ( @@ -671,6 +670,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: ) if self.training_supported: + self.default_training_dataset_uri: Optional[str] = json_obj.get("DefaultTrainingDatasetUri") self.training_model_package_artifact_uri: Optional[str] = json_obj.get( "TrainingModelPackageArtifactUri" ) diff --git a/src/sagemaker/jumpstart/hub/utils.py b/src/sagemaker/jumpstart/hub/utils.py index 1bbc6198a2..a9a2a0de97 100644 --- a/src/sagemaker/jumpstart/hub/utils.py +++ b/src/sagemaker/jumpstart/hub/utils.py @@ -22,6 +22,7 @@ from sagemaker.jumpstart.types import HubContentType, HubArnExtractedInfo from sagemaker.jumpstart import constants from packaging.specifiers import SpecifierSet, InvalidSpecifier +from packaging import version PROPRIETARY_VERSION_KEYWORD = "@marketplace-version:" @@ -219,9 +220,7 @@ def get_hub_model_version( sagemaker_session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION try: - hub_content_summaries = sagemaker_session.list_hub_content_versions( - hub_name=hub_name, hub_content_name=hub_model_name, hub_content_type=hub_model_type - ).get("HubContentSummaries") + hub_content_summaries = _list_hub_content_versions_helper(hub_name=hub_name, hub_content_name=hub_model_name, hub_content_type=hub_model_type, sagemaker_session=sagemaker_session) except Exception as ex: raise Exception(f"Failed calling list_hub_content_versions: {str(ex)}") @@ -237,6 +236,18 @@ def get_hub_model_version( return marketplace_hub_content_version raise +def _list_hub_content_versions_helper(hub_name, hub_content_name, hub_content_type, sagemaker_session): + all_hub_content_summaries = [] + list_hub_content_versions_response = sagemaker_session.list_hub_content_versions( + hub_name=hub_name, hub_content_name=hub_content_name, hub_content_type=hub_content_type + ) + all_hub_content_summaries.extend(list_hub_content_versions_response.get("HubContentSummaries")) + while "NextToken" in list_hub_content_versions_response: + list_hub_content_versions_response = sagemaker_session.list_hub_content_versions( + hub_name=hub_name, hub_content_name=hub_content_name, hub_content_type=hub_content_type, next_token=list_hub_content_versions_response["NextToken"] + ) + all_hub_content_summaries.extend(list_hub_content_versions_response.get("HubContentSummaries")) + return all_hub_content_summaries def _get_hub_model_version_for_open_weight_version( hub_content_summaries: List[Any], hub_model_version: Optional[str] = None @@ -244,7 +255,7 @@ def _get_hub_model_version_for_open_weight_version( available_model_versions = [model.get("HubContentVersion") for model in hub_content_summaries] if hub_model_version == "*" or hub_model_version is None: - return str(max(available_model_versions)) + return str(max(version.parse(v) for v in available_model_versions)) try: spec = SpecifierSet(f"=={hub_model_version}") diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 0cd4bcc902..5997059cc3 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1940,11 +1940,6 @@ def use_inference_script_uri(self) -> bool: def use_training_model_artifact(self) -> bool: """Returns True if the model should use a model uri when kicking off training job.""" - # gated model never use training model artifact - if self.gated_bucket: - return False - - # otherwise, return true is a training model package is not set return len(self.training_model_package_artifact_uris or {}) == 0 def is_gated_model(self) -> bool: diff --git a/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py b/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py index fa873bd2fc..633230f373 100644 --- a/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py +++ b/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py @@ -60,17 +60,12 @@ def add_model_references(): def test_jumpstart_hub_estimator(setup, add_model_references): - model_id, model_version = "huggingface-spc-bert-base-cased", "*" - sagemaker_session = get_sm_session() - estimator = JumpStartEstimator( model_id=model_id, - role=sagemaker_session.get_caller_identity_arn(), - sagemaker_session=sagemaker_session, - tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], ) estimator.fit( @@ -85,14 +80,11 @@ def test_jumpstart_hub_estimator(setup, add_model_references): training_job_name=estimator.latest_training_job.name, model_id=model_id, model_version=model_version, - sagemaker_session=get_sm_session(), ) # uses ml.p3.2xlarge instance predictor = estimator.deploy( tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], - role=get_sm_session().get_caller_identity_arn(), - sagemaker_session=get_sm_session(), ) response = predictor.predict(["hello", "world"]) @@ -100,7 +92,8 @@ def test_jumpstart_hub_estimator(setup, add_model_references): assert response is not None -def test_jumpstart_hub_estimator_with_default_session(setup, add_model_references): +def test_jumpstart_hub_estimator_with_session(setup, add_model_references): + model_id, model_version = "huggingface-spc-bert-base-cased", "*" sagemaker_session = get_sm_session() @@ -125,12 +118,14 @@ def test_jumpstart_hub_estimator_with_default_session(setup, add_model_reference training_job_name=estimator.latest_training_job.name, model_id=model_id, model_version=model_version, + sagemaker_session=get_sm_session(), ) # uses ml.p3.2xlarge instance predictor = estimator.deploy( tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], role=get_sm_session().get_caller_identity_arn(), + sagemaker_session=get_sm_session(), ) response = predictor.predict(["hello", "world"]) @@ -144,9 +139,8 @@ def test_jumpstart_hub_gated_estimator_with_eula(setup, add_model_references): estimator = JumpStartEstimator( model_id=model_id, - role=get_sm_session().get_caller_identity_arn(), - sagemaker_session=get_sm_session(), hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], ) estimator.fit( @@ -161,14 +155,11 @@ def test_jumpstart_hub_gated_estimator_with_eula(setup, add_model_references): training_job_name=estimator.latest_training_job.name, model_id=model_id, model_version=model_version, - sagemaker_session=get_sm_session(), ) # uses ml.p3.2xlarge instance predictor = estimator.deploy( tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], - role=get_sm_session().get_caller_identity_arn(), - sagemaker_session=get_sm_session(), ) response = predictor.predict(["hello", "world"]) @@ -182,9 +173,8 @@ def test_jumpstart_hub_gated_estimator_without_eula(setup, add_model_references) estimator = JumpStartEstimator( model_id=model_id, - role=get_sm_session().get_caller_identity_arn(), - sagemaker_session=get_sm_session(), hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], ) with pytest.raises(Exception): estimator.fit( diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index 4021599120..0c9065feb5 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -15553,6 +15553,8 @@ }, "inference_enable_network_isolation": True, "training_enable_network_isolation": True, + "default_training_dataset_uri": None, + "default_training_dataset_key": "training-datasets/tf_flowers/", "resource_name_base": "pt-ic-mobilenet-v2", "hosting_eula_key": None, "hosting_model_package_arns": {}, diff --git a/tests/unit/sagemaker/jumpstart/test_types.py b/tests/unit/sagemaker/jumpstart/test_types.py index acce8ef4f1..0b5ef63947 100644 --- a/tests/unit/sagemaker/jumpstart/test_types.py +++ b/tests/unit/sagemaker/jumpstart/test_types.py @@ -378,6 +378,7 @@ def test_jumpstart_model_specs(): specs1.training_script_key == "source-directory-tarballs/pytorch/transfer_learning/ic/v2.3.0/sourcedir.tar.gz" ) + assert specs1.default_training_dataset_key == "training-datasets/tf_flowers/" assert specs1.hyperparameters == [ JumpStartHyperparameter( { From 62184c14ae1364a1240b914964d19243085c32ca Mon Sep 17 00:00:00 2001 From: Rohan Narayan Date: Thu, 6 Mar 2025 19:34:55 +0000 Subject: [PATCH 06/13] fix formatting --- src/sagemaker/jumpstart/hub/interfaces.py | 4 +++- src/sagemaker/jumpstart/hub/utils.py | 24 ++++++++++++++++++----- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/src/sagemaker/jumpstart/hub/interfaces.py b/src/sagemaker/jumpstart/hub/interfaces.py index d67ef96b3e..6ba5a37c3c 100644 --- a/src/sagemaker/jumpstart/hub/interfaces.py +++ b/src/sagemaker/jumpstart/hub/interfaces.py @@ -670,7 +670,9 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: ) if self.training_supported: - self.default_training_dataset_uri: Optional[str] = json_obj.get("DefaultTrainingDatasetUri") + self.default_training_dataset_uri: Optional[str] = json_obj.get( + "DefaultTrainingDatasetUri" + ) self.training_model_package_artifact_uri: Optional[str] = json_obj.get( "TrainingModelPackageArtifactUri" ) diff --git a/src/sagemaker/jumpstart/hub/utils.py b/src/sagemaker/jumpstart/hub/utils.py index a9a2a0de97..75af019ca6 100644 --- a/src/sagemaker/jumpstart/hub/utils.py +++ b/src/sagemaker/jumpstart/hub/utils.py @@ -220,7 +220,12 @@ def get_hub_model_version( sagemaker_session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION try: - hub_content_summaries = _list_hub_content_versions_helper(hub_name=hub_name, hub_content_name=hub_model_name, hub_content_type=hub_model_type, sagemaker_session=sagemaker_session) + hub_content_summaries = _list_hub_content_versions_helper( + hub_name=hub_name, + hub_content_name=hub_model_name, + hub_content_type=hub_model_type, + sagemaker_session=sagemaker_session, + ) except Exception as ex: raise Exception(f"Failed calling list_hub_content_versions: {str(ex)}") @@ -236,7 +241,10 @@ def get_hub_model_version( return marketplace_hub_content_version raise -def _list_hub_content_versions_helper(hub_name, hub_content_name, hub_content_type, sagemaker_session): + +def _list_hub_content_versions_helper( + hub_name, hub_content_name, hub_content_type, sagemaker_session +): all_hub_content_summaries = [] list_hub_content_versions_response = sagemaker_session.list_hub_content_versions( hub_name=hub_name, hub_content_name=hub_content_name, hub_content_type=hub_content_type @@ -244,11 +252,17 @@ def _list_hub_content_versions_helper(hub_name, hub_content_name, hub_content_ty all_hub_content_summaries.extend(list_hub_content_versions_response.get("HubContentSummaries")) while "NextToken" in list_hub_content_versions_response: list_hub_content_versions_response = sagemaker_session.list_hub_content_versions( - hub_name=hub_name, hub_content_name=hub_content_name, hub_content_type=hub_content_type, next_token=list_hub_content_versions_response["NextToken"] - ) - all_hub_content_summaries.extend(list_hub_content_versions_response.get("HubContentSummaries")) + hub_name=hub_name, + hub_content_name=hub_content_name, + hub_content_type=hub_content_type, + next_token=list_hub_content_versions_response["NextToken"], + ) + all_hub_content_summaries.extend( + list_hub_content_versions_response.get("HubContentSummaries") + ) return all_hub_content_summaries + def _get_hub_model_version_for_open_weight_version( hub_content_summaries: List[Any], hub_model_version: Optional[str] = None ) -> str: From 562434678c7f4c81115c8a28ca4d0e2ecba97ce5 Mon Sep 17 00:00:00 2001 From: Rohan Narayan Date: Thu, 6 Mar 2025 20:28:00 +0000 Subject: [PATCH 07/13] fix unit tests --- tests/unit/sagemaker/jumpstart/estimator/test_estimator.py | 1 + tests/unit/sagemaker/jumpstart/test_types.py | 3 --- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index 4a64b413f4..225b083cbe 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -688,6 +688,7 @@ def test_gated_model_non_model_package_s3_uri( instance_count=1, image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pyt" "orch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04", + model_uri='s3://jumpstart-private-cache-prod-us-west-2/some/dummy/key', source_dir="s3://jumpstart-cache-prod-us-west-2/source-d" "irectory-tarballs/meta/transfer_learning/textgeneration/prepack/v1.0.1/sourcedir.tar.gz", entry_point="transfer_learning.py", diff --git a/tests/unit/sagemaker/jumpstart/test_types.py b/tests/unit/sagemaker/jumpstart/test_types.py index 0b5ef63947..8697a673c6 100644 --- a/tests/unit/sagemaker/jumpstart/test_types.py +++ b/tests/unit/sagemaker/jumpstart/test_types.py @@ -332,9 +332,6 @@ def test_jumpstart_model_header(): def test_use_training_model_artifact(): specs1 = JumpStartModelSpecs(BASE_SPEC) assert specs1.use_training_model_artifact() - specs1.gated_bucket = True - assert not specs1.use_training_model_artifact() - specs1.gated_bucket = False specs1.training_model_package_artifact_uris = {"region1": "blah", "region2": "blah2"} assert not specs1.use_training_model_artifact() From bd02b60f3cbc7cd17e9a560fc8b1960e9303f58b Mon Sep 17 00:00:00 2001 From: Rohan Narayan Date: Fri, 7 Mar 2025 16:32:54 +0000 Subject: [PATCH 08/13] fix model_uri usage issue --- src/sagemaker/estimator.py | 2 +- src/sagemaker/jumpstart/factory/estimator.py | 10 ++++++++-- src/sagemaker/jumpstart/types.py | 5 +++++ .../test_jumpstart_private_hub_estimator.py | 16 ++++++++-------- .../jumpstart/estimator/test_estimator.py | 1 - tests/unit/sagemaker/jumpstart/test_types.py | 3 +++ 6 files changed, 25 insertions(+), 12 deletions(-) diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index c5d67a7be7..e16aa829cc 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -2511,7 +2511,7 @@ def start_new(cls, estimator, inputs, experiment_config): train_args = cls._get_train_args(estimator, inputs, experiment_config) logger.debug("Train args after processing defaults: %s", train_args) - print("rohan debug: ", train_args) + estimator.sagemaker_session.train(**train_args) return cls(estimator.sagemaker_session, estimator._current_job_name) diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index c5b8570e27..429bbbaae8 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -56,6 +56,7 @@ JUMPSTART_LOGGER, TRAINING_ENTRY_POINT_SCRIPT_NAME, SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY, + JUMPSTART_MODEL_HUB_NAME, ) from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType from sagemaker.jumpstart.factory import model @@ -630,8 +631,13 @@ def _add_model_reference_arn_to_kwargs( def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstimatorInitKwargs: """Sets model uri in kwargs based on default or override, returns full kwargs.""" - - if _model_supports_training_model_uri(**get_model_info_default_kwargs(kwargs)): + # hub_arn is by default None unless the user specifies the hub_name + # If no hub_name is specified, it is assumed the public hub + is_private_hub = JUMPSTART_MODEL_HUB_NAME not in kwargs.hub_arn if kwargs.hub_arn else False + if ( + _model_supports_training_model_uri(**get_model_info_default_kwargs(kwargs)) + or is_private_hub + ): default_model_uri = model_uris.retrieve( model_scope=JumpStartScriptScope.TRAINING, instance_type=kwargs.instance_type, diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 5997059cc3..e748be4c3b 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1940,6 +1940,11 @@ def use_inference_script_uri(self) -> bool: def use_training_model_artifact(self) -> bool: """Returns True if the model should use a model uri when kicking off training job.""" + # gated model never uses training model artifact + if self.gated_bucket: + return False + + # otherwise, return true is a training model package is not set return len(self.training_model_package_artifact_uris or {}) == 0 def is_gated_model(self) -> bool: diff --git a/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py b/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py index 633230f373..ece9c7807d 100644 --- a/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py +++ b/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py @@ -151,18 +151,18 @@ def test_jumpstart_hub_gated_estimator_with_eula(setup, add_model_references): }, ) - estimator = JumpStartEstimator.attach( - training_job_name=estimator.latest_training_job.name, - model_id=model_id, - model_version=model_version, - ) - - # uses ml.p3.2xlarge instance predictor = estimator.deploy( tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], + role=get_sm_session().get_caller_identity_arn(), + sagemaker_session=get_sm_session(), ) - response = predictor.predict(["hello", "world"]) + payload = { + "inputs": "some-payload", + "parameters": {"max_new_tokens": 256, "top_p": 0.9, "temperature": 0.6}, + } + + response = predictor.predict(payload, custom_attributes="accept_eula=true") assert response is not None diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index 225b083cbe..4a64b413f4 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -688,7 +688,6 @@ def test_gated_model_non_model_package_s3_uri( instance_count=1, image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pyt" "orch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04", - model_uri='s3://jumpstart-private-cache-prod-us-west-2/some/dummy/key', source_dir="s3://jumpstart-cache-prod-us-west-2/source-d" "irectory-tarballs/meta/transfer_learning/textgeneration/prepack/v1.0.1/sourcedir.tar.gz", entry_point="transfer_learning.py", diff --git a/tests/unit/sagemaker/jumpstart/test_types.py b/tests/unit/sagemaker/jumpstart/test_types.py index 8697a673c6..0b5ef63947 100644 --- a/tests/unit/sagemaker/jumpstart/test_types.py +++ b/tests/unit/sagemaker/jumpstart/test_types.py @@ -332,6 +332,9 @@ def test_jumpstart_model_header(): def test_use_training_model_artifact(): specs1 = JumpStartModelSpecs(BASE_SPEC) assert specs1.use_training_model_artifact() + specs1.gated_bucket = True + assert not specs1.use_training_model_artifact() + specs1.gated_bucket = False specs1.training_model_package_artifact_uris = {"region1": "blah", "region2": "blah2"} assert not specs1.use_training_model_artifact() From ed93b9eacbcdca37546c692043431d4eed3985df Mon Sep 17 00:00:00 2001 From: Rohan Narayan Date: Fri, 7 Mar 2025 16:37:24 +0000 Subject: [PATCH 09/13] fix some formatting --- src/sagemaker/estimator.py | 1 - src/sagemaker/jumpstart/factory/estimator.py | 1 + src/sagemaker/jumpstart/types.py | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index e16aa829cc..fa40719c9f 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -2511,7 +2511,6 @@ def start_new(cls, estimator, inputs, experiment_config): train_args = cls._get_train_args(estimator, inputs, experiment_config) logger.debug("Train args after processing defaults: %s", train_args) - estimator.sagemaker_session.train(**train_args) return cls(estimator.sagemaker_session, estimator._current_job_name) diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 429bbbaae8..12eb30daaf 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -313,6 +313,7 @@ def _add_hub_access_config_to_kwargs_inputs( kwargs: JumpStartEstimatorFitKwargs, hub_access_config=None ): """Adds HubAccessConfig to kwargs inputs""" + dataset_uri = kwargs.specs.default_training_dataset_uri if isinstance(kwargs.inputs, str): if dataset_uri is not None and dataset_uri == kwargs.inputs: diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index e748be4c3b..0cd4bcc902 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1940,7 +1940,7 @@ def use_inference_script_uri(self) -> bool: def use_training_model_artifact(self) -> bool: """Returns True if the model should use a model uri when kicking off training job.""" - # gated model never uses training model artifact + # gated model never use training model artifact if self.gated_bucket: return False From e4974f384bc9eb9bf237ab9b97bdcf2e92c74636 Mon Sep 17 00:00:00 2001 From: Rohan Narayan Date: Fri, 7 Mar 2025 19:12:30 +0000 Subject: [PATCH 10/13] separate private hub setup code --- .../test_jumpstart_private_hub_estimator.py | 23 +---------- .../model/test_jumpstart_private_hub_model.py | 25 +---------- .../sagemaker/jumpstart/private_hub/setup.py | 41 +++++++++++++++++++ 3 files changed, 45 insertions(+), 44 deletions(-) create mode 100644 tests/integ/sagemaker/jumpstart/private_hub/setup.py diff --git a/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py b/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py index ece9c7807d..6e4e7b17f5 100644 --- a/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py +++ b/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py @@ -34,29 +34,10 @@ get_training_dataset_for_model_and_version, ) -MAX_INIT_TIME_SECONDS = 5 - -TEST_MODEL_IDS = { - "huggingface-spc-bert-base-cased", - "meta-textgeneration-llama-2-7b", - "catboost-regression-model", -} - +from tests.integ.sagemaker.jumpstart.private_hub.setup import add_model_references -@with_exponential_backoff() -def create_model_reference(hub_instance, model_arn): - hub_instance.create_model_reference(model_arn=model_arn) - -@pytest.fixture(scope="session") -def add_model_references(): - # Create Model References to test in Hub - hub_instance = Hub( - hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session() - ) - for model in TEST_MODEL_IDS: - model_arn = get_public_hub_model_arn(hub_instance, model) - create_model_reference(hub_instance, model_arn) +MAX_INIT_TIME_SECONDS = 5 def test_jumpstart_hub_estimator(setup, add_model_references): diff --git a/tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py b/tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py index a64db4a97d..5e17c98ec5 100644 --- a/tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py +++ b/tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py @@ -35,31 +35,10 @@ with_exponential_backoff, ) -MAX_INIT_TIME_SECONDS = 5 - -TEST_MODEL_IDS = { - "catboost-classification-model", - "huggingface-txt2img-conflictx-complex-lineart", - "meta-textgeneration-llama-2-7b", - "meta-textgeneration-llama-3-2-1b", - "catboost-regression-model", -} - +from tests.integ.sagemaker.jumpstart.private_hub.setup import add_model_references -@with_exponential_backoff() -def create_model_reference(hub_instance, model_arn): - hub_instance.create_model_reference(model_arn=model_arn) - -@pytest.fixture(scope="session") -def add_model_references(): - # Create Model References to test in Hub - hub_instance = Hub( - hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session() - ) - for model in TEST_MODEL_IDS: - model_arn = get_public_hub_model_arn(hub_instance, model) - create_model_reference(hub_instance, model_arn) +MAX_INIT_TIME_SECONDS = 5 def test_jumpstart_hub_model(setup, add_model_references): diff --git a/tests/integ/sagemaker/jumpstart/private_hub/setup.py b/tests/integ/sagemaker/jumpstart/private_hub/setup.py new file mode 100644 index 0000000000..267272ac1b --- /dev/null +++ b/tests/integ/sagemaker/jumpstart/private_hub/setup.py @@ -0,0 +1,41 @@ +from __future__ import absolute_import + +import os + +import pytest +from sagemaker.jumpstart.hub.hub import Hub + +from tests.integ.sagemaker.jumpstart.constants import ( + ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME, +) +from tests.integ.sagemaker.jumpstart.utils import ( + get_public_hub_model_arn, + get_sm_session, + with_exponential_backoff, +) + + +TEST_MODEL_IDS = { + "catboost-classification-model", + "huggingface-txt2img-conflictx-complex-lineart", + "meta-textgeneration-llama-2-7b", + "meta-textgeneration-llama-3-2-1b", + "catboost-regression-model", + "huggingface-spc-bert-base-cased", +} + + +@with_exponential_backoff() +def create_model_reference(hub_instance, model_arn): + hub_instance.create_model_reference(model_arn=model_arn) + + +@pytest.fixture(scope="session") +def add_model_references(): + # Create Model References to test in Hub + hub_instance = Hub( + hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session() + ) + for model in TEST_MODEL_IDS: + model_arn = get_public_hub_model_arn(hub_instance, model) + create_model_reference(hub_instance, model_arn) From cb57010ce2cbd2d5596f35d72370320b303dc5dc Mon Sep 17 00:00:00 2001 From: Rohan Narayan Date: Fri, 7 Mar 2025 19:27:19 +0000 Subject: [PATCH 11/13] add try catch block --- .../test_jumpstart_private_hub_estimator.py | 26 +++++++++++- .../model/test_jumpstart_private_hub_model.py | 28 ++++++++++++- .../sagemaker/jumpstart/private_hub/setup.py | 41 ------------------- 3 files changed, 50 insertions(+), 45 deletions(-) delete mode 100644 tests/integ/sagemaker/jumpstart/private_hub/setup.py diff --git a/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py b/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py index 6e4e7b17f5..e134d0ebd4 100644 --- a/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py +++ b/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py @@ -34,10 +34,32 @@ get_training_dataset_for_model_and_version, ) -from tests.integ.sagemaker.jumpstart.private_hub.setup import add_model_references +MAX_INIT_TIME_SECONDS = 5 +TEST_MODEL_IDS = { + "huggingface-spc-bert-base-cased", + "meta-textgeneration-llama-2-7b", + "catboost-regression-model", +} -MAX_INIT_TIME_SECONDS = 5 + +@with_exponential_backoff() +def create_model_reference(hub_instance, model_arn): + try: + hub_instance.create_model_reference(model_arn=model_arn) + except: + pass + + +@pytest.fixture(scope="session") +def add_model_references(): + # Create Model References to test in Hub + hub_instance = Hub( + hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session() + ) + for model in TEST_MODEL_IDS: + model_arn = get_public_hub_model_arn(hub_instance, model) + create_model_reference(hub_instance, model_arn) def test_jumpstart_hub_estimator(setup, add_model_references): diff --git a/tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py b/tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py index 5e17c98ec5..d62fa15105 100644 --- a/tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py +++ b/tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py @@ -35,10 +35,34 @@ with_exponential_backoff, ) -from tests.integ.sagemaker.jumpstart.private_hub.setup import add_model_references +MAX_INIT_TIME_SECONDS = 5 +TEST_MODEL_IDS = { + "catboost-classification-model", + "huggingface-txt2img-conflictx-complex-lineart", + "meta-textgeneration-llama-2-7b", + "meta-textgeneration-llama-3-2-1b", + "catboost-regression-model", +} -MAX_INIT_TIME_SECONDS = 5 + +@with_exponential_backoff() +def create_model_reference(hub_instance, model_arn): + try: + hub_instance.create_model_reference(model_arn=model_arn) + except: + pass + + +@pytest.fixture(scope="session") +def add_model_references(): + # Create Model References to test in Hub + hub_instance = Hub( + hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session() + ) + for model in TEST_MODEL_IDS: + model_arn = get_public_hub_model_arn(hub_instance, model) + create_model_reference(hub_instance, model_arn) def test_jumpstart_hub_model(setup, add_model_references): diff --git a/tests/integ/sagemaker/jumpstart/private_hub/setup.py b/tests/integ/sagemaker/jumpstart/private_hub/setup.py deleted file mode 100644 index 267272ac1b..0000000000 --- a/tests/integ/sagemaker/jumpstart/private_hub/setup.py +++ /dev/null @@ -1,41 +0,0 @@ -from __future__ import absolute_import - -import os - -import pytest -from sagemaker.jumpstart.hub.hub import Hub - -from tests.integ.sagemaker.jumpstart.constants import ( - ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME, -) -from tests.integ.sagemaker.jumpstart.utils import ( - get_public_hub_model_arn, - get_sm_session, - with_exponential_backoff, -) - - -TEST_MODEL_IDS = { - "catboost-classification-model", - "huggingface-txt2img-conflictx-complex-lineart", - "meta-textgeneration-llama-2-7b", - "meta-textgeneration-llama-3-2-1b", - "catboost-regression-model", - "huggingface-spc-bert-base-cased", -} - - -@with_exponential_backoff() -def create_model_reference(hub_instance, model_arn): - hub_instance.create_model_reference(model_arn=model_arn) - - -@pytest.fixture(scope="session") -def add_model_references(): - # Create Model References to test in Hub - hub_instance = Hub( - hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session() - ) - for model in TEST_MODEL_IDS: - model_arn = get_public_hub_model_arn(hub_instance, model) - create_model_reference(hub_instance, model_arn) From c6f0c178ad8c4be5021b3bd99832a995aaefaa3e Mon Sep 17 00:00:00 2001 From: Rohan Narayan Date: Fri, 7 Mar 2025 20:12:37 +0000 Subject: [PATCH 12/13] fix flake8 issue so except clause is not bare --- .../estimator/test_jumpstart_private_hub_estimator.py | 2 +- .../private_hub/model/test_jumpstart_private_hub_model.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py b/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py index e134d0ebd4..05ce72e877 100644 --- a/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py +++ b/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py @@ -47,7 +47,7 @@ def create_model_reference(hub_instance, model_arn): try: hub_instance.create_model_reference(model_arn=model_arn) - except: + except(Exception): pass diff --git a/tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py b/tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py index d62fa15105..cee947400d 100644 --- a/tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py +++ b/tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py @@ -50,7 +50,7 @@ def create_model_reference(hub_instance, model_arn): try: hub_instance.create_model_reference(model_arn=model_arn) - except: + except(Exception): pass From d572fbe6ddd2d885044eb58bb681343414b7dc72 Mon Sep 17 00:00:00 2001 From: Rohan Narayan Date: Fri, 7 Mar 2025 20:25:23 +0000 Subject: [PATCH 13/13] black formatting --- .../estimator/test_jumpstart_private_hub_estimator.py | 2 +- .../private_hub/model/test_jumpstart_private_hub_model.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py b/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py index 05ce72e877..a6e33f1bdf 100644 --- a/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py +++ b/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py @@ -47,7 +47,7 @@ def create_model_reference(hub_instance, model_arn): try: hub_instance.create_model_reference(model_arn=model_arn) - except(Exception): + except Exception: pass diff --git a/tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py b/tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py index cee947400d..c7e039693b 100644 --- a/tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py +++ b/tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py @@ -50,7 +50,7 @@ def create_model_reference(hub_instance, model_arn): try: hub_instance.create_model_reference(model_arn=model_arn) - except(Exception): + except Exception: pass