|
12 | 12 | # permissions and limitations under the License. |
13 | 13 | from __future__ import absolute_import |
14 | 14 |
|
| 15 | +from enum import Enum |
15 | 16 | from stepfunctions.inputs import ExecutionInput, StepInput |
16 | 17 | from stepfunctions.steps.states import Task |
17 | 18 | from stepfunctions.steps.fields import Field |
18 | | -from stepfunctions.steps.utils import tags_dict_to_kv_list, resource_integration_arn_builder |
19 | | -from stepfunctions.steps.integration_resources import IntegrationPattern, IntegrationServices, SageMakerApi |
| 19 | +from stepfunctions.steps.utils import tags_dict_to_kv_list, get_service_integration_arn |
| 20 | +from stepfunctions.steps.integration_resources import IntegrationPattern |
20 | 21 |
|
21 | 22 | from sagemaker.workflow.airflow import training_config, transform_config, model_config, tuning_config, processing_config |
22 | 23 | from sagemaker.model import Model, FrameworkModel |
23 | 24 | from sagemaker.model_monitor import DataCaptureConfig |
24 | 25 |
|
| 26 | +SageMaker = "sagemaker" |
| 27 | + |
| 28 | + |
| 29 | +class SageMakerApi(Enum): |
| 30 | + CreateTrainingJob = "createTrainingJob" |
| 31 | + CreateTransformJob = "createTransformJob" |
| 32 | + CreateModel = "createModel" |
| 33 | + CreateEndpointConfig = "createEndpointConfig" |
| 34 | + UpdateEndpoint = "updateEndpoint" |
| 35 | + CreateEndpoint = "createEndpoint" |
| 36 | + CreateHyperParameterTuningJob = "createHyperParameterTuningJob" |
| 37 | + CreateProcessingJob = "createProcessingJob" |
| 38 | + |
| 39 | + |
25 | 40 | class TrainingStep(Task): |
26 | 41 |
|
27 | 42 | """ |
@@ -62,15 +77,17 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non |
62 | 77 | """ |
63 | 78 | Example resource arn: arn:aws:states:::sagemaker:createTrainingJob.sync |
64 | 79 | """ |
65 | | - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker, |
66 | | - SageMakerApi.CreateTrainingJob, |
67 | | - IntegrationPattern.WaitForCompletion) |
| 80 | + |
| 81 | + kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker, |
| 82 | + SageMakerApi.CreateTrainingJob, |
| 83 | + IntegrationPattern.WaitForCompletion) |
68 | 84 | else: |
69 | 85 | """ |
70 | 86 | Example resource arn: arn:aws:states:::sagemaker:createTrainingJob |
71 | 87 | """ |
72 | | - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker, |
73 | | - SageMakerApi.CreateTrainingJob) |
| 88 | + |
| 89 | + kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker, |
| 90 | + SageMakerApi.CreateTrainingJob) |
74 | 91 |
|
75 | 92 | if isinstance(job_name, str): |
76 | 93 | parameters = training_config(estimator=estimator, inputs=data, job_name=job_name, mini_batch_size=mini_batch_size) |
@@ -154,15 +171,17 @@ def __init__(self, state_id, transformer, job_name, model_name, data, data_type= |
154 | 171 | """ |
155 | 172 | Example resource arn: arn:aws:states:::sagemaker:createTransformJob.sync |
156 | 173 | """ |
157 | | - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker, |
158 | | - SageMakerApi.CreateTransformJob, |
159 | | - IntegrationPattern.WaitForCompletion) |
| 174 | + |
| 175 | + kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker, |
| 176 | + SageMakerApi.CreateTransformJob, |
| 177 | + IntegrationPattern.WaitForCompletion) |
160 | 178 | else: |
161 | 179 | """ |
162 | 180 | Example resource arn: arn:aws:states:::sagemaker:createTransformJob |
163 | 181 | """ |
164 | | - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker, |
165 | | - SageMakerApi.CreateTransformJob) |
| 182 | + |
| 183 | + kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker, |
| 184 | + SageMakerApi.CreateTransformJob) |
166 | 185 |
|
167 | 186 | if isinstance(job_name, str): |
168 | 187 | parameters = transform_config( |
@@ -248,8 +267,9 @@ def __init__(self, state_id, model, model_name=None, instance_type=None, tags=No |
248 | 267 | """ |
249 | 268 | Example resource arn: arn:aws:states:::sagemaker:createModel |
250 | 269 | """ |
251 | | - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker, |
252 | | - SageMakerApi.CreateModel) |
| 270 | + |
| 271 | + kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker, |
| 272 | + SageMakerApi.CreateModel) |
253 | 273 |
|
254 | 274 | super(ModelStep, self).__init__(state_id, **kwargs) |
255 | 275 |
|
@@ -293,8 +313,9 @@ def __init__(self, state_id, endpoint_config_name, model_name, initial_instance_ |
293 | 313 | """ |
294 | 314 | Example resource arn: arn:aws:states:::sagemaker:createEndpointConfig |
295 | 315 | """ |
296 | | - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker, |
297 | | - SageMakerApi.CreateEndpointConfig) |
| 316 | + |
| 317 | + kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker, |
| 318 | + SageMakerApi.CreateEndpointConfig) |
298 | 319 |
|
299 | 320 | kwargs[Field.Parameters.value] = parameters |
300 | 321 |
|
@@ -330,14 +351,16 @@ def __init__(self, state_id, endpoint_name, endpoint_config_name, tags=None, upd |
330 | 351 | """ |
331 | 352 | Example resource arn: arn:aws:states:::sagemaker:updateEndpoint |
332 | 353 | """ |
333 | | - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker, |
334 | | - SageMakerApi.UpdateEndpoint) |
| 354 | + |
| 355 | + kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker, |
| 356 | + SageMakerApi.UpdateEndpoint) |
335 | 357 | else: |
336 | 358 | """ |
337 | 359 | Example resource arn: arn:aws:states:::sagemaker:createEndpoint |
338 | 360 | """ |
339 | | - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker, |
340 | | - SageMakerApi.CreateEndpoint) |
| 361 | + |
| 362 | + kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker, |
| 363 | + SageMakerApi.CreateEndpoint) |
341 | 364 |
|
342 | 365 | kwargs[Field.Parameters.value] = parameters |
343 | 366 |
|
@@ -378,15 +401,17 @@ def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, ta |
378 | 401 | """ |
379 | 402 | Example resource arn: arn:aws:states:::sagemaker:createHyperParameterTuningJob.sync |
380 | 403 | """ |
381 | | - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker, |
382 | | - SageMakerApi.CreateHyperParameterTuningJob, |
383 | | - IntegrationPattern.WaitForCompletion) |
| 404 | + |
| 405 | + kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker, |
| 406 | + SageMakerApi.CreateHyperParameterTuningJob, |
| 407 | + IntegrationPattern.WaitForCompletion) |
384 | 408 | else: |
385 | 409 | """ |
386 | 410 | Example resource arn: arn:aws:states:::sagemaker:createHyperParameterTuningJob |
387 | 411 | """ |
388 | | - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker, |
389 | | - SageMakerApi.CreateHyperParameterTuningJob) |
| 412 | + |
| 413 | + kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker, |
| 414 | + SageMakerApi.CreateHyperParameterTuningJob) |
390 | 415 |
|
391 | 416 | parameters = tuning_config(tuner=tuner, inputs=data, job_name=job_name).copy() |
392 | 417 |
|
@@ -436,15 +461,17 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp |
436 | 461 | """ |
437 | 462 | Example resource arn: arn:aws:states:::sagemaker:createProcessingJob.sync |
438 | 463 | """ |
439 | | - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker, |
440 | | - SageMakerApi.CreateProcessingJob, |
441 | | - IntegrationPattern.WaitForCompletion) |
| 464 | + |
| 465 | + kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker, |
| 466 | + SageMakerApi.CreateProcessingJob, |
| 467 | + IntegrationPattern.WaitForCompletion) |
442 | 468 | else: |
443 | 469 | """ |
444 | 470 | Example resource arn: arn:aws:states:::sagemaker:createProcessingJob |
445 | 471 | """ |
446 | | - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker, |
447 | | - SageMakerApi.CreateProcessingJob) |
| 472 | + |
| 473 | + kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker, |
| 474 | + SageMakerApi.CreateProcessingJob) |
448 | 475 |
|
449 | 476 | if isinstance(job_name, str): |
450 | 477 | parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id, job_name=job_name) |
|
0 commit comments