3434from stepfunctions .steps .sagemaker import TrainingStep , TransformStep , ModelStep , EndpointStep , EndpointConfigStep , TuningStep , ProcessingStep
3535from stepfunctions .workflow import Workflow
3636
37- from tests .integ import DATA_DIR , DEFAULT_TIMEOUT_MINUTES
37+ from tests .integ import DATA_DIR , DEFAULT_TIMEOUT_MINUTES , SAGEMAKER_RETRY_STRATEGY
3838from tests .integ .timeout import timeout
3939from tests .integ .utils import (
4040 state_machine_delete_wait ,
@@ -83,6 +83,7 @@ def test_training_step(pca_estimator_fixture, record_set_fixture, sfn_client, sf
8383 # Build workflow definition
8484 job_name = generate_job_name ()
8585 training_step = TrainingStep ('create_training_job_step' , estimator = pca_estimator_fixture , job_name = job_name , data = record_set_fixture , mini_batch_size = 200 )
86+ training_step .add_retry (SAGEMAKER_RETRY_STRATEGY )
8687 workflow_graph = Chain ([training_step ])
8788
8889 with timeout (minutes = DEFAULT_TIMEOUT_MINUTES ):
@@ -110,6 +111,7 @@ def test_model_step(trained_estimator, sfn_client, sagemaker_session, sfn_role_a
110111 # Build workflow definition
111112 model_name = generate_job_name ()
112113 model_step = ModelStep ('create_model_step' , model = trained_estimator .create_model (), model_name = model_name )
114+ model_step .add_retry (SAGEMAKER_RETRY_STRATEGY )
113115 workflow_graph = Chain ([model_step ])
114116
115117 with timeout (minutes = DEFAULT_TIMEOUT_MINUTES ):
@@ -142,6 +144,7 @@ def test_transform_step(trained_estimator, sfn_client, sfn_role_arn):
142144
143145 # Create a model step to save the model
144146 model_step = ModelStep ('create_model_step' , model = trained_estimator .create_model (), model_name = job_name )
147+ model_step .add_retry (SAGEMAKER_RETRY_STRATEGY )
145148
146149 # Upload data for transformation to S3
147150 data_path = os .path .join (DATA_DIR , "one_p_mnist" )
@@ -153,6 +156,7 @@ def test_transform_step(trained_estimator, sfn_client, sfn_role_arn):
153156
154157 # Build workflow definition
155158 transform_step = TransformStep ('create_transform_job_step' , pca_transformer , job_name = job_name , model_name = job_name , data = transform_input , content_type = "text/csv" )
159+ transform_step .add_retry (SAGEMAKER_RETRY_STRATEGY )
156160 workflow_graph = Chain ([model_step , transform_step ])
157161
158162 with timeout (minutes = DEFAULT_TIMEOUT_MINUTES ):
@@ -184,6 +188,7 @@ def test_endpoint_config_step(trained_estimator, sfn_client, sagemaker_session,
184188 # Build workflow definition
185189 endpoint_config_name = unique_name_from_base ("integ-test-endpoint-config" )
186190 endpoint_config_step = EndpointConfigStep ('create_endpoint_config_step' , endpoint_config_name = endpoint_config_name , model_name = model .name , initial_instance_count = INSTANCE_COUNT , instance_type = INSTANCE_TYPE )
191+ endpoint_config_step .add_retry (SAGEMAKER_RETRY_STRATEGY )
187192 workflow_graph = Chain ([endpoint_config_step ])
188193
189194 with timeout (minutes = DEFAULT_TIMEOUT_MINUTES ):
@@ -224,6 +229,7 @@ def test_create_endpoint_step(trained_estimator, record_set_fixture, sfn_client,
224229 # Build workflow definition
225230 endpoint_name = unique_name_from_base ("integ-test-endpoint" )
226231 endpoint_step = EndpointStep ('create_endpoint_step' , endpoint_name = endpoint_name , endpoint_config_name = model .name )
232+ endpoint_step .add_retry (SAGEMAKER_RETRY_STRATEGY )
227233 workflow_graph = Chain ([endpoint_step ])
228234
229235 with timeout (minutes = DEFAULT_TIMEOUT_MINUTES ):
@@ -279,6 +285,7 @@ def test_tuning_step(sfn_client, record_set_for_hyperparameter_tuning, sagemaker
279285
280286 # Build workflow definition
281287 tuning_step = TuningStep ('Tuning' , tuner = tuner , job_name = job_name , data = record_set_for_hyperparameter_tuning )
288+ tuning_step .add_retry (SAGEMAKER_RETRY_STRATEGY )
282289 workflow_graph = Chain ([tuning_step ])
283290
284291 with timeout (minutes = DEFAULT_TIMEOUT_MINUTES ):
@@ -332,6 +339,7 @@ def test_processing_step(sklearn_processor_fixture, sagemaker_session, sfn_clien
332339 container_arguments = ['--train-test-split-ratio' , '0.2' ],
333340 container_entrypoint = ['python3' , '/opt/ml/processing/input/code/preprocessor.py' ],
334341 )
342+ processing_step .add_retry (SAGEMAKER_RETRY_STRATEGY )
335343 workflow_graph = Chain ([processing_step ])
336344
337345 with timeout (minutes = DEFAULT_TIMEOUT_MINUTES ):
@@ -419,6 +427,7 @@ def test_processing_step_with_placeholders(sklearn_processor_fixture, sagemaker_
419427 container_entrypoint = execution_input ['entrypoint' ],
420428 parameters = parameters
421429 )
430+ processing_step .add_retry (SAGEMAKER_RETRY_STRATEGY )
422431 workflow_graph = Chain ([processing_step ])
423432
424433 with timeout (minutes = DEFAULT_TIMEOUT_MINUTES ):
0 commit comments