Skip to content

Commit b2bb22e

Browse files
jkasirajJaya Kasirajpintaoz-aws
authored
feat: add support for mlflow to pipelines (#5429)
* feat: add support for mlflow to pipelines * refactor: use MlflowConfig from sagemaker-core --------- Co-authored-by: Jaya Kasiraj <jkasiraj@amazon.com> Co-authored-by: pintaoz-aws <167920275+pintaoz-aws@users.noreply.github.com>
1 parent efa8d7d commit b2bb22e

File tree

5 files changed

+525
-11
lines changed

5 files changed

+525
-11
lines changed

sagemaker-mlops/src/sagemaker/mlops/local/pipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from sagemaker.core.workflow.functions import Join, JsonGet, PropertyFile
2727
from sagemaker.core.workflow.properties import Properties
2828
from sagemaker.core.workflow.execution_variables import ExecutionVariable, ExecutionVariables
29+
2930
# Orchestration imports (now in mlops)
3031
from sagemaker.mlops.workflow.function_step import DelayedReturn
3132
from sagemaker.mlops.workflow.steps import StepTypeEnum, Step
@@ -560,4 +561,4 @@ def get(self, step: Step) -> _StepExecutor:
560561
self.pipeline_executor.execution.update_step_failure(
561562
step.name, f"Unsupported step type {step_type} to execute."
562563
)
563-
return step_executor
564+
return step_executor

sagemaker-mlops/src/sagemaker/mlops/workflow/pipeline.py

Lines changed: 65 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,13 @@
3434
from sagemaker.core.remote_function.job import JOBS_CONTAINER_ENTRYPOINT
3535
from sagemaker.core.s3 import s3_path_join
3636
from sagemaker.core.helper.session_helper import Session
37-
from sagemaker.core.common_utils import resolve_value_from_config, retry_with_backoff, format_tags, Tags
37+
from sagemaker.core.common_utils import (
38+
resolve_value_from_config,
39+
retry_with_backoff,
40+
format_tags,
41+
Tags,
42+
)
43+
3844
# Orchestration imports (now in mlops)
3945
from sagemaker.mlops.workflow.callback_step import CallbackOutput, CallbackStep
4046
from sagemaker.mlops.workflow._event_bridge_client_helper import (
@@ -44,19 +50,24 @@
4450
EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT,
4551
)
4652
from sagemaker.mlops.workflow.lambda_step import LambdaOutput, LambdaStep
53+
from sagemaker.core.shapes.shapes import MlflowConfig
4754
from sagemaker.core.helper.pipeline_variable import (
4855
RequestType,
4956
PipelineVariable,
5057
)
58+
5159
# Primitive imports (stay in core)
5260
from sagemaker.core.workflow.execution_variables import ExecutionVariables
5361
from sagemaker.core.workflow.parameters import Parameter
62+
5463
# Orchestration imports (now in mlops)
5564
from sagemaker.core.workflow.pipeline_definition_config import PipelineDefinitionConfig
5665
from sagemaker.mlops.workflow.pipeline_experiment_config import PipelineExperimentConfig
5766
from sagemaker.mlops.workflow.parallelism_config import ParallelismConfiguration
67+
5868
# Primitive imports (stay in core)
5969
from sagemaker.core.workflow.properties import Properties
70+
6071
# Orchestration imports (now in mlops)
6172
from sagemaker.mlops.workflow.selective_execution_config import SelectiveExecutionConfig
6273
from sagemaker.core.workflow.step_outputs import StepOutput
@@ -89,6 +100,7 @@ def __init__(
89100
name: str = "",
90101
parameters: Optional[Sequence[Parameter]] = None,
91102
pipeline_experiment_config: Optional[PipelineExperimentConfig] = _DEFAULT_EXPERIMENT_CFG,
103+
mlflow_config: Optional[MlflowConfig] = None,
92104
steps: Optional[Sequence[Union[Step, StepOutput]]] = None,
93105
sagemaker_session: Optional[Session] = None,
94106
pipeline_definition_config: Optional[PipelineDefinitionConfig] = _DEFAULT_DEFINITION_CFG,
@@ -104,6 +116,8 @@ def __init__(
104116
the same name already exists. By default, pipeline name is used as
105117
experiment name and execution id is used as the trial name.
106118
If set to None, no experiment or trial will be created automatically.
119+
mlflow_config (Optional[MlflowConfig]): If set, the pipeline will be configured
120+
with MLflow tracking for experiment tracking and model versioning.
107121
steps (Sequence[Union[Step, StepOutput]]): The list of the
108122
non-conditional steps associated with the pipeline. Any steps that are within the
109123
`if_steps` or `else_steps` of a `ConditionStep` cannot be listed in the steps of a
@@ -120,6 +134,7 @@ def __init__(
120134
self.name = name
121135
self.parameters = parameters if parameters else []
122136
self.pipeline_experiment_config = pipeline_experiment_config
137+
self.mlflow_config = mlflow_config
123138
self.steps = steps if steps else []
124139
self.sagemaker_session = sagemaker_session if sagemaker_session else Session()
125140
self.pipeline_definition_config = pipeline_definition_config
@@ -359,6 +374,7 @@ def start(
359374
execution_description: str = None,
360375
parallelism_config: ParallelismConfiguration = None,
361376
selective_execution_config: SelectiveExecutionConfig = None,
377+
mlflow_experiment_name: str = None,
362378
pipeline_version_id: int = None,
363379
):
364380
"""Starts a Pipeline execution in the Workflow service.
@@ -373,6 +389,10 @@ def start(
373389
over the parallelism configuration of the parent pipeline.
374390
selective_execution_config (Optional[SelectiveExecutionConfig]): The configuration for
375391
selective step execution.
392+
mlflow_experiment_name (str): Optional MLflow experiment name to override
393+
the experiment name specified in the pipeline's mlflow_config.
394+
If provided, this will override the experiment name for this specific
395+
pipeline execution only, without modifying the pipeline definition.
376396
pipeline_version_id (Optional[str]): version ID of the pipeline to start the execution from. If not
377397
specified, uses the latest version ID.
378398
@@ -396,6 +416,7 @@ def start(
396416
PipelineExecutionDisplayName=execution_display_name,
397417
ParallelismConfiguration=parallelism_config,
398418
SelectiveExecutionConfig=selective_execution_config,
419+
MlflowExperimentName=mlflow_experiment_name,
399420
PipelineVersionId=pipeline_version_id,
400421
)
401422
if self.sagemaker_session.local_mode:
@@ -435,14 +456,23 @@ def definition(self) -> str:
435456
if self.pipeline_experiment_config is not None
436457
else None
437458
),
459+
"MlflowConfig": _convert_mlflow_config_to_request(self.mlflow_config),
438460
"Steps": list_to_request(compiled_steps),
439461
}
440-
441-
request_dict["PipelineExperimentConfig"] = interpolate(
442-
request_dict["PipelineExperimentConfig"], {}, {}, pipeline_name=self.name
443-
)
444462
callback_output_to_step_map = _map_callback_outputs(self.steps)
445463
lambda_output_to_step_name = _map_lambda_outputs(self.steps)
464+
request_dict["PipelineExperimentConfig"] = interpolate(
465+
request_dict["PipelineExperimentConfig"],
466+
callback_output_to_step_map=callback_output_to_step_map,
467+
lambda_output_to_step_map=lambda_output_to_step_name,
468+
pipeline_name=self.name,
469+
)
470+
request_dict["MlflowConfig"] = interpolate(
471+
request_dict["MlflowConfig"],
472+
callback_output_to_step_map=callback_output_to_step_map,
473+
lambda_output_to_step_map=lambda_output_to_step_name,
474+
pipeline_name=self.name,
475+
)
446476
request_dict["Steps"] = interpolate(
447477
request_dict["Steps"],
448478
callback_output_to_step_map=callback_output_to_step_map,
@@ -730,6 +760,34 @@ def delete_triggers(self, trigger_names: List[str]):
730760
logger.info("Deleted Pipeline Schedule: %s ...", trigger_name)
731761

732762

763+
def _convert_mlflow_config_to_request(mlflow_config: MlflowConfig) -> dict:
764+
"""Convert sagemaker-core MlflowConfig to pipeline request format.
765+
766+
Args:
767+
mlflow_config: MlflowConfig instance from sagemaker.core.shapes.shapes
768+
769+
Returns:
770+
dict: Request format for pipeline MLflow configuration
771+
"""
772+
if mlflow_config is None:
773+
return None
774+
775+
from sagemaker.core.utils.utils import Unassigned
776+
777+
resource_arn = mlflow_config.mlflow_resource_arn
778+
if isinstance(resource_arn, Unassigned):
779+
resource_arn = None
780+
781+
experiment_name = mlflow_config.mlflow_experiment_name
782+
if isinstance(experiment_name, Unassigned):
783+
experiment_name = None
784+
785+
return {
786+
"MlflowResourceArn": resource_arn,
787+
"MlflowExperimentName": experiment_name,
788+
}
789+
790+
733791
def format_start_parameters(parameters: Dict[str, Any]) -> List[Dict[str, Any]]:
734792
"""Formats start parameter overrides as a list of dicts.
735793
@@ -1135,7 +1193,6 @@ def _initialize_adjacency_list(self) -> Dict[str, List[str]]:
11351193
if isinstance(child_step, Step):
11361194
dependency_list[child_step.name].add(step.name)
11371195

1138-
11391196
adjacency_list = {}
11401197
for step in dependency_list:
11411198
for step_dependency in dependency_list[step]:
@@ -1173,9 +1230,7 @@ def is_cyclic_helper(current_step):
11731230
return True
11741231
return False
11751232

1176-
def get_steps_in_sub_dag(
1177-
self, current_step: Step, sub_dag_steps: Set[str] = None
1178-
) -> Set[str]:
1233+
def get_steps_in_sub_dag(self, current_step: Step, sub_dag_steps: Set[str] = None) -> Set[str]:
11791234
"""Get names of all steps (including current step) in the sub dag of current step.
11801235
11811236
Returns a set of step names in the sub dag.
@@ -1215,4 +1270,4 @@ def __next__(self) -> Step:
12151270

12161271
while self.stack:
12171272
return self.step_map.get(self.stack.pop())
1218-
raise StopIteration
1273+
raise StopIteration

0 commit comments

Comments
 (0)