@@ -68,6 +68,41 @@ def pca_estimator():
6868 return pca
6969
7070
71+ @pytest .fixture
72+ def pca_estimator_with_env ():
73+ s3_output_location = 's3://sagemaker/models'
74+
75+ pca = sagemaker .estimator .Estimator (
76+ PCA_IMAGE ,
77+ role = EXECUTION_ROLE ,
78+ instance_count = 1 ,
79+ instance_type = 'ml.c4.xlarge' ,
80+ output_path = s3_output_location ,
81+ environment = {
82+ 'JobName' : "job_name" ,
83+ 'ModelName' : "model_name"
84+ },
85+ subnets = [
86+ 'subnet-00000000000000000' ,
87+ 'subnet-00000000000000001'
88+ ]
89+ )
90+
91+ pca .set_hyperparameters (
92+ feature_dim = 50000 ,
93+ num_components = 10 ,
94+ subtract_mean = True ,
95+ algorithm_mode = 'randomized' ,
96+ mini_batch_size = 200
97+ )
98+
99+ pca .sagemaker_session = MagicMock ()
100+ pca .sagemaker_session .boto_region_name = 'us-east-1'
101+ pca .sagemaker_session ._default_bucket = 'sagemaker'
102+
103+ return pca
104+
105+
71106@pytest .fixture
72107def pca_estimator_with_debug_hook ():
73108 s3_output_location = 's3://sagemaker/models'
@@ -156,6 +191,31 @@ def pca_model():
156191 )
157192
158193
194+ @pytest .fixture
195+ def pca_model_with_env ():
196+ model_data = 's3://sagemaker/models/pca.tar.gz'
197+ return Model (
198+ model_data = model_data ,
199+ image_uri = PCA_IMAGE ,
200+ role = EXECUTION_ROLE ,
201+ name = 'pca-model' ,
202+ env = {
203+ 'JobName' : "job_name" ,
204+ 'ModelName' : "model_name"
205+ },
206+ vpc_config = {
207+ "SecurityGroupIds" : ["sg-00000000000000000" ],
208+ "Subnets" : ["subnet-00000000000000000" , "subnet-00000000000000001" ]
209+ },
210+ image_config = {
211+ "RepositoryAccessMode" : "Vpc" ,
212+ "RepositoryAuthConfig" : {
213+ "RepositoryCredentialsProviderArn" : "arn"
214+ }
215+ }
216+ )
217+
218+
159219@pytest .fixture
160220def pca_transformer (pca_model ):
161221 return Transformer (
@@ -537,6 +597,63 @@ def test_training_step_creation_with_model(pca_estimator):
537597 }
538598
539599
600+ @patch ('botocore.client.BaseClient._make_api_call' , new = mock_boto_api_call )
601+ @patch .object (boto3 .session .Session , 'region_name' , 'us-east-1' )
602+ def test_training_step_creation_with_model_with_env (pca_estimator_with_env ):
603+ training_step = TrainingStep ('Training' , estimator = pca_estimator_with_env , job_name = 'TrainingJob' )
604+ model_step = ModelStep ('Training - Save Model' , training_step .get_expected_model (model_name = training_step .output ()['TrainingJobName' ]))
605+ training_step .next (model_step )
606+ assert training_step .to_dict () == {
607+ 'Type' : 'Task' ,
608+ 'Parameters' : {
609+ 'AlgorithmSpecification' : {
610+ 'TrainingImage' : PCA_IMAGE ,
611+ 'TrainingInputMode' : 'File'
612+ },
613+ 'OutputDataConfig' : {
614+ 'S3OutputPath' : 's3://sagemaker/models'
615+ },
616+ 'StoppingCondition' : {
617+ 'MaxRuntimeInSeconds' : 86400
618+ },
619+ 'ResourceConfig' : {
620+ 'InstanceCount' : 1 ,
621+ 'InstanceType' : 'ml.c4.xlarge' ,
622+ 'VolumeSizeInGB' : 30
623+ },
624+ 'RoleArn' : EXECUTION_ROLE ,
625+ 'HyperParameters' : {
626+ 'feature_dim' : '50000' ,
627+ 'num_components' : '10' ,
628+ 'subtract_mean' : 'True' ,
629+ 'algorithm_mode' : 'randomized' ,
630+ 'mini_batch_size' : '200'
631+ },
632+ 'TrainingJobName' : 'TrainingJob'
633+ },
634+ 'Resource' : 'arn:aws:states:::sagemaker:createTrainingJob.sync' ,
635+ 'Next' : 'Training - Save Model'
636+ }
637+
638+ assert model_step .to_dict () == {
639+ 'Type' : 'Task' ,
640+ 'Resource' : 'arn:aws:states:::sagemaker:createModel' ,
641+ 'Parameters' : {
642+ 'ExecutionRoleArn' : EXECUTION_ROLE ,
643+ 'ModelName.$' : "$['TrainingJobName']" ,
644+ 'PrimaryContainer' : {
645+ 'Environment' : {
646+ 'JobName' : 'job_name' ,
647+ 'ModelName' : 'model_name'
648+ },
649+ 'Image' : PCA_IMAGE ,
650+ 'ModelDataUrl.$' : "$['ModelArtifacts']['S3ModelArtifacts']"
651+ }
652+ },
653+ 'End' : True
654+ }
655+
656+
540657@patch ('botocore.client.BaseClient._make_api_call' , new = mock_boto_api_call )
541658@patch .object (boto3 .session .Session , 'region_name' , 'us-east-1' )
542659def test_training_step_creation_with_framework (tensorflow_estimator ):
@@ -806,6 +923,31 @@ def test_get_expected_model(pca_estimator):
806923 }
807924
808925
926+ @patch ('botocore.client.BaseClient._make_api_call' , new = mock_boto_api_call )
927+ @patch .object (boto3 .session .Session , 'region_name' , 'us-east-1' )
928+ def test_get_expected_model_with_env (pca_estimator_with_env ):
929+ training_step = TrainingStep ('Training' , estimator = pca_estimator_with_env , job_name = 'TrainingJob' )
930+ expected_model = training_step .get_expected_model ()
931+ model_step = ModelStep ('Create model' , model = expected_model , model_name = 'pca-model' )
932+ assert model_step .to_dict () == {
933+ 'Type' : 'Task' ,
934+ 'Parameters' : {
935+ 'ExecutionRoleArn' : EXECUTION_ROLE ,
936+ 'ModelName' : 'pca-model' ,
937+ 'PrimaryContainer' : {
938+ 'Environment' : {
939+ 'JobName' : 'job_name' ,
940+ 'ModelName' : 'model_name'
941+ },
942+ 'Image' : expected_model .image_uri ,
943+ 'ModelDataUrl.$' : "$['ModelArtifacts']['S3ModelArtifacts']"
944+ }
945+ },
946+ 'Resource' : 'arn:aws:states:::sagemaker:createModel' ,
947+ 'End' : True
948+ }
949+
950+
809951@patch ('botocore.client.BaseClient._make_api_call' , new = mock_boto_api_call )
810952@patch .object (boto3 .session .Session , 'region_name' , 'us-east-1' )
811953def test_get_expected_model_with_framework_estimator (tensorflow_estimator ):
@@ -859,6 +1001,29 @@ def test_model_step_creation(pca_model):
8591001 }
8601002
8611003
1004+ @patch .object (boto3 .session .Session , 'region_name' , 'us-east-1' )
1005+ def test_model_step_creation_with_env (pca_model_with_env ):
1006+ step = ModelStep ('Create model' , model = pca_model_with_env , model_name = 'pca-model' , tags = DEFAULT_TAGS )
1007+ assert step .to_dict () == {
1008+ 'Type' : 'Task' ,
1009+ 'Parameters' : {
1010+ 'ExecutionRoleArn' : EXECUTION_ROLE ,
1011+ 'ModelName' : 'pca-model' ,
1012+ 'PrimaryContainer' : {
1013+ 'Environment' : {
1014+ 'JobName' : 'job_name' ,
1015+ 'ModelName' : 'model_name'
1016+ },
1017+ 'Image' : pca_model_with_env .image_uri ,
1018+ 'ModelDataUrl' : pca_model_with_env .model_data
1019+ },
1020+ 'Tags' : DEFAULT_TAGS_LIST
1021+ },
1022+ 'Resource' : 'arn:aws:states:::sagemaker:createModel' ,
1023+ 'End' : True
1024+ }
1025+
1026+
8621027@patch .object (boto3 .session .Session , 'region_name' , 'us-east-1' )
8631028def test_endpoint_config_step_creation (pca_model ):
8641029 data_capture_config = DataCaptureConfig (
0 commit comments