3434from sagemaker .core .remote_function .job import JOBS_CONTAINER_ENTRYPOINT
3535from sagemaker .core .s3 import s3_path_join
3636from sagemaker .core .helper .session_helper import Session
37- from sagemaker .core .common_utils import resolve_value_from_config , retry_with_backoff , format_tags , Tags
37+ from sagemaker .core .common_utils import (
38+ resolve_value_from_config ,
39+ retry_with_backoff ,
40+ format_tags ,
41+ Tags ,
42+ )
43+
3844# Orchestration imports (now in mlops)
3945from sagemaker .mlops .workflow .callback_step import CallbackOutput , CallbackStep
4046from sagemaker .mlops .workflow ._event_bridge_client_helper import (
4450 EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT ,
4551)
4652from sagemaker .mlops .workflow .lambda_step import LambdaOutput , LambdaStep
53+ from sagemaker .core .shapes .shapes import MlflowConfig
4754from sagemaker .core .helper .pipeline_variable import (
4855 RequestType ,
4956 PipelineVariable ,
5057)
58+
5159# Primitive imports (stay in core)
5260from sagemaker .core .workflow .execution_variables import ExecutionVariables
5361from sagemaker .core .workflow .parameters import Parameter
62+
5463# Orchestration imports (now in mlops)
5564from sagemaker .core .workflow .pipeline_definition_config import PipelineDefinitionConfig
5665from sagemaker .mlops .workflow .pipeline_experiment_config import PipelineExperimentConfig
5766from sagemaker .mlops .workflow .parallelism_config import ParallelismConfiguration
67+
5868# Primitive imports (stay in core)
5969from sagemaker .core .workflow .properties import Properties
70+
6071# Orchestration imports (now in mlops)
6172from sagemaker .mlops .workflow .selective_execution_config import SelectiveExecutionConfig
6273from sagemaker .core .workflow .step_outputs import StepOutput
@@ -89,6 +100,7 @@ def __init__(
89100 name : str = "" ,
90101 parameters : Optional [Sequence [Parameter ]] = None ,
91102 pipeline_experiment_config : Optional [PipelineExperimentConfig ] = _DEFAULT_EXPERIMENT_CFG ,
103+ mlflow_config : Optional [MlflowConfig ] = None ,
92104 steps : Optional [Sequence [Union [Step , StepOutput ]]] = None ,
93105 sagemaker_session : Optional [Session ] = None ,
94106 pipeline_definition_config : Optional [PipelineDefinitionConfig ] = _DEFAULT_DEFINITION_CFG ,
@@ -104,6 +116,8 @@ def __init__(
104116 the same name already exists. By default, pipeline name is used as
105117 experiment name and execution id is used as the trial name.
106118 If set to None, no experiment or trial will be created automatically.
119+ mlflow_config (Optional[MlflowConfig]): If set, the pipeline will be configured
120+ with MLflow tracking for experiment tracking and model versioning.
107121 steps (Sequence[Union[Step, StepOutput]]): The list of the
108122 non-conditional steps associated with the pipeline. Any steps that are within the
109123 `if_steps` or `else_steps` of a `ConditionStep` cannot be listed in the steps of a
@@ -120,6 +134,7 @@ def __init__(
120134 self .name = name
121135 self .parameters = parameters if parameters else []
122136 self .pipeline_experiment_config = pipeline_experiment_config
137+ self .mlflow_config = mlflow_config
123138 self .steps = steps if steps else []
124139 self .sagemaker_session = sagemaker_session if sagemaker_session else Session ()
125140 self .pipeline_definition_config = pipeline_definition_config
@@ -359,6 +374,7 @@ def start(
359374 execution_description : str = None ,
360375 parallelism_config : ParallelismConfiguration = None ,
361376 selective_execution_config : SelectiveExecutionConfig = None ,
377+ mlflow_experiment_name : str = None ,
362378 pipeline_version_id : int = None ,
363379 ):
364380 """Starts a Pipeline execution in the Workflow service.
@@ -373,6 +389,10 @@ def start(
373389 over the parallelism configuration of the parent pipeline.
374390 selective_execution_config (Optional[SelectiveExecutionConfig]): The configuration for
375391 selective step execution.
392+ mlflow_experiment_name (str): Optional MLflow experiment name to override
393+ the experiment name specified in the pipeline's mlflow_config.
394+ If provided, this will override the experiment name for this specific
395+ pipeline execution only, without modifying the pipeline definition.
376396 pipeline_version_id (Optional[str]): version ID of the pipeline to start the execution from. If not
377397 specified, uses the latest version ID.
378398
@@ -396,6 +416,7 @@ def start(
396416 PipelineExecutionDisplayName = execution_display_name ,
397417 ParallelismConfiguration = parallelism_config ,
398418 SelectiveExecutionConfig = selective_execution_config ,
419+ MlflowExperimentName = mlflow_experiment_name ,
399420 PipelineVersionId = pipeline_version_id ,
400421 )
401422 if self .sagemaker_session .local_mode :
@@ -435,14 +456,23 @@ def definition(self) -> str:
435456 if self .pipeline_experiment_config is not None
436457 else None
437458 ),
459+ "MlflowConfig" : _convert_mlflow_config_to_request (self .mlflow_config ),
438460 "Steps" : list_to_request (compiled_steps ),
439461 }
440-
441- request_dict ["PipelineExperimentConfig" ] = interpolate (
442- request_dict ["PipelineExperimentConfig" ], {}, {}, pipeline_name = self .name
443- )
444462 callback_output_to_step_map = _map_callback_outputs (self .steps )
445463 lambda_output_to_step_name = _map_lambda_outputs (self .steps )
464+ request_dict ["PipelineExperimentConfig" ] = interpolate (
465+ request_dict ["PipelineExperimentConfig" ],
466+ callback_output_to_step_map = callback_output_to_step_map ,
467+ lambda_output_to_step_map = lambda_output_to_step_name ,
468+ pipeline_name = self .name ,
469+ )
470+ request_dict ["MlflowConfig" ] = interpolate (
471+ request_dict ["MlflowConfig" ],
472+ callback_output_to_step_map = callback_output_to_step_map ,
473+ lambda_output_to_step_map = lambda_output_to_step_name ,
474+ pipeline_name = self .name ,
475+ )
446476 request_dict ["Steps" ] = interpolate (
447477 request_dict ["Steps" ],
448478 callback_output_to_step_map = callback_output_to_step_map ,
@@ -730,6 +760,34 @@ def delete_triggers(self, trigger_names: List[str]):
730760 logger .info ("Deleted Pipeline Schedule: %s ..." , trigger_name )
731761
732762
763+ def _convert_mlflow_config_to_request (mlflow_config : MlflowConfig ) -> dict :
764+ """Convert sagemaker-core MlflowConfig to pipeline request format.
765+
766+ Args:
767+ mlflow_config: MlflowConfig instance from sagemaker.core.shapes.shapes
768+
769+ Returns:
770+ dict: Request format for pipeline MLflow configuration
771+ """
772+ if mlflow_config is None :
773+ return None
774+
775+ from sagemaker .core .utils .utils import Unassigned
776+
777+ resource_arn = mlflow_config .mlflow_resource_arn
778+ if isinstance (resource_arn , Unassigned ):
779+ resource_arn = None
780+
781+ experiment_name = mlflow_config .mlflow_experiment_name
782+ if isinstance (experiment_name , Unassigned ):
783+ experiment_name = None
784+
785+ return {
786+ "MlflowResourceArn" : resource_arn ,
787+ "MlflowExperimentName" : experiment_name ,
788+ }
789+
790+
733791def format_start_parameters (parameters : Dict [str , Any ]) -> List [Dict [str , Any ]]:
734792 """Formats start parameter overrides as a list of dicts.
735793
@@ -1135,7 +1193,6 @@ def _initialize_adjacency_list(self) -> Dict[str, List[str]]:
11351193 if isinstance (child_step , Step ):
11361194 dependency_list [child_step .name ].add (step .name )
11371195
1138-
11391196 adjacency_list = {}
11401197 for step in dependency_list :
11411198 for step_dependency in dependency_list [step ]:
@@ -1173,9 +1230,7 @@ def is_cyclic_helper(current_step):
11731230 return True
11741231 return False
11751232
1176- def get_steps_in_sub_dag (
1177- self , current_step : Step , sub_dag_steps : Set [str ] = None
1178- ) -> Set [str ]:
1233+ def get_steps_in_sub_dag (self , current_step : Step , sub_dag_steps : Set [str ] = None ) -> Set [str ]:
11791234 """Get names of all steps (including current step) in the sub dag of current step.
11801235
11811236 Returns a set of step names in the sub dag.
@@ -1215,4 +1270,4 @@ def __next__(self) -> Step:
12151270
12161271 while self .stack :
12171272 return self .step_map .get (self .stack .pop ())
1218- raise StopIteration
1273+ raise StopIteration
0 commit comments