From 5484ba0e5f50564f8903153ace060bb1221eb4aa Mon Sep 17 00:00:00 2001 From: Sophia Date: Fri, 14 Nov 2025 22:00:48 -0800 Subject: [PATCH 01/31] feat: Implement elastic training cli arguments (#273) * feat: Implement elastic training cli arguments * Add elastic training unified config and unit test * Add graceful shutdown and scaling timeout to cli args --- .../v1_1/model.py | 48 +++++- .../v1_1/schema.json | 59 +++++++- .../hyperpod_pytorch_job_unified_config.py | 142 +++++++++++++++++- test/unit_tests/cli/test_training.py | 56 +++++++ 4 files changed, 300 insertions(+), 5 deletions(-) diff --git a/hyperpod-pytorch-job-template/hyperpod_pytorch_job_template/v1_1/model.py b/hyperpod-pytorch-job-template/hyperpod_pytorch_job_template/v1_1/model.py index 01cf8075..d3a60de0 100644 --- a/hyperpod-pytorch-job-template/hyperpod_pytorch_job_template/v1_1/model.py +++ b/hyperpod-pytorch-job-template/hyperpod_pytorch_job_template/v1_1/model.py @@ -11,7 +11,8 @@ Metadata, Volumes, HostPath, - PersistentVolumeClaim + PersistentVolumeClaim, + ElasticPolicy ) from sagemaker.hyperpod.training.hyperpod_pytorch_job import HyperPodPytorchJob import yaml @@ -222,6 +223,28 @@ class PyTorchJobConfig(BaseModel): alias="required_topology", description="Required topology annotation for scheduling", ) + elastic_replica_increment_step: Optional[int] = Field( + default=None, + alias="elastic_replica_increment_step", + description="Scaling step size for elastic training", + ge=1, + ) + max_node_count: Optional[int] = Field( + default=None, + alias="max_node_count", + description="Maximum number of nodes for elastic training", + ge=1, + ) + elastic_graceful_shutdown_timeout_seconds: Optional[int] = Field( + default=None, + alias="elastic_graceful_shutdown_timeout_seconds", + description="Graceful shutdown timeout in seconds for elastic scaling operations" + ) + elastic_scaling_timeout: Optional[str] = Field( + default=None, + alias="elastic_scaling_timeout", + description="Scaling timeout for elastic training" + ) @field_validator('tasks_per_node', mode='before') @classmethod @@ -431,15 +454,34 @@ def build_dict(**kwargs): replica_kwargs = build_dict( name="pod", template=Template(metadata=Metadata(**metadata_kwargs), spec=Spec(**spec_kwargs)), - replicas=self.node_count + replicas=self.node_count, + max_replicas=self.max_node_count ) + # Build elastic policy + elastic_policy = None + if any([ + self.elastic_replica_increment_step is not None, + self.max_node_count is not None, + self.elastic_graceful_shutdown_timeout_seconds is not None, + self.elastic_scaling_timeout is not None + ]): + elastic_policy_kwargs = build_dict( + min_replicas=self.node_count, + replica_increment_step=self.elastic_replica_increment_step, + max_replicas=self.max_node_count, + graceful_shutdown_timeout_seconds=self.elastic_graceful_shutdown_timeout_seconds, + scaling_timeout=self.elastic_scaling_timeout + ) + elastic_policy = ElasticPolicy(**elastic_policy_kwargs) + # Build job job_kwargs = build_dict( metadata=metadata_kwargs, replica_specs=[ReplicaSpec(**replica_kwargs)], nproc_per_node=str(self.tasks_per_node) if self.tasks_per_node else None, - run_policy=RunPolicy(clean_pod_policy="None", job_max_retry_count=self.max_retry) if self.max_retry else None + run_policy=RunPolicy(clean_pod_policy="None", job_max_retry_count=self.max_retry) if self.max_retry else None, + elastic_policy=elastic_policy ) result = HyperPodPytorchJob(**job_kwargs) diff --git a/hyperpod-pytorch-job-template/hyperpod_pytorch_job_template/v1_1/schema.json b/hyperpod-pytorch-job-template/hyperpod_pytorch_job_template/v1_1/schema.json index 41abed18..5d400619 100644 --- a/hyperpod-pytorch-job-template/hyperpod_pytorch_job_template/v1_1/schema.json +++ b/hyperpod-pytorch-job-template/hyperpod_pytorch_job_template/v1_1/schema.json @@ -372,7 +372,64 @@ "type": "string", "description": "Required topology annotation for scheduling", "$ref": "#/$defs/topologyLabels" + }, + "elastic_replica_increment_step": { + "anyOf": [ + { + "minimum": 1, + "type": "integer" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Scaling step size for elastic training", + "title": "Elastic Training Replica Increment Step" + }, + "max_node_count": { + "anyOf": [ + { + "minimum": 1, + "type": "integer" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Maximum number of nodes for elastic training", + "title": "Max Node Count" + }, + "elastic_graceful_shutdown_timeout_seconds": { + "anyOf": [ + { + "minimum": 0, + "type": "integer" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Graceful shutdown timeout in seconds for elastic scaling operations", + "title": "Elastic Graceful Shutdown Timeout Seconds" + }, + "elastic_scaling_timeout": { + "anyOf": [ + { + "minLength": 1, + "type": "string" + }, + { + "type": "null" } + ], + "default": null, + "description": "Scaling timeout for elastic training", + "title": "Elastic Scaling Timeout" + } + }, "required": [ "job_name", @@ -380,4 +437,4 @@ ], "title": "PyTorchJobConfig", "type": "object" -} \ No newline at end of file +} diff --git a/src/sagemaker/hyperpod/training/config/hyperpod_pytorch_job_unified_config.py b/src/sagemaker/hyperpod/training/config/hyperpod_pytorch_job_unified_config.py index a7855ef5..5e2f30ba 100644 --- a/src/sagemaker/hyperpod/training/config/hyperpod_pytorch_job_unified_config.py +++ b/src/sagemaker/hyperpod/training/config/hyperpod_pytorch_job_unified_config.py @@ -2982,6 +2982,11 @@ class ReplicaSpec(BaseModel): default=0, description="Replicas is the desired number of replicas of the given template.", ) + maxReplicas: Optional[int] = Field( + default=None, + alias="max_replicas", + description="Maximum replicas for elastic training" + ) spares: Optional[int] = Field( default=0, description="Spares requests spare resources from Kueue. E.g. If a job is configured with 4 replicas and 2 spares, job requests resources required to run 6 pods such as cpu, gpu", @@ -2991,6 +2996,52 @@ class ReplicaSpec(BaseModel): description="Template is the object that describes the pod that will be created for this replica.", ) +class ElasticPolicy(BaseModel): + """ElasticPolicy defines the elastic training policy""" + + model_config = ConfigDict(extra="forbid") + + replicaIncrementStep: Optional[int] = Field( + default=None, + alias="replica_increment_step", + description="Step size for elastic replica scaling" + ) + minReplicas: Optional[int] = Field( + default=None, + alias="min_replicas", + description="Minimum number of replicas" + ) + maxReplicas: Optional[int] = Field( + default=None, + alias="max_replicas", + description="Maximum number of replicas" + ) + replicaDiscreteValues: Optional[List[int]] = Field( + default=None, + alias="replica_discrete_values", + description="Alternative to ReplicaIncrementStep. Provides exact values for total replicas count" + ) + scalingTimeout: Optional[str] = Field( + default=None, + alias="scaling_timeout", + description="Timeout for scaling operations" + ) + scalingPolicy: Optional[str] = Field( + default=None, + alias="scaling_policy", + description="Scaling policy behavior (auto, aggressive, conservative)" + ) + gracefulShutdownTimeoutSeconds: Optional[int] = Field( + default=None, + alias="graceful_shutdown_timeout_seconds", + description="Graceful shutdown timeout in seconds for elastic scaling operations" + ) + faultyScaleDownTimeoutSeconds: Optional[int] = Field( + default=None, + alias="faulty_scale_down_timeout_seconds", + description="Timeout in seconds after entering Faulted state before triggering faulty pod scale-down" + ) + class LogMonitoringConfiguration(BaseModel): """LogMonitoringRule defines the criteria used to detect a SLOW or HANGING job""" @@ -3098,6 +3149,11 @@ class RunPolicy(BaseModel): alias="ttl_seconds_after_finished", description="TTLSecondsAfterFinished is the TTL to clean up jobs. Set to -1 for infinite", ) + workloadMode: Optional[str] = Field( + default=None, + alias="workload_mode", + description="Workload deployment mode for elastic training (e.g., 'Deployment')", + ) class PodSets(BaseModel): @@ -3153,6 +3209,43 @@ class Pods(BaseModel): ) +class ElasticScalingStatus(BaseModel): + """ElasticScalingStatus represents the current state of elastic scaling operations""" + + model_config = ConfigDict(extra="forbid") + + targetReplicas: Optional[Dict[str, int]] = Field( + default=None, + alias="target_replicas", + description="TargetReplicas contains the desired replica counts per ReplicaSpec name", + ) + lastUpdated: Optional[str] = Field( + default=None, + alias="last_updated", + description="LastUpdated is the timestamp when this status was last modified", + ) + lastScalingTime: Optional[str] = Field( + default=None, + alias="last_scaling_time", + description="LastScalingTime tracks when the last scaling operation completed", + ) + lastRestartTime: Optional[str] = Field( + default=None, + alias="last_restart_time", + description="LastRestartTime tracks when the job was last restarted for scaleUpRestartTimeout", + ) + podsScaled: Optional[bool] = Field( + default=None, + alias="pods_scaled", + description="PodsScaled indicates whether pods have already been scaled in this scaling round", + ) + isFaultyPodScaleDown: Optional[bool] = Field( + default=None, + alias="is_faulty_pod_scale_down", + description="IsFaultyPodScaleDown indicates this scaling operation is removing faulty pods", + ) + + class RestartStatus(BaseModel): """Additional restart limiting status""" @@ -3171,6 +3264,33 @@ class RestartStatus(BaseModel): ) +class FaultyPodInstanceList(BaseModel): + """FaultyPodInstanceRecord tracks faulty pod/instances for each restart""" + + model_config = ConfigDict(extra="forbid") + + restartType: Optional[str] = Field( + default=None, + alias="restart_type", + description="RestartType indicates whether this was a PLR or JLR" + ) + faultyInstanceIdList: Optional[List[str]] = Field( + default_factory=list, + alias="faulty_instance_id_list", + description="FaultyInstanceIdList tracks faulty instance ids" + ) + faultyPodList: Optional[List[str]] = Field( + default_factory=list, + alias="faulty_pod_list", + description="FaultyPodList tracks faulty pod names" + ) + faultyRankList: Optional[List[str]] = Field( + default_factory=list, + alias="faulty_rank_list", + description="FaultyRankList tracks faulty pod ranks" + ) + + class HyperPodPytorchJobStatus(BaseModel): """HyperPodPytorchJobStatus defines the observed state of HyperPodPytorchJob""" @@ -3187,6 +3307,11 @@ class HyperPodPytorchJobStatus(BaseModel): alias="job_pods", description="The StatefulSet containing the training pods", ) + latestFaultyPodInstanceList: Optional[FaultyPodInstanceList] = Field( + default=None, + alias="latest_faulty_pod_instance_list", + description="LatestFaultyPodInstanceList tracks faulty pods/nodes of latest restart" + ) managerPods: Optional[ManagerPods] = Field( default=None, alias="manager_pods", description="Pod Manager pods" ) @@ -3221,6 +3346,16 @@ class HyperPodPytorchJobStatus(BaseModel): alias="restart_status", description="Additional restart limiting status", ) + elasticScalingStatus: Optional[ElasticScalingStatus] = Field( + default=None, + alias="elastic_scaling_status", + description="ElasticScalingStatus contains the current state of elastic scaling operations", + ) + elasticWorkloadRef: Optional[Dict[str, str]] = Field( + default=None, + alias="elastic_workload_ref", + description="Reference to associated ElasticWorkload (optional, only set when ElasticPolicy is present)", + ) startTime: Optional[str] = Field( default=None, alias="start_time", @@ -3245,4 +3380,9 @@ class _HyperPodPytorchJob(BaseModel): ) runPolicy: Optional[RunPolicy] = Field( default=None, alias="run_policy", description="RunPolicy" - ) \ No newline at end of file + ) + elasticPolicy: Optional[ElasticPolicy] = Field( + default=None, + alias="elastic_policy", + description="ElasticPolicy for elastic training" + ) diff --git a/test/unit_tests/cli/test_training.py b/test/unit_tests/cli/test_training.py index 95de870c..fa05000e 100644 --- a/test/unit_tests/cli/test_training.py +++ b/test/unit_tests/cli/test_training.py @@ -156,6 +156,62 @@ def test_optional_params(self): self.assertEqual(call_args["metadata"]["labels"]["kueue.x-k8s.io/queue-name"], "localqueue") self.assertEqual(call_args["metadata"]["annotations"]["kueue.x-k8s.io/podset-required-topology"], "topology.k8s.aws/ultraserver-id") + @patch('sys.argv', ['pytest', '--version', '1.1']) + def test_elastic_training_params(self): + """Test job creation with elastic training parameters""" + # Reload the training module with mocked sys.argv + if 'sagemaker.hyperpod.cli.commands.training' in sys.modules: + importlib.reload(sys.modules['sagemaker.hyperpod.cli.commands.training']) + + from sagemaker.hyperpod.cli.commands.training import pytorch_create + + with patch("hyperpod_pytorch_job_template.v1_1.model.HyperPodPytorchJob") as mock_hyperpod_job: + mock_instance = Mock() + mock_hyperpod_job.return_value = mock_instance + + result = self.runner.invoke( + pytorch_create, + [ + "--version", + "1.1", + "--job-name", + "elastic-test-job", + "--image", + "pytorch:latest", + "--elastic-replica-increment-step", + "2", + "--max-node-count", + "4", + "--elastic-graceful-shutdown-timeout-seconds", + "180", + "--elastic-scaling-timeout", + "30s", + ], + ) + + print(f"Command output: {result.output}") + + # Verify command succeeded + self.assertEqual(result.exit_code, 0) + self.assertIn("Using version: 1.1", result.output) + + # Verify HyperPodPytorchJob was created with elastic parameters + mock_hyperpod_job.assert_called_once() + call_args = mock_hyperpod_job.call_args[1] + + # Validate basic job configuration + self.assertEqual(call_args["metadata"]["name"], "elastic-test-job") + + # Validate elastic policy configuration + self.assertIsNotNone(call_args.get("elastic_policy")) + elastic_policy = call_args["elastic_policy"] + self.assertEqual(elastic_policy.replicaIncrementStep, 2) + self.assertEqual(elastic_policy.maxReplicas, 4) + self.assertEqual(elastic_policy.gracefulShutdownTimeoutSeconds, 180) + self.assertEqual(elastic_policy.scalingTimeout, "30s") + + mock_instance.create.assert_called_once() + @patch('sagemaker.hyperpod.common.cli_decorators._namespace_exists') @patch("sagemaker.hyperpod.cli.commands.training.HyperPodPytorchJob") def test_list_jobs(self, mock_hyperpod_pytorch_job, mock_namespace_exists): From 648c0835e9210b8e8745aee69968fce216aff3f6 Mon Sep 17 00:00:00 2001 From: Molly He Date: Fri, 14 Nov 2025 22:03:51 -0800 Subject: [PATCH 02/31] Revert "feat: Implement elastic training cli arguments (#273)" This reverts commit 18428ef2b1c0562bf51a9a4b4aa2914eed441259. --- .../v1_1/model.py | 48 +----- .../v1_1/schema.json | 59 +------- .../hyperpod_pytorch_job_unified_config.py | 142 +----------------- test/unit_tests/cli/test_training.py | 56 ------- 4 files changed, 5 insertions(+), 300 deletions(-) diff --git a/hyperpod-pytorch-job-template/hyperpod_pytorch_job_template/v1_1/model.py b/hyperpod-pytorch-job-template/hyperpod_pytorch_job_template/v1_1/model.py index d3a60de0..01cf8075 100644 --- a/hyperpod-pytorch-job-template/hyperpod_pytorch_job_template/v1_1/model.py +++ b/hyperpod-pytorch-job-template/hyperpod_pytorch_job_template/v1_1/model.py @@ -11,8 +11,7 @@ Metadata, Volumes, HostPath, - PersistentVolumeClaim, - ElasticPolicy + PersistentVolumeClaim ) from sagemaker.hyperpod.training.hyperpod_pytorch_job import HyperPodPytorchJob import yaml @@ -223,28 +222,6 @@ class PyTorchJobConfig(BaseModel): alias="required_topology", description="Required topology annotation for scheduling", ) - elastic_replica_increment_step: Optional[int] = Field( - default=None, - alias="elastic_replica_increment_step", - description="Scaling step size for elastic training", - ge=1, - ) - max_node_count: Optional[int] = Field( - default=None, - alias="max_node_count", - description="Maximum number of nodes for elastic training", - ge=1, - ) - elastic_graceful_shutdown_timeout_seconds: Optional[int] = Field( - default=None, - alias="elastic_graceful_shutdown_timeout_seconds", - description="Graceful shutdown timeout in seconds for elastic scaling operations" - ) - elastic_scaling_timeout: Optional[str] = Field( - default=None, - alias="elastic_scaling_timeout", - description="Scaling timeout for elastic training" - ) @field_validator('tasks_per_node', mode='before') @classmethod @@ -454,34 +431,15 @@ def build_dict(**kwargs): replica_kwargs = build_dict( name="pod", template=Template(metadata=Metadata(**metadata_kwargs), spec=Spec(**spec_kwargs)), - replicas=self.node_count, - max_replicas=self.max_node_count + replicas=self.node_count ) - # Build elastic policy - elastic_policy = None - if any([ - self.elastic_replica_increment_step is not None, - self.max_node_count is not None, - self.elastic_graceful_shutdown_timeout_seconds is not None, - self.elastic_scaling_timeout is not None - ]): - elastic_policy_kwargs = build_dict( - min_replicas=self.node_count, - replica_increment_step=self.elastic_replica_increment_step, - max_replicas=self.max_node_count, - graceful_shutdown_timeout_seconds=self.elastic_graceful_shutdown_timeout_seconds, - scaling_timeout=self.elastic_scaling_timeout - ) - elastic_policy = ElasticPolicy(**elastic_policy_kwargs) - # Build job job_kwargs = build_dict( metadata=metadata_kwargs, replica_specs=[ReplicaSpec(**replica_kwargs)], nproc_per_node=str(self.tasks_per_node) if self.tasks_per_node else None, - run_policy=RunPolicy(clean_pod_policy="None", job_max_retry_count=self.max_retry) if self.max_retry else None, - elastic_policy=elastic_policy + run_policy=RunPolicy(clean_pod_policy="None", job_max_retry_count=self.max_retry) if self.max_retry else None ) result = HyperPodPytorchJob(**job_kwargs) diff --git a/hyperpod-pytorch-job-template/hyperpod_pytorch_job_template/v1_1/schema.json b/hyperpod-pytorch-job-template/hyperpod_pytorch_job_template/v1_1/schema.json index 5d400619..41abed18 100644 --- a/hyperpod-pytorch-job-template/hyperpod_pytorch_job_template/v1_1/schema.json +++ b/hyperpod-pytorch-job-template/hyperpod_pytorch_job_template/v1_1/schema.json @@ -372,64 +372,7 @@ "type": "string", "description": "Required topology annotation for scheduling", "$ref": "#/$defs/topologyLabels" - }, - "elastic_replica_increment_step": { - "anyOf": [ - { - "minimum": 1, - "type": "integer" - }, - { - "type": "null" - } - ], - "default": null, - "description": "Scaling step size for elastic training", - "title": "Elastic Training Replica Increment Step" - }, - "max_node_count": { - "anyOf": [ - { - "minimum": 1, - "type": "integer" - }, - { - "type": "null" - } - ], - "default": null, - "description": "Maximum number of nodes for elastic training", - "title": "Max Node Count" - }, - "elastic_graceful_shutdown_timeout_seconds": { - "anyOf": [ - { - "minimum": 0, - "type": "integer" - }, - { - "type": "null" - } - ], - "default": null, - "description": "Graceful shutdown timeout in seconds for elastic scaling operations", - "title": "Elastic Graceful Shutdown Timeout Seconds" - }, - "elastic_scaling_timeout": { - "anyOf": [ - { - "minLength": 1, - "type": "string" - }, - { - "type": "null" } - ], - "default": null, - "description": "Scaling timeout for elastic training", - "title": "Elastic Scaling Timeout" - } - }, "required": [ "job_name", @@ -437,4 +380,4 @@ ], "title": "PyTorchJobConfig", "type": "object" -} +} \ No newline at end of file diff --git a/src/sagemaker/hyperpod/training/config/hyperpod_pytorch_job_unified_config.py b/src/sagemaker/hyperpod/training/config/hyperpod_pytorch_job_unified_config.py index 5e2f30ba..a7855ef5 100644 --- a/src/sagemaker/hyperpod/training/config/hyperpod_pytorch_job_unified_config.py +++ b/src/sagemaker/hyperpod/training/config/hyperpod_pytorch_job_unified_config.py @@ -2982,11 +2982,6 @@ class ReplicaSpec(BaseModel): default=0, description="Replicas is the desired number of replicas of the given template.", ) - maxReplicas: Optional[int] = Field( - default=None, - alias="max_replicas", - description="Maximum replicas for elastic training" - ) spares: Optional[int] = Field( default=0, description="Spares requests spare resources from Kueue. E.g. If a job is configured with 4 replicas and 2 spares, job requests resources required to run 6 pods such as cpu, gpu", @@ -2996,52 +2991,6 @@ class ReplicaSpec(BaseModel): description="Template is the object that describes the pod that will be created for this replica.", ) -class ElasticPolicy(BaseModel): - """ElasticPolicy defines the elastic training policy""" - - model_config = ConfigDict(extra="forbid") - - replicaIncrementStep: Optional[int] = Field( - default=None, - alias="replica_increment_step", - description="Step size for elastic replica scaling" - ) - minReplicas: Optional[int] = Field( - default=None, - alias="min_replicas", - description="Minimum number of replicas" - ) - maxReplicas: Optional[int] = Field( - default=None, - alias="max_replicas", - description="Maximum number of replicas" - ) - replicaDiscreteValues: Optional[List[int]] = Field( - default=None, - alias="replica_discrete_values", - description="Alternative to ReplicaIncrementStep. Provides exact values for total replicas count" - ) - scalingTimeout: Optional[str] = Field( - default=None, - alias="scaling_timeout", - description="Timeout for scaling operations" - ) - scalingPolicy: Optional[str] = Field( - default=None, - alias="scaling_policy", - description="Scaling policy behavior (auto, aggressive, conservative)" - ) - gracefulShutdownTimeoutSeconds: Optional[int] = Field( - default=None, - alias="graceful_shutdown_timeout_seconds", - description="Graceful shutdown timeout in seconds for elastic scaling operations" - ) - faultyScaleDownTimeoutSeconds: Optional[int] = Field( - default=None, - alias="faulty_scale_down_timeout_seconds", - description="Timeout in seconds after entering Faulted state before triggering faulty pod scale-down" - ) - class LogMonitoringConfiguration(BaseModel): """LogMonitoringRule defines the criteria used to detect a SLOW or HANGING job""" @@ -3149,11 +3098,6 @@ class RunPolicy(BaseModel): alias="ttl_seconds_after_finished", description="TTLSecondsAfterFinished is the TTL to clean up jobs. Set to -1 for infinite", ) - workloadMode: Optional[str] = Field( - default=None, - alias="workload_mode", - description="Workload deployment mode for elastic training (e.g., 'Deployment')", - ) class PodSets(BaseModel): @@ -3209,43 +3153,6 @@ class Pods(BaseModel): ) -class ElasticScalingStatus(BaseModel): - """ElasticScalingStatus represents the current state of elastic scaling operations""" - - model_config = ConfigDict(extra="forbid") - - targetReplicas: Optional[Dict[str, int]] = Field( - default=None, - alias="target_replicas", - description="TargetReplicas contains the desired replica counts per ReplicaSpec name", - ) - lastUpdated: Optional[str] = Field( - default=None, - alias="last_updated", - description="LastUpdated is the timestamp when this status was last modified", - ) - lastScalingTime: Optional[str] = Field( - default=None, - alias="last_scaling_time", - description="LastScalingTime tracks when the last scaling operation completed", - ) - lastRestartTime: Optional[str] = Field( - default=None, - alias="last_restart_time", - description="LastRestartTime tracks when the job was last restarted for scaleUpRestartTimeout", - ) - podsScaled: Optional[bool] = Field( - default=None, - alias="pods_scaled", - description="PodsScaled indicates whether pods have already been scaled in this scaling round", - ) - isFaultyPodScaleDown: Optional[bool] = Field( - default=None, - alias="is_faulty_pod_scale_down", - description="IsFaultyPodScaleDown indicates this scaling operation is removing faulty pods", - ) - - class RestartStatus(BaseModel): """Additional restart limiting status""" @@ -3264,33 +3171,6 @@ class RestartStatus(BaseModel): ) -class FaultyPodInstanceList(BaseModel): - """FaultyPodInstanceRecord tracks faulty pod/instances for each restart""" - - model_config = ConfigDict(extra="forbid") - - restartType: Optional[str] = Field( - default=None, - alias="restart_type", - description="RestartType indicates whether this was a PLR or JLR" - ) - faultyInstanceIdList: Optional[List[str]] = Field( - default_factory=list, - alias="faulty_instance_id_list", - description="FaultyInstanceIdList tracks faulty instance ids" - ) - faultyPodList: Optional[List[str]] = Field( - default_factory=list, - alias="faulty_pod_list", - description="FaultyPodList tracks faulty pod names" - ) - faultyRankList: Optional[List[str]] = Field( - default_factory=list, - alias="faulty_rank_list", - description="FaultyRankList tracks faulty pod ranks" - ) - - class HyperPodPytorchJobStatus(BaseModel): """HyperPodPytorchJobStatus defines the observed state of HyperPodPytorchJob""" @@ -3307,11 +3187,6 @@ class HyperPodPytorchJobStatus(BaseModel): alias="job_pods", description="The StatefulSet containing the training pods", ) - latestFaultyPodInstanceList: Optional[FaultyPodInstanceList] = Field( - default=None, - alias="latest_faulty_pod_instance_list", - description="LatestFaultyPodInstanceList tracks faulty pods/nodes of latest restart" - ) managerPods: Optional[ManagerPods] = Field( default=None, alias="manager_pods", description="Pod Manager pods" ) @@ -3346,16 +3221,6 @@ class HyperPodPytorchJobStatus(BaseModel): alias="restart_status", description="Additional restart limiting status", ) - elasticScalingStatus: Optional[ElasticScalingStatus] = Field( - default=None, - alias="elastic_scaling_status", - description="ElasticScalingStatus contains the current state of elastic scaling operations", - ) - elasticWorkloadRef: Optional[Dict[str, str]] = Field( - default=None, - alias="elastic_workload_ref", - description="Reference to associated ElasticWorkload (optional, only set when ElasticPolicy is present)", - ) startTime: Optional[str] = Field( default=None, alias="start_time", @@ -3380,9 +3245,4 @@ class _HyperPodPytorchJob(BaseModel): ) runPolicy: Optional[RunPolicy] = Field( default=None, alias="run_policy", description="RunPolicy" - ) - elasticPolicy: Optional[ElasticPolicy] = Field( - default=None, - alias="elastic_policy", - description="ElasticPolicy for elastic training" - ) + ) \ No newline at end of file diff --git a/test/unit_tests/cli/test_training.py b/test/unit_tests/cli/test_training.py index fa05000e..95de870c 100644 --- a/test/unit_tests/cli/test_training.py +++ b/test/unit_tests/cli/test_training.py @@ -156,62 +156,6 @@ def test_optional_params(self): self.assertEqual(call_args["metadata"]["labels"]["kueue.x-k8s.io/queue-name"], "localqueue") self.assertEqual(call_args["metadata"]["annotations"]["kueue.x-k8s.io/podset-required-topology"], "topology.k8s.aws/ultraserver-id") - @patch('sys.argv', ['pytest', '--version', '1.1']) - def test_elastic_training_params(self): - """Test job creation with elastic training parameters""" - # Reload the training module with mocked sys.argv - if 'sagemaker.hyperpod.cli.commands.training' in sys.modules: - importlib.reload(sys.modules['sagemaker.hyperpod.cli.commands.training']) - - from sagemaker.hyperpod.cli.commands.training import pytorch_create - - with patch("hyperpod_pytorch_job_template.v1_1.model.HyperPodPytorchJob") as mock_hyperpod_job: - mock_instance = Mock() - mock_hyperpod_job.return_value = mock_instance - - result = self.runner.invoke( - pytorch_create, - [ - "--version", - "1.1", - "--job-name", - "elastic-test-job", - "--image", - "pytorch:latest", - "--elastic-replica-increment-step", - "2", - "--max-node-count", - "4", - "--elastic-graceful-shutdown-timeout-seconds", - "180", - "--elastic-scaling-timeout", - "30s", - ], - ) - - print(f"Command output: {result.output}") - - # Verify command succeeded - self.assertEqual(result.exit_code, 0) - self.assertIn("Using version: 1.1", result.output) - - # Verify HyperPodPytorchJob was created with elastic parameters - mock_hyperpod_job.assert_called_once() - call_args = mock_hyperpod_job.call_args[1] - - # Validate basic job configuration - self.assertEqual(call_args["metadata"]["name"], "elastic-test-job") - - # Validate elastic policy configuration - self.assertIsNotNone(call_args.get("elastic_policy")) - elastic_policy = call_args["elastic_policy"] - self.assertEqual(elastic_policy.replicaIncrementStep, 2) - self.assertEqual(elastic_policy.maxReplicas, 4) - self.assertEqual(elastic_policy.gracefulShutdownTimeoutSeconds, 180) - self.assertEqual(elastic_policy.scalingTimeout, "30s") - - mock_instance.create.assert_called_once() - @patch('sagemaker.hyperpod.common.cli_decorators._namespace_exists') @patch("sagemaker.hyperpod.cli.commands.training.HyperPodPytorchJob") def test_list_jobs(self, mock_hyperpod_pytorch_job, mock_namespace_exists): From 99c4705d8d6ced150a30e85d48128c613f2f5cc0 Mon Sep 17 00:00:00 2001 From: Brian Xia Date: Tue, 21 Oct 2025 11:35:17 -0700 Subject: [PATCH 03/31] Add dev_space_constants.py (#255) Co-authored-by: Brian Xia --- .../cli/constants/dev_space_constants.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 src/sagemaker/hyperpod/cli/constants/dev_space_constants.py diff --git a/src/sagemaker/hyperpod/cli/constants/dev_space_constants.py b/src/sagemaker/hyperpod/cli/constants/dev_space_constants.py new file mode 100644 index 00000000..3d6cab95 --- /dev/null +++ b/src/sagemaker/hyperpod/cli/constants/dev_space_constants.py @@ -0,0 +1,20 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +DEV_SPACE_GROUP = "sagemaker.aws.com" +DEV_SPACE_VERSION = "v1alpha1" +DEV_SPACE_PLURAL = "spaces" +DEFAULT_DEV_SPACE_PORT = "8888" +# Immutable fields that cannot be updated after dev space creation +IMMUTABLE_FIELDS = { + "storage_class_name", +} \ No newline at end of file From 6c219569fe52c5091d3191e40e11fd17e51438a7 Mon Sep 17 00:00:00 2001 From: Brian Xia Date: Wed, 22 Oct 2025 14:57:32 -0700 Subject: [PATCH 04/31] Add dev_space_access_constants.py (#256) Co-authored-by: Brian Xia --- .../cli/constants/dev_space_access_constants.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 src/sagemaker/hyperpod/cli/constants/dev_space_access_constants.py diff --git a/src/sagemaker/hyperpod/cli/constants/dev_space_access_constants.py b/src/sagemaker/hyperpod/cli/constants/dev_space_access_constants.py new file mode 100644 index 00000000..6d41c6a0 --- /dev/null +++ b/src/sagemaker/hyperpod/cli/constants/dev_space_access_constants.py @@ -0,0 +1,16 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +DEV_SPACE_ACCESS_GROUP = "access.devspaces.sagemaker.aws.com" +DEV_SPACE_ACCESS_VERSION = "v1alpha1" +DEV_SPACE_ACCESS_PLURAL = "devspaceaccess" From 47bdccc8ff7a7072f6a155081d304a2002306618 Mon Sep 17 00:00:00 2001 From: Brian Xia Date: Wed, 22 Oct 2025 14:57:50 -0700 Subject: [PATCH 05/31] Add space_admin_config_constants.py (#257) Co-authored-by: Brian Xia --- .../constants/space_admin_config_constants.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 src/sagemaker/hyperpod/cli/constants/space_admin_config_constants.py diff --git a/src/sagemaker/hyperpod/cli/constants/space_admin_config_constants.py b/src/sagemaker/hyperpod/cli/constants/space_admin_config_constants.py new file mode 100644 index 00000000..bd793538 --- /dev/null +++ b/src/sagemaker/hyperpod/cli/constants/space_admin_config_constants.py @@ -0,0 +1,16 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +SPACE_ADMIN_CONFIG_GROUP = "sagemaker.aws.com" +SPACE_ADMIN_CONFIG_VERSION = "v1alpha1" +SPACE_ADMIN_CONFIG_PLURAL = "spaceadminconfigs" From d2b76fa5527ae7ece5f608d98c28bcf56166aa76 Mon Sep 17 00:00:00 2001 From: Brian Xia Date: Thu, 23 Oct 2025 09:03:50 -0700 Subject: [PATCH 06/31] Add template package only (#261) Co-authored-by: Brian Xia --- .../hyperpod_dev_space_template/__init__.py | 12 ++ .../hyperpod_dev_space_template/registry.py | 20 ++ .../v1_0/__init__.py | 12 ++ .../hyperpod_dev_space_template/v1_0/model.py | 195 +++++++++++++++++ .../v1_0/schema.json | 201 ++++++++++++++++++ hyperpod-dev-space-template/pyproject.toml | 26 +++ hyperpod-dev-space-template/update_schema.py | 8 + 7 files changed, 474 insertions(+) create mode 100644 hyperpod-dev-space-template/hyperpod_dev_space_template/__init__.py create mode 100644 hyperpod-dev-space-template/hyperpod_dev_space_template/registry.py create mode 100644 hyperpod-dev-space-template/hyperpod_dev_space_template/v1_0/__init__.py create mode 100644 hyperpod-dev-space-template/hyperpod_dev_space_template/v1_0/model.py create mode 100644 hyperpod-dev-space-template/hyperpod_dev_space_template/v1_0/schema.json create mode 100644 hyperpod-dev-space-template/pyproject.toml create mode 100644 hyperpod-dev-space-template/update_schema.py diff --git a/hyperpod-dev-space-template/hyperpod_dev_space_template/__init__.py b/hyperpod-dev-space-template/hyperpod_dev_space_template/__init__.py new file mode 100644 index 00000000..65490521 --- /dev/null +++ b/hyperpod-dev-space-template/hyperpod_dev_space_template/__init__.py @@ -0,0 +1,12 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. diff --git a/hyperpod-dev-space-template/hyperpod_dev_space_template/registry.py b/hyperpod-dev-space-template/hyperpod_dev_space_template/registry.py new file mode 100644 index 00000000..bdf80082 --- /dev/null +++ b/hyperpod-dev-space-template/hyperpod_dev_space_template/registry.py @@ -0,0 +1,20 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from .v1_0.model import DevSpaceConfig +from typing import Dict, Type +from pydantic import BaseModel + +# Direct version-to-model mapping +SCHEMA_REGISTRY: Dict[str, Type[BaseModel]] = { + "1.0": DevSpaceConfig, +} diff --git a/hyperpod-dev-space-template/hyperpod_dev_space_template/v1_0/__init__.py b/hyperpod-dev-space-template/hyperpod_dev_space_template/v1_0/__init__.py new file mode 100644 index 00000000..65490521 --- /dev/null +++ b/hyperpod-dev-space-template/hyperpod_dev_space_template/v1_0/__init__.py @@ -0,0 +1,12 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. diff --git a/hyperpod-dev-space-template/hyperpod_dev_space_template/v1_0/model.py b/hyperpod-dev-space-template/hyperpod_dev_space_template/v1_0/model.py new file mode 100644 index 00000000..e30c4884 --- /dev/null +++ b/hyperpod-dev-space-template/hyperpod_dev_space_template/v1_0/model.py @@ -0,0 +1,195 @@ +from pydantic import BaseModel, ConfigDict, Field, field_validator +from typing import Optional, List, Dict, Literal +from enum import Enum + + +# TODO: Temporarily removed for private beta +# class VolumeConfig(BaseModel): +# name: str = Field( +# ..., +# description="Volume name", +# min_length=1 +# ) +# type: Literal['hostPath', 'pvc'] = Field(..., description="Volume type") +# mount_path: str = Field( +# ..., +# description="Mount path in container", +# min_length=1 +# ) +# path: Optional[str] = Field( +# None, +# description="Host path (required for hostPath volumes)", +# min_length=1 +# ) +# claim_name: Optional[str] = Field( +# None, +# description="PVC claim name (required for pvc volumes)", +# min_length=1 +# ) +# read_only: Optional[Literal['true', 'false']] = Field(None, description="Read-only flag for pvc volumes") + + +class SharedStatus(str, Enum): + PUBLIC = "public" + PRIVATE = "private" + + +class Application(str, Enum): + JUPYTER = "jupyter" + CODE_EDITOR = "code-editor" + + +class ResourcesConfig(BaseModel): + memory: Optional[str] = Field(default="1Gi", description="Memory limit") + cpu: Optional[str] = Field(default="500m", description="CPU limit") + nvidia_gpu: Optional[str] = Field(default=None, alias="nvidia.com/gpu", description="GPU limit") + + +class DevSpaceConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + + name: str = Field( + description="Dev space name", + min_length=1, + max_length=63, + pattern=r'^[a-z0-9]([-a-z0-9]*[a-z0-9])?$' + ) + image: Optional[str] = Field( + default="public.ecr.aws/sagemaker/sagemaker-distribution:3.2.0-cpu", + description="Container image for the dev space", + min_length=1 + ) + namespace: str = Field( + default="default", + description="Kubernetes namespace", + min_length=1 + ) + desired_status: Optional[Literal['Running', 'Stopped']] = Field( + default="Running", + alias="desired_status", + description="Desired status of the dev space" + ) + service_account_name: Optional[str] = Field( + default="default", + alias="service_account_name", + description="Service account name", + min_length=1 + ) + resources: Optional[ResourcesConfig] = Field( + default=ResourcesConfig(), + description="Resource limit" + ) + storage_class_name: Optional[str] = Field( + default=None, + alias="storage_class_name", + description="Storage class name", + min_length=1 + ) + storage_size: Optional[str] = Field( + default=None, + alias="storage_size", + description="Storage size (e.g., '10Gi')", + min_length=1 + ) + shared_status: Optional[SharedStatus] = Field( + default=SharedStatus.PRIVATE, + description="Space shared setting (private | public)" + ) + application: Optional[Application] = Field( + default=Application.JUPYTER, + description="Application to run in the container (jupyter | code-editor)" + ) + # TODO: Temporarily removed for private beta + # queue_name: Optional[str] = Field( + # default=None, + # alias="queue_name", + # description="Queue name for scheduling", + # min_length=1, + # max_length=63, + # pattern=r'^[a-z0-9]([-a-z0-9]*[a-z0-9])?$' + # ) + # priority: Optional[str] = Field( + # default=None, + # description="Priority class for scheduling", + # min_length=1 + # ) + # volume: Optional[List[VolumeConfig]] = Field( + # default=None, description="List of volume configurations. \ + # Command structure: --volume name=,type=,mount_path=, \ + # For hostPath: --volume name=model-data,type=hostPath,mount_path=/data,path=/data \ + # For persistentVolumeClaim: --volume name=training-output,type=pvc,mount_path=/mnt/output,claim_name=training-output-pvc,read_only=false \ + # If multiple --volume flag if multiple volumes are needed \ + # " + # ) + + # @field_validator('volume') + # def validate_no_duplicates(cls, v): + # """Validate no duplicate volume names or mount paths.""" + # if not v: + # return v + + # # Check for duplicate volume names + # names = [vol.name for vol in v] + # if len(names) != len(set(names)): + # raise ValueError("Duplicate volume names found") + + # # Check for duplicate mount paths + # mount_paths = [vol.mount_path for vol in v] + # if len(mount_paths) != len(set(mount_paths)): + # raise ValueError("Duplicate mount paths found") + + # return v + + def to_domain(self) -> Dict: + """ + Convert flat config to domain model for dev space creation + """ + # Create the dev space spec + spec = { + "image": self.image + } + + # Add optional spec fields + if self.desired_status is not None: + spec["desiredStatus"] = self.desired_status + if self.service_account_name is not None: + spec["serviceAccountName"] = self.service_account_name + if self.resources is not None: + spec["resources"] = self.resources.model_dump(exclude_none=True) + if self.storage_class_name is not None: + spec["storageClassName"] = self.storage_class_name + if self.storage_size is not None: + spec["storageSize"] = self.storage_size + if self.shared_status is not None: + spec["sharedStatus"] = self.shared_status.value + if self.application is not None: + spec["application"] = self.application.value + + # Create metadata + metadata = {"name": self.name} + if self.namespace is not None: + metadata["namespace"] = self.namespace + + # Add labels for scheduling + # labels = {} + # if self.queue_name is not None: + # labels["kueue.x-k8s.io/queue-name"] = self.queue_name + # if self.priority is not None: + # labels["kueue.x-k8s.io/priority-class"] = self.priority + + # if labels: + # metadata["labels"] = labels + + # Create the complete dev space configuration + dev_space_config = { + "apiVersion": "sagemaker.aws.com/v1alpha1", + "kind": "Space", + "metadata": metadata, + "spec": spec + } + + return { + "name": self.name, + "namespace": self.namespace, + "dev_space_spec": dev_space_config + } diff --git a/hyperpod-dev-space-template/hyperpod_dev_space_template/v1_0/schema.json b/hyperpod-dev-space-template/hyperpod_dev_space_template/v1_0/schema.json new file mode 100644 index 00000000..c235a896 --- /dev/null +++ b/hyperpod-dev-space-template/hyperpod_dev_space_template/v1_0/schema.json @@ -0,0 +1,201 @@ +{ + "$defs": { + "Application": { + "enum": [ + "jupyter", + "code-editor" + ], + "title": "Application", + "type": "string" + }, + "ResourcesConfig": { + "properties": { + "memory": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": "1Gi", + "description": "Memory limit", + "title": "Memory" + }, + "cpu": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": "500m", + "description": "CPU limit", + "title": "Cpu" + }, + "nvidia.com/gpu": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "GPU limit", + "title": "Nvidia.Com/Gpu" + } + }, + "title": "ResourcesConfig", + "type": "object" + }, + "SharedStatus": { + "enum": [ + "public", + "private" + ], + "title": "SharedStatus", + "type": "string" + } + }, + "additionalProperties": false, + "properties": { + "name": { + "description": "Dev space name", + "maxLength": 63, + "minLength": 1, + "pattern": "^[a-z0-9]([-a-z0-9]*[a-z0-9])?$", + "title": "Name", + "type": "string" + }, + "image": { + "anyOf": [ + { + "minLength": 1, + "type": "string" + }, + { + "type": "null" + } + ], + "default": "public.ecr.aws/sagemaker/sagemaker-distribution:3.2.0-cpu", + "description": "Container image for the dev space", + "title": "Image" + }, + "namespace": { + "default": "default", + "description": "Kubernetes namespace", + "minLength": 1, + "title": "Namespace", + "type": "string" + }, + "desired_status": { + "anyOf": [ + { + "enum": [ + "Running", + "Stopped" + ], + "type": "string" + }, + { + "type": "null" + } + ], + "default": "Running", + "description": "Desired status of the dev space", + "title": "Desired Status" + }, + "service_account_name": { + "anyOf": [ + { + "minLength": 1, + "type": "string" + }, + { + "type": "null" + } + ], + "default": "default", + "description": "Service account name", + "title": "Service Account Name" + }, + "resources": { + "anyOf": [ + { + "$ref": "#/$defs/ResourcesConfig" + }, + { + "type": "null" + } + ], + "default": { + "memory": "1Gi", + "cpu": "500m", + "nvidia.com/gpu": null + }, + "description": "Resource limit" + }, + "storage_class_name": { + "anyOf": [ + { + "minLength": 1, + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Storage class name", + "title": "Storage Class Name" + }, + "storage_size": { + "anyOf": [ + { + "minLength": 1, + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Storage size (e.g., '10Gi')", + "title": "Storage Size" + }, + "shared_status": { + "anyOf": [ + { + "$ref": "#/$defs/SharedStatus" + }, + { + "type": "null" + } + ], + "default": "private", + "description": "Space shared setting (private | public)" + }, + "application": { + "anyOf": [ + { + "$ref": "#/$defs/Application" + }, + { + "type": "null" + } + ], + "default": "jupyter", + "description": "Application to run in the container (jupyter | code-editor)" + } + }, + "required": [ + "name" + ], + "title": "DevSpaceConfig", + "type": "object" +} \ No newline at end of file diff --git a/hyperpod-dev-space-template/pyproject.toml b/hyperpod-dev-space-template/pyproject.toml new file mode 100644 index 00000000..817ce58c --- /dev/null +++ b/hyperpod-dev-space-template/pyproject.toml @@ -0,0 +1,26 @@ +[build-system] +requires = ["setuptools>=45", "wheel", "setuptools_scm[toml]>=6.2"] +build-backend = "setuptools.build_meta" + +[project] +name = "hyperpod-dev-space-template" +version = "1.0.0" +description = "Template for HyperPod Dev Space configuration" +authors = [ + {name = "Amazon Web Services"}, +] +license = {text = "Apache-2.0"} +requires-python = ">=3.8" +dependencies = [ + "pydantic>=2.0.0", +] + +[project.urls] +Homepage = "https://github.com/aws/sagemaker-hyperpod-cli" + +[tool.setuptools.packages.find] +where = ["."] +include = ["hyperpod_dev_space_template*"] + +[tool.setuptools.package-data] +"hyperpod_dev_space_template.v1_0" = ["schema.json"] diff --git a/hyperpod-dev-space-template/update_schema.py b/hyperpod-dev-space-template/update_schema.py new file mode 100644 index 00000000..01c0b87d --- /dev/null +++ b/hyperpod-dev-space-template/update_schema.py @@ -0,0 +1,8 @@ +#!/usr/bin/env python3 +import json +from hyperpod_dev_space_template.v1_0.model import DevSpaceConfig + +schema = DevSpaceConfig.model_json_schema() +with open('hyperpod_dev_space_template/v1_0/schema.json', 'w') as f: + json.dump(schema, f, indent=2) +print('✅ Schema updated!') From b8f7333aa64abee657751f3db1b6e7f1efc3d325 Mon Sep 17 00:00:00 2001 From: Brian Xia Date: Thu, 23 Oct 2025 16:53:04 -0700 Subject: [PATCH 07/31] Add dev_space.py CLI command (#263) * Add dev_space.py CLI command * Add dev space unit tests --------- Co-authored-by: Brian Xia --- .../hyperpod/cli/commands/dev_space.py | 226 +++++++ test/unit_tests/cli/test_dev_space.py | 574 ++++++++++++++++++ 2 files changed, 800 insertions(+) create mode 100644 src/sagemaker/hyperpod/cli/commands/dev_space.py create mode 100644 test/unit_tests/cli/test_dev_space.py diff --git a/src/sagemaker/hyperpod/cli/commands/dev_space.py b/src/sagemaker/hyperpod/cli/commands/dev_space.py new file mode 100644 index 00000000..4a3aa324 --- /dev/null +++ b/src/sagemaker/hyperpod/cli/commands/dev_space.py @@ -0,0 +1,226 @@ +import click +import json +from tabulate import tabulate +from sagemaker.hyperpod.cli.clients.kubernetes_client import KubernetesClient +from sagemaker.hyperpod.cli.dev_space_utils import generate_click_command +from hyperpod_dev_space_template.registry import SCHEMA_REGISTRY +from sagemaker.hyperpod.common.telemetry.telemetry_logging import ( + _hyperpod_telemetry_emitter, +) +from sagemaker.hyperpod.common.telemetry.constants import Feature + + +@click.command("hyp-dev-space") +@generate_click_command( + schema_pkg="hyperpod_dev_space_template", + registry=SCHEMA_REGISTRY, +) +def dev_space_create(version, config): + """Create a dev-space resource.""" + + try: + name = config.get("name") + namespace = config.get("namespace") + dev_space_spec = config.get("dev_space_spec") + + k8s_client = KubernetesClient() + k8s_client.create_dev_space(namespace, dev_space_spec) + + click.echo(f"Dev space '{name}' created successfully in namespace '{namespace}'") + except Exception as e: + click.echo(f"Error creating dev space: {e}", err=True) + + +@click.command("hyp-dev-space") +@click.option("--namespace", "-n", required=False, default="default", help="Kubernetes namespace") +@click.option("--output", "-o", type=click.Choice(["table", "json"]), default="table") +def dev_space_list(namespace, output): + """List dev-space resources.""" + k8s_client = KubernetesClient() + + try: + resources = k8s_client.list_dev_spaces(namespace) + + if output == "json": + click.echo(json.dumps(resources, indent=2)) + else: + items = resources.get("items", []) + if items: + table_data = [] + for item in items: + table_data.append([ + item["metadata"]["name"], + item["metadata"]["namespace"], + item.get("status", {}).get("phase", "Unknown") + ]) + click.echo(tabulate(table_data, headers=["NAME", "NAMESPACE", "STATUS"])) + else: + click.echo("No dev spaces found") + except Exception as e: + click.echo(f"Error listing dev spaces: {e}", err=True) + + +@click.command("hyp-dev-space") +@click.option("--name", required=True, help="Name of the dev space") +@click.option("--namespace", "-n", required=False, default="default", help="Kubernetes namespace") +@click.option("--output", "-o", type=click.Choice(["yaml", "json"]), default="yaml") +def dev_space_describe(name, namespace, output): + """Describe a dev-space resource.""" + k8s_client = KubernetesClient() + + try: + resource = k8s_client.get_dev_space(namespace, name) + resource["metadata"].pop('managedFields', None) + + if output == "json": + click.echo(json.dumps(resource, indent=2)) + else: + import yaml + click.echo(yaml.dump(resource, default_flow_style=False)) + except Exception as e: + click.echo(f"Error describing dev space '{name}': {e}", err=True) + + +@click.command("hyp-dev-space") +@click.option("--name", required=True, help="Name of the dev space") +@click.option("--namespace", "-n", required=False, default="default", help="Kubernetes namespace") +def dev_space_delete(name, namespace): + """Delete a dev-space resource.""" + k8s_client = KubernetesClient() + + try: + k8s_client.delete_dev_space(namespace, name) + + click.echo(f"Dev space '{name}' deleted successfully") + except Exception as e: + click.echo(f"Error deleting dev space '{name}': {e}", err=True) + + +@click.command("hyp-dev-space") +@generate_click_command( + schema_pkg="hyperpod_dev_space_template", + registry=SCHEMA_REGISTRY, + is_update=True, +) +def dev_space_update(version, config): + """Update a dev-space resource.""" + k8s_client = KubernetesClient() + + try: + name = config["name"] + namespace = config["namespace"] + dev_space_spec = config.get("dev_space_spec", {}) + + k8s_client.patch_dev_space( + namespace=namespace, + name=name, + body=dev_space_spec + ) + + click.echo(f"Dev space '{name}' updated successfully") + except Exception as e: + click.echo(f"Error updating dev space '{name}': {e}", err=True) + + +@click.command("hyp-dev-space") +@click.option("--name", required=True, help="Name of the dev space") +@click.option("--namespace", "-n", required=False, default="default", help="Kubernetes namespace") +def dev_space_start(name, namespace): + """Start a dev-space resource.""" + k8s_client = KubernetesClient() + + try: + # Patch the resource to set desired status to "Running" + patch_body = {"spec": {"desiredStatus": "Running"}} + k8s_client.patch_dev_space( + namespace=namespace, + name=name, + body=patch_body + ) + + click.echo(f"Dev space '{name}' start requested") + except Exception as e: + click.echo(f"Error starting dev space '{name}': {e}", err=True) + + +@click.command("hyp-dev-space") +@click.option("--name", required=True, help="Name of the dev space") +@click.option("--namespace", "-n", required=False, default="default", help="Kubernetes namespace") +def dev_space_stop(name, namespace): + """Stop a dev-space resource.""" + k8s_client = KubernetesClient() + + try: + # Patch the resource to set desired status to "Stopped" + patch_body = {"spec": {"desiredStatus": "Stopped"}} + k8s_client.patch_dev_space( + namespace=namespace, + name=name, + body=patch_body + ) + + click.echo(f"Dev space '{name}' stop requested") + except Exception as e: + click.echo(f"Error stopping dev space '{name}': {e}", err=True) + + +@click.command("hyp-dev-space") +@click.option("--name", required=True, help="Name of the dev space") +@click.option("--namespace", "-n", required=False, default="default", help="Kubernetes namespace") +def dev_space_get_logs(name, namespace): + """Get logs for a dev-space resource.""" + k8s_client = KubernetesClient() + + try: + # Get pods associated with the dev space + pods = k8s_client.list_pods_with_labels( + namespace=namespace, + label_selector=f"sagemaker.aws.com/space-name={name}" + ) + + if not pods.items: + click.echo(f"No pods found for dev space '{name}'") + return + + # Get logs from the first pod + pod_name = pods.items[0].metadata.name + logs = k8s_client.get_logs_for_pod( + pod_name=pod_name, + namespace=namespace, + ) + + click.echo(logs) + except Exception as e: + click.echo(f"Error getting logs for dev space '{name}': {e}", err=True) + + +@click.command("hyp-dev-space") +@click.option("--name", required=True, help="Name of the dev space") +@click.option("--namespace", "-n", required=False, default="default", help="Kubernetes namespace") +@click.option("--port", required=True, help="Mapping localhost port to pod") +def dev_space_port_forward(name, namespace, port): + """Forward a local port to a dev-space pod.""" + k8s_client = KubernetesClient() + + try: + # Get pods associated with the dev space + pods = k8s_client.list_pods_with_labels( + namespace=namespace, + label_selector=f"sagemaker.aws.com/space-name={name}" + ) + + if not pods.items: + click.echo(f"No pods found for dev space '{name}'") + return + + # Get the first running pod + pod_name = pods.items[0].metadata.name + + k8s_client.port_forward_dev_space( + namespace=namespace, + pod_name=pod_name, + local_port=port, + ) + + except Exception as e: + click.echo(f"Error forwarding port for dev space '{name}': {e}", err=True) diff --git a/test/unit_tests/cli/test_dev_space.py b/test/unit_tests/cli/test_dev_space.py new file mode 100644 index 00000000..2b1ac434 --- /dev/null +++ b/test/unit_tests/cli/test_dev_space.py @@ -0,0 +1,574 @@ +import pytest +import json +from click.testing import CliRunner +from unittest.mock import Mock, patch, MagicMock + +from sagemaker.hyperpod.cli.commands.dev_space import ( + dev_space_create, + dev_space_list, + dev_space_describe, + dev_space_delete, + dev_space_update, + dev_space_start, + dev_space_stop, + dev_space_get_logs, + dev_space_port_forward, +) + + +class TestDevSpaceCommands: + """Test cases for dev space commands""" + + def setup_method(self): + self.runner = CliRunner() + self.mock_k8s_client = Mock() + + @patch('sagemaker.hyperpod.cli.dev_space_utils.load_schema_for_version') + @patch('sagemaker.hyperpod.cli.commands.dev_space.KubernetesClient') + def test_dev_space_create_success(self, mock_k8s_client_class, mock_load_schema): + """Test successful dev space creation""" + # Mock schema loading + mock_load_schema.return_value = { + "properties": { + "name": {"type": "string"}, + "namespace": {"type": "string"} + }, + "required": ["name"] + } + + # Mock model registry + mock_model = Mock() + mock_model.return_value = Mock() + mock_model.return_value.to_domain.return_value = { + "name": "test-space", + "namespace": "test-ns", + "dev_space_spec": {"spec": {"image": "test-image"}} + } + + # Mock KubernetesClient + mock_k8s_instance = Mock() + mock_k8s_client_class.return_value = mock_k8s_instance + + with patch('hyperpod_dev_space_template.registry.SCHEMA_REGISTRY', {'1.0': mock_model}): + result = self.runner.invoke(dev_space_create, [ + '--version', '1.0', + '--name', 'test-space', + '--namespace', 'test-ns' + ]) + + assert result.exit_code == 0 + assert "Dev space 'test-space' created successfully" in result.output + mock_k8s_instance.create_dev_space.assert_called_once() + + @patch('sagemaker.hyperpod.cli.dev_space_utils.load_schema_for_version') + def test_dev_space_create_missing_required_args(self, mock_load_schema): + """Test dev space creation with missing required arguments""" + mock_load_schema.return_value = { + "properties": {"name": {"type": "string"}}, + "required": ["name"] + } + + result = self.runner.invoke(dev_space_create, ['--version', '1.0']) + assert result.exit_code != 0 + assert 'Missing option' in result.output + + @patch('sagemaker.hyperpod.cli.commands.dev_space.KubernetesClient') + def test_dev_space_create_k8s_error(self, mock_k8s_client_class): + """Test dev space creation error handling""" + mock_k8s_instance = Mock() + mock_k8s_instance.create_dev_space.side_effect = Exception("Creation failed") + mock_k8s_client_class.return_value = mock_k8s_instance + + mock_model = Mock() + mock_model.return_value = Mock() + mock_model.return_value.to_domain.return_value = { + "name": "test-space", + "namespace": "test-ns", + "dev_space_spec": {} + } + + with patch('hyperpod_dev_space_template.registry.SCHEMA_REGISTRY', {'1.0': mock_model}): + with patch('sagemaker.hyperpod.cli.dev_space_utils.load_schema_for_version') as mock_load_schema: + mock_load_schema.return_value = { + "properties": { + "name": {"type": "string"}, + "namespace": {"type": "string"} + }, + "required": ["name", "namespace"] + } + result = self.runner.invoke(dev_space_create, [ + '--version', '1.0', + '--name', 'test-space', + '--namespace', 'test-ns' + ]) + + assert result.exit_code == 0 + assert "Error creating dev space: Creation failed" in result.output + + @patch('sagemaker.hyperpod.cli.commands.dev_space.KubernetesClient') + def test_dev_space_list_table_output(self, mock_k8s_client_class): + """Test dev space list with table output""" + mock_k8s_instance = Mock() + mock_k8s_instance.list_dev_spaces.return_value = { + "items": [ + { + "metadata": {"name": "space1", "namespace": "ns1"}, + "status": {"phase": "Running"} + }, + { + "metadata": {"name": "space2", "namespace": "ns2"}, + "status": {"phase": "Stopped"} + } + ] + } + mock_k8s_client_class.return_value = mock_k8s_instance + + result = self.runner.invoke(dev_space_list, [ + '--namespace', 'test-ns', + '--output', 'table' + ]) + + assert result.exit_code == 0 + assert "space1" in result.output + assert "space2" in result.output + mock_k8s_instance.list_dev_spaces.assert_called_once_with('test-ns') + + @patch('sagemaker.hyperpod.cli.commands.dev_space.KubernetesClient') + def test_dev_space_list_json_output(self, mock_k8s_client_class): + """Test dev space list with JSON output""" + mock_resources = { + "items": [ + {"metadata": {"name": "space1", "namespace": "ns1"}} + ] + } + mock_k8s_instance = Mock() + mock_k8s_instance.list_dev_spaces.return_value = mock_resources + mock_k8s_client_class.return_value = mock_k8s_instance + + result = self.runner.invoke(dev_space_list, [ + '--namespace', 'test-ns', + '--output', 'json' + ]) + + assert result.exit_code == 0 + output_json = json.loads(result.output) + assert output_json == mock_resources + + @patch('sagemaker.hyperpod.cli.commands.dev_space.KubernetesClient') + def test_dev_space_list_empty(self, mock_k8s_client_class): + """Test dev space list with no items""" + mock_k8s_instance = Mock() + mock_k8s_instance.list_dev_spaces.return_value = {"items": []} + mock_k8s_client_class.return_value = mock_k8s_instance + + result = self.runner.invoke(dev_space_list, [ + '--namespace', 'test-ns' + ]) + + assert result.exit_code == 0 + assert "No dev spaces found" in result.output + + @patch('sagemaker.hyperpod.cli.commands.dev_space.KubernetesClient') + def test_dev_space_list_error(self, mock_k8s_client_class): + """Test dev space list error handling""" + mock_k8s_instance = Mock() + mock_k8s_instance.list_dev_spaces.side_effect = Exception("List failed") + mock_k8s_client_class.return_value = mock_k8s_instance + + result = self.runner.invoke(dev_space_list, [ + '--namespace', 'test-ns' + ]) + + assert result.exit_code == 0 + assert "Error listing dev spaces: List failed" in result.output + + @patch('sagemaker.hyperpod.cli.commands.dev_space.KubernetesClient') + def test_dev_space_describe_yaml_output(self, mock_k8s_client_class): + """Test dev space describe with YAML output""" + mock_resource = {"metadata": {"name": "test-space"}} + mock_k8s_instance = Mock() + mock_k8s_instance.get_dev_space.return_value = mock_resource + mock_k8s_client_class.return_value = mock_k8s_instance + + with patch('yaml.dump') as mock_yaml_dump: + mock_yaml_dump.return_value = "yaml_output" + result = self.runner.invoke(dev_space_describe, [ + '--name', 'test-space', + '--namespace', 'test-ns', + ]) + + assert result.exit_code == 0 + assert "yaml_output" in result.output + mock_k8s_instance.get_dev_space.assert_called_once_with('test-ns', 'test-space') + + @patch('sagemaker.hyperpod.cli.commands.dev_space.KubernetesClient') + def test_dev_space_describe_json_output(self, mock_k8s_client_class): + """Test dev space describe with JSON output""" + mock_resource = {"metadata": {"name": "test-space"}} + mock_k8s_instance = Mock() + mock_k8s_instance.get_dev_space.return_value = mock_resource + mock_k8s_client_class.return_value = mock_k8s_instance + + result = self.runner.invoke(dev_space_describe, [ + '--name', 'test-space', + '--namespace', 'test-ns', + '--output', 'json' + ]) + + assert result.exit_code == 0 + output_json = json.loads(result.output) + assert output_json == mock_resource + + @patch('sagemaker.hyperpod.cli.commands.dev_space.KubernetesClient') + def test_dev_space_describe_k8s_error(self, mock_k8s_client_class): + """Test dev space describe error handling""" + mock_k8s_instance = Mock() + mock_k8s_instance.get_dev_space.side_effect = Exception("Describe failed") + mock_k8s_client_class.return_value = mock_k8s_instance + + result = self.runner.invoke(dev_space_describe, [ + '--name', 'test-space', + '--namespace', 'test-ns' + ]) + + assert result.exit_code == 0 + assert "Error describing dev space 'test-space': Describe failed" in result.output + + @patch('sagemaker.hyperpod.cli.commands.dev_space.KubernetesClient') + def test_dev_space_delete_success(self, mock_k8s_client_class): + """Test successful dev space deletion""" + mock_k8s_instance = Mock() + mock_k8s_client_class.return_value = mock_k8s_instance + + result = self.runner.invoke(dev_space_delete, [ + '--name', 'test-space', + '--namespace', 'test-ns' + ]) + + assert result.exit_code == 0 + assert "Dev space 'test-space' deleted successfully" in result.output + mock_k8s_instance.delete_dev_space.assert_called_once_with('test-ns', 'test-space') + + @patch('sagemaker.hyperpod.cli.commands.dev_space.KubernetesClient') + def test_dev_space_delete_k8s_error(self, mock_k8s_client_class): + """Test dev space delete error handling""" + mock_k8s_instance = Mock() + mock_k8s_instance.delete_dev_space.side_effect = Exception("Delete failed") + mock_k8s_client_class.return_value = mock_k8s_instance + + result = self.runner.invoke(dev_space_delete, [ + '--name', 'test-space', + '--namespace', 'test-ns' + ]) + + assert result.exit_code == 0 + assert "Error deleting dev space 'test-space': Delete failed" in result.output + + @patch('sagemaker.hyperpod.cli.dev_space_utils.load_schema_for_version') + @patch('sagemaker.hyperpod.cli.commands.dev_space.KubernetesClient') + def test_dev_space_update_success(self, mock_k8s_client_class, mock_load_schema): + """Test successful dev space update""" + # Mock schema loading + mock_load_schema.return_value = { + "properties": { + "name": {"type": "string"}, + "namespace": {"type": "string"} + }, + "required": ["name"] + } + + # Mock model registry + mock_model = Mock() + mock_model.return_value = Mock() + mock_model.return_value.to_domain.return_value = { + "name": "test-space", + "namespace": "test-ns", + "dev_space_spec": {"spec": {"image": "updated-image"}} + } + + # Mock KubernetesClient + mock_k8s_instance = Mock() + mock_k8s_client_class.return_value = mock_k8s_instance + + with patch('hyperpod_dev_space_template.registry.SCHEMA_REGISTRY', {'1.0': mock_model}): + result = self.runner.invoke(dev_space_update, [ + '--version', '1.0', + '--name', 'test-space', + '--namespace', 'test-ns' + ]) + + assert result.exit_code == 0 + assert "Dev space 'test-space' updated successfully" in result.output + mock_k8s_instance.patch_dev_space.assert_called_once() + + @patch('sagemaker.hyperpod.cli.commands.dev_space.KubernetesClient') + def test_dev_space_update_k8_error(self, mock_k8s_client_class): + """Test dev space update error handling""" + mock_k8s_instance = Mock() + mock_k8s_instance.patch_dev_space.side_effect = Exception("Update failed") + mock_k8s_client_class.return_value = mock_k8s_instance + + mock_model = Mock() + mock_model.return_value = Mock() + mock_model.return_value.to_domain.return_value = { + "name": "test-space", + "namespace": "test-ns", + "dev_space_spec": {} + } + + with patch('hyperpod_dev_space_template.registry.SCHEMA_REGISTRY', {'1.0': mock_model}): + with patch('sagemaker.hyperpod.cli.dev_space_utils.load_schema_for_version') as mock_load_schema: + mock_load_schema.return_value = { + "properties": { + "name": {"type": "string"}, + "namespace": {"type": "string"} + }, + "required": ["name", "namespace"] + } + result = self.runner.invoke(dev_space_update, [ + '--version', '1.0', + '--name', 'test-space', + '--namespace', 'test-ns' + ]) + + assert result.exit_code == 0 + assert "Error updating dev space 'test-space': Update failed" in result.output + + @patch('sagemaker.hyperpod.cli.commands.dev_space.KubernetesClient') + def test_dev_space_start_success(self, mock_k8s_client_class): + """Test successful dev space start""" + mock_k8s_instance = Mock() + mock_k8s_client_class.return_value = mock_k8s_instance + + result = self.runner.invoke(dev_space_start, [ + '--name', 'test-space', + '--namespace', 'test-ns' + ]) + + assert result.exit_code == 0 + assert "Dev space 'test-space' start requested" in result.output + mock_k8s_instance.patch_dev_space.assert_called_once_with( + namespace='test-ns', + name='test-space', + body={"spec": {"desiredStatus": "Running"}} + ) + + @patch('sagemaker.hyperpod.cli.commands.dev_space.KubernetesClient') + def test_dev_space_start_k8s_error(self, mock_k8s_client_class): + """Test dev space start error handling""" + mock_k8s_instance = Mock() + mock_k8s_instance.patch_dev_space.side_effect = Exception("Start failed") + mock_k8s_client_class.return_value = mock_k8s_instance + + result = self.runner.invoke(dev_space_start, [ + '--name', 'test-space', + '--namespace', 'test-ns' + ]) + + assert result.exit_code == 0 + assert "Error starting dev space 'test-space': Start failed" in result.output + + @patch('sagemaker.hyperpod.cli.commands.dev_space.KubernetesClient') + def test_dev_space_stop_success(self, mock_k8s_client_class): + """Test successful dev space stop""" + mock_k8s_instance = Mock() + mock_k8s_client_class.return_value = mock_k8s_instance + + result = self.runner.invoke(dev_space_stop, [ + '--name', 'test-space', + '--namespace', 'test-ns' + ]) + + assert result.exit_code == 0 + assert "Dev space 'test-space' stop requested" in result.output + mock_k8s_instance.patch_dev_space.assert_called_once_with( + namespace='test-ns', + name='test-space', + body={"spec": {"desiredStatus": "Stopped"}} + ) + + @patch('sagemaker.hyperpod.cli.commands.dev_space.KubernetesClient') + def test_dev_space_stop_k8s_error(self, mock_k8s_client_class): + """Test dev space stop error handling""" + mock_k8s_instance = Mock() + mock_k8s_instance.patch_dev_space.side_effect = Exception("Stop failed") + mock_k8s_client_class.return_value = mock_k8s_instance + + result = self.runner.invoke(dev_space_stop, [ + '--name', 'test-space', + '--namespace', 'test-ns' + ]) + + assert result.exit_code == 0 + assert "Error stopping dev space 'test-space': Stop failed" in result.output + + @patch('sagemaker.hyperpod.cli.commands.dev_space.KubernetesClient') + def test_dev_space_get_logs_success(self, mock_k8s_client_class): + """Test successful dev space get logs""" + mock_pod = Mock() + mock_pod.metadata.name = "test-pod" + mock_pods = Mock() + mock_pods.items = [mock_pod] + + mock_k8s_instance = Mock() + mock_k8s_instance.list_pods_with_labels.return_value = mock_pods + mock_k8s_instance.get_logs_for_pod.return_value = "test logs" + mock_k8s_client_class.return_value = mock_k8s_instance + + result = self.runner.invoke(dev_space_get_logs, [ + '--name', 'test-space', + '--namespace', 'test-ns' + ]) + + assert result.exit_code == 0 + assert "test logs" in result.output + mock_k8s_instance.list_pods_with_labels.assert_called_once_with( + namespace='test-ns', + label_selector='sagemaker.aws.com/space-name=test-space' + ) + mock_k8s_instance.get_logs_for_pod.assert_called_once_with( + pod_name='test-pod', + namespace='test-ns' + ) + + @patch('sagemaker.hyperpod.cli.commands.dev_space.KubernetesClient') + def test_dev_space_get_logs_no_pods(self, mock_k8s_client_class): + """Test dev space get logs with no pods""" + mock_pods = Mock() + mock_pods.items = [] + + mock_k8s_instance = Mock() + mock_k8s_instance.list_pods_with_labels.return_value = mock_pods + mock_k8s_client_class.return_value = mock_k8s_instance + + result = self.runner.invoke(dev_space_get_logs, [ + '--name', 'test-space', + '--namespace', 'test-ns' + ]) + + assert result.exit_code == 0 + assert "No pods found for dev space 'test-space'" in result.output + + @patch('sagemaker.hyperpod.cli.commands.dev_space.KubernetesClient') + def test_dev_space_get_logs_k8s_error(self, mock_k8s_client_class): + """Test dev space get logs error handling""" + mock_k8s_instance = Mock() + mock_k8s_instance.list_pods_with_labels.side_effect = Exception("List pod failed") + mock_k8s_client_class.return_value = mock_k8s_instance + + result = self.runner.invoke(dev_space_get_logs, [ + '--name', 'test-space', + '--namespace', 'test-ns' + ]) + + assert result.exit_code == 0 + assert "Error getting logs for dev space 'test-space': List pod failed" in result.output + + @patch('sagemaker.hyperpod.cli.commands.dev_space.KubernetesClient') + def test_dev_space_port_forward_success(self, mock_k8s_client_class): + """Test successful dev space port forward""" + mock_pod = Mock() + mock_pod.metadata.name = "test-pod" + mock_pods = Mock() + mock_pods.items = [mock_pod] + + mock_k8s_instance = Mock() + mock_k8s_instance.list_pods_with_labels.return_value = mock_pods + mock_k8s_client_class.return_value = mock_k8s_instance + + result = self.runner.invoke(dev_space_port_forward, [ + '--name', 'test-space', + '--namespace', 'test-ns', + '--port', '8080' + ]) + + assert result.exit_code == 0 + mock_k8s_instance.list_pods_with_labels.assert_called_once_with( + namespace='test-ns', + label_selector='sagemaker.aws.com/space-name=test-space' + ) + mock_k8s_instance.port_forward_dev_space.assert_called_once_with( + namespace='test-ns', + pod_name='test-pod', + local_port='8080' + ) + + @patch('sagemaker.hyperpod.cli.commands.dev_space.KubernetesClient') + def test_dev_space_port_forward_no_pods(self, mock_k8s_client_class): + """Test dev space port forward with no pods""" + mock_pods = Mock() + mock_pods.items = [] + + mock_k8s_instance = Mock() + mock_k8s_instance.list_pods_with_labels.return_value = mock_pods + mock_k8s_client_class.return_value = mock_k8s_instance + + result = self.runner.invoke(dev_space_port_forward, [ + '--name', 'test-space', + '--namespace', 'test-ns', + '--port', '8080' + ]) + + assert result.exit_code == 0 + assert "No pods found for dev space 'test-space'" in result.output + + @patch('sagemaker.hyperpod.cli.commands.dev_space.KubernetesClient') + def test_dev_space_port_forward_error(self, mock_k8s_client_class): + """Test dev space port forward error handling""" + mock_k8s_instance = Mock() + mock_k8s_instance.list_pods_with_labels.side_effect = Exception("Port forward failed") + mock_k8s_client_class.return_value = mock_k8s_instance + + result = self.runner.invoke(dev_space_port_forward, [ + '--name', 'test-space', + '--namespace', 'test-ns', + '--port', '8080' + ]) + + assert result.exit_code == 0 + assert "Error forwarding port for dev space 'test-space': Port forward failed" in result.output + + def test_missing_required_arguments(self): + """Test commands with missing required arguments""" + # Test create without name + result = self.runner.invoke(dev_space_create, ['--namespace', 'test-ns']) + assert result.exit_code == 2 + assert "Missing option '--name'" in result.output + + # Test describe without name + result = self.runner.invoke(dev_space_describe, ['--namespace', 'test-ns']) + assert result.exit_code == 2 + assert "Missing option '--name'" in result.output + + # Test delete without name + result = self.runner.invoke(dev_space_delete, ['--namespace', 'test-ns']) + assert result.exit_code == 2 + assert "Missing option '--name'" in result.output + + # Test update without name + result = self.runner.invoke(dev_space_update, ['--namespace', 'test-ns']) + assert result.exit_code == 2 + assert "Missing option '--name'" in result.output + + # Test start without name + result = self.runner.invoke(dev_space_start, ['--namespace', 'test-ns']) + assert result.exit_code == 2 + assert "Missing option '--name'" in result.output + + # Test stop without name + result = self.runner.invoke(dev_space_stop, ['--namespace', 'test-ns']) + assert result.exit_code == 2 + assert "Missing option '--name'" in result.output + + # Test get logs without name + result = self.runner.invoke(dev_space_get_logs, ['--namespace', 'test-ns']) + assert result.exit_code == 2 + assert "Missing option '--name'" in result.output + + # Test port forward without port + result = self.runner.invoke(dev_space_port_forward, [ + '--name', 'test-space', + '--namespace', 'test-ns' + ]) + assert result.exit_code == 2 + assert "Missing option '--port'" in result.output From fd7e6446abfb6e3074337daeb14a4e5ec5166ee9 Mon Sep 17 00:00:00 2001 From: Brian Xia Date: Thu, 23 Oct 2025 16:53:27 -0700 Subject: [PATCH 08/31] Add dev_space_utils.py to work with the dev space template model (#262) * Add dev_space_utils.py * Add unit tests for dev_space_utils --------- Co-authored-by: Brian Xia --- src/sagemaker/hyperpod/cli/dev_space_utils.py | 160 ++++++++ test/unit_tests/cli/test_dev_space_utils.py | 363 ++++++++++++++++++ 2 files changed, 523 insertions(+) create mode 100644 src/sagemaker/hyperpod/cli/dev_space_utils.py create mode 100644 test/unit_tests/cli/test_dev_space_utils.py diff --git a/src/sagemaker/hyperpod/cli/dev_space_utils.py b/src/sagemaker/hyperpod/cli/dev_space_utils.py new file mode 100644 index 00000000..f9f94d03 --- /dev/null +++ b/src/sagemaker/hyperpod/cli/dev_space_utils.py @@ -0,0 +1,160 @@ +import json +import pkgutil +import click +from typing import Callable, Optional, Mapping, Type, Dict, Any +from pydantic import ValidationError +from sagemaker.hyperpod.cli.constants.dev_space_constants import IMMUTABLE_FIELDS + + +def load_schema_for_version( + version: str, + base_package: str, +) -> dict: + """ + Load schema.json from the top-level .vX_Y_Z package. + """ + ver_pkg = f"{base_package}.v{version.replace('.', '_')}" + raw = pkgutil.get_data(ver_pkg, "schema.json") + if raw is None: + raise click.ClickException( + f"Could not load schema.json for version {version} " + f"(looked in package {ver_pkg})" + ) + return json.loads(raw) + + +def generate_click_command( + *, + version_key: Optional[str] = None, + schema_pkg: str = "hyperpod_dev_space_template", + registry: Mapping[str, Type] = None, + is_update: bool = False, +) -> Callable: + """ + Decorator factory for dev space commands. + """ + if registry is None: + raise ValueError("You must pass a registry mapping version→Model") + + # get schema defaults for manually handled options + schema = load_schema_for_version(version_key or "1.0", schema_pkg) + props = schema.get("properties", {}) + + def decorator(func: Callable) -> Callable: + # build resources from CPU/memory options + def _build_resources(cpu, memory, gpu): + if cpu is None and memory is None and gpu is None: + return None + + default_resources = props["resources"]["default"] + return { + "cpu": cpu or default_resources["cpu"], + "memory": memory or default_resources["memory"], + "nvidia.com/gpu": gpu or default_resources["nvidia.com/gpu"] + } + + # 1) the wrapper click will call + def wrapped_func(*args, **kwargs): + version = version_key or kwargs.pop("version", "1.0") + + Model = registry.get(version) + if Model is None: + raise click.ClickException(f"Unsupported schema version: {version}") + + resources = _build_resources(kwargs.pop("cpu", None), kwargs.pop("memory", None), kwargs.pop("gpu", None)) + if resources is not None: + kwargs["resources"] = resources + + # filter out None/empty values so Pydantic model defaults apply + filtered_kwargs = {} + for key, value in kwargs.items(): + if value is not None: + filtered_kwargs[key] = value + + try: + flat = Model(**filtered_kwargs) + domain_config = flat.to_domain() + except ValidationError as e: + error_messages = [] + for err in e.errors(): + loc = ".".join(str(x) for x in err["loc"]) + msg = err["msg"] + error_messages.append(f" – {loc}: {msg}") + + raise click.UsageError( + f"❌ Configuration validation errors:\n" + "\n".join(error_messages) + ) + + return func(version, domain_config) + + # 2) inject click options from JSON Schema + wrapped_func = click.option( + "--cpu", + type=str, + default=None, + help="CPU resource, e.g. '250m'", + )(wrapped_func) + + wrapped_func = click.option( + "--memory", + type=str, + default=None, + help="Memory resource, e.g. '256Mi'", + )(wrapped_func) + + wrapped_func = click.option( + "--gpu", + type=str, + default=None, + help="Gpu resource, e.g. '1'", + )(wrapped_func) + + # Exclude the props that were handled out of the below for loop + excluded_props = set( + [ + "resources", + "version", + ] + ) + + # 3) auto-inject all schema.json fields + reqs = set(schema.get("required", [])) + + for name, spec in reversed(list(props.items())): + if name in excluded_props: + continue + + if is_update and name in IMMUTABLE_FIELDS: + continue + + # infer click type + if "enum" in spec: + ctype = click.Choice(spec["enum"]) + elif spec.get("type") == "integer": + ctype = int + elif spec.get("type") == "number": + ctype = float + elif spec.get("type") == "boolean": + ctype = bool + else: + ctype = str + + wrapped_func = click.option( + f"--{name.replace('_','-')}", + required=(name in reqs), + default=spec.get("default", None), + type=ctype, + help=spec.get("description", ""), + )(wrapped_func) + + # 4) if no hard-coded version_key, inject the top-level --version flag + if version_key is None: + wrapped_func = click.option( + "--version", + default="1.0", + help="Schema version to use", + )(wrapped_func) + + return wrapped_func + + return decorator diff --git a/test/unit_tests/cli/test_dev_space_utils.py b/test/unit_tests/cli/test_dev_space_utils.py new file mode 100644 index 00000000..d9e3a203 --- /dev/null +++ b/test/unit_tests/cli/test_dev_space_utils.py @@ -0,0 +1,363 @@ +import pytest +import json +import click +from click.testing import CliRunner +from unittest.mock import Mock, patch +from pydantic import ValidationError + +from sagemaker.hyperpod.cli.dev_space_utils import load_schema_for_version, generate_click_command + + +class TestLoadSchemaForVersion: + @patch('sagemaker.hyperpod.cli.dev_space_utils.pkgutil.get_data') + def test_success(self, mock_get_data): + """Test successful schema loading""" + data = {"properties": {"name": {"type": "string"}}} + mock_get_data.return_value = json.dumps(data).encode() + + result = load_schema_for_version('1.2', 'test_package') + + assert result == data + mock_get_data.assert_called_once_with('test_package.v1_2', 'schema.json') + + @patch('sagemaker.hyperpod.cli.dev_space_utils.pkgutil.get_data') + def test_schema_not_found(self, mock_get_data): + """Test handling of missing schema file""" + mock_get_data.return_value = None + + with pytest.raises(click.ClickException) as exc: + load_schema_for_version('1.0', 'test_package') + + assert "Could not load schema.json for version 1.0" in str(exc.value) + + @patch('sagemaker.hyperpod.cli.dev_space_utils.pkgutil.get_data') + def test_invalid_json_schema(self, mock_get_data): + """Test handling of invalid JSON in schema file""" + mock_get_data.return_value = b'invalid json' + + with pytest.raises(json.JSONDecodeError): + load_schema_for_version('1.0', 'test_package') + + +class TestGenerateClickCommand: + def setup_method(self): + self.runner = CliRunner() + + def test_missing_registry(self): + """Test that registry is required""" + with pytest.raises(ValueError) as exc: + generate_click_command(schema_pkg="test_package") + assert "You must pass a registry mapping" in str(exc.value) + + @patch('sagemaker.hyperpod.cli.dev_space_utils.load_schema_for_version') + def test_unsupported_version(self, mock_load_schema): + """Test handling of unsupported version""" + mock_load_schema.return_value = {'properties': {}, 'required': []} + registry = {} + + @click.command() + @generate_click_command(registry=registry) + def cmd(version, domain_config): + click.echo('should not reach here') + + result = self.runner.invoke(cmd, []) + assert result.exit_code != 0 + assert 'Unsupported schema version: 1.0' in result.output + + @patch('sagemaker.hyperpod.cli.dev_space_utils.load_schema_for_version') + def test_version_handling(self, mock_load_schema): + """Test version handling in command generation""" + schema = {'properties': {}, 'required': []} + mock_load_schema.return_value = schema + + class DummyModel: + def __init__(self, **kwargs): + pass + def to_domain(self): + return self + + registry = {'2.0': DummyModel} + + @click.command() + @generate_click_command( + version_key='2.0', + schema_pkg="test_package", + registry=registry + ) + def cmd(version, domain_config): + click.echo(version) + + result = self.runner.invoke(cmd, []) + assert result.exit_code == 0 + assert result.output.strip() == '2.0' + + @patch('sagemaker.hyperpod.cli.dev_space_utils.load_schema_for_version') + def test_resources_building(self, mock_load_schema): + """Test CPU and memory resource building""" + schema = { + 'properties': { + 'resources': { + 'default': { + 'cpu': '250m', + 'memory': '256Mi', + 'nvidia.com/gpu': None + } + } + }, + 'required': [] + } + mock_load_schema.return_value = schema + + class DummyModel: + def __init__(self, **kwargs): + self.resources = kwargs.get('resources') + def to_domain(self): + return self + + registry = {'1.0': DummyModel} + + @click.command() + @generate_click_command(registry=registry, schema_pkg="hyperpod_dev_space_template") + def cmd(version, domain_config): + click.echo(json.dumps(domain_config.resources)) + + # Test with custom CPU and memory + result = self.runner.invoke(cmd, ['--cpu', '1000m', '--memory', '1Gi']) + assert result.exit_code == 0 + output = json.loads(result.output) + assert output['cpu'] == '1000m' + assert output['memory'] == '1Gi' + assert output['nvidia.com/gpu'] is None + + # Test with only CPU + result = self.runner.invoke(cmd, ['--cpu', '750m']) + assert result.exit_code == 0 + output = json.loads(result.output) + assert output['cpu'] == '750m' + assert output['memory'] == '256Mi' # default + + # Test with no resources specified + result = self.runner.invoke(cmd, []) + assert result.exit_code == 0 + assert result.output.strip() == 'null' + + @patch('sagemaker.hyperpod.cli.dev_space_utils.load_schema_for_version') + def test_type_conversion(self, mock_load_schema): + """Test type conversion for different parameter types""" + schema = { + 'properties': { + 'name': {'type': 'string'}, + 'desired_status': {'type': 'string', 'enum': ['Running', 'Stopped']}, + 'storage_size': {'type': 'string'}, + 'port': {'type': 'integer'} + }, + 'required': ['name'] + } + mock_load_schema.return_value = schema + + class DummyModel: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + def to_domain(self): + return self + + registry = {'1.0': DummyModel} + + @click.command() + @generate_click_command(registry=registry, schema_pkg="hyperpod_dev_space_template") + def cmd(version, domain_config): + click.echo(json.dumps({ + 'name': domain_config.name, + 'desired_status': getattr(domain_config, 'desired_status', None), + 'storage_size': getattr(domain_config, 'storage_size', None), + 'port': getattr(domain_config, 'port', None) + })) + + # Test string and enum types + result = self.runner.invoke(cmd, [ + '--name', 'test-space', + '--desired-status', 'Running', + '--storage-size', '20Gi' + ]) + assert result.exit_code == 0 + output = json.loads(result.output) + assert output['name'] == 'test-space' + assert output['desired_status'] == 'Running' + assert output['storage_size'] == '20Gi' + + # Test invalid enum value + result = self.runner.invoke(cmd, [ + '--name', 'test-space', + '--desired-status', 'Invalid' + ]) + assert result.exit_code == 2 + assert "Invalid value" in result.output + + @patch('sagemaker.hyperpod.cli.dev_space_utils.load_schema_for_version') + def test_successful_command_execution(self, mock_load_schema): + """Test successful command execution with valid parameters""" + schema = { + 'properties': { + 'name': {'type': 'string'}, + 'image': {'type': 'string', 'default': 'default-image'} + }, + 'required': ['name'] + } + mock_load_schema.return_value = schema + + class DummyModel: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + def to_domain(self): + return self + + registry = {'1.0': DummyModel} + + @click.command() + @generate_click_command(registry=registry, schema_pkg="hyperpod_dev_space_template") + def cmd(version, domain_config): + click.echo(f'success: {domain_config.name}') + + # Test successful execution + result = self.runner.invoke(cmd, ['--name', 'test-space']) + assert result.exit_code == 0 + assert 'success: test-space' in result.output + + @patch('sagemaker.hyperpod.cli.dev_space_utils.load_schema_for_version') + def test_immutable_fields_excluded_in_update(self, mock_load_schema): + """Test that immutable fields are excluded in update mode""" + schema = { + 'properties': { + 'name': {'type': 'string'}, + 'storage_class_name': {'type': 'string'}, + 'image': {'type': 'string'} + }, + 'required': ['name'] + } + mock_load_schema.return_value = schema + + class DummyModel: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + def to_domain(self): + return self + + registry = {'1.0': DummyModel} + + @click.command() + @generate_click_command( + registry=registry, + schema_pkg="hyperpod_dev_space_template", + is_update=True + ) + def cmd(version, domain_config): + click.echo('success') + + # Get the command's help to check available options + result = self.runner.invoke(cmd, ['--help']) + assert result.exit_code == 0 + # storage_class_name should not be available in update mode + assert '--storage-class-name' not in result.output + # but other fields should be available + assert '--name' in result.output + assert '--image' in result.output + + @patch('sagemaker.hyperpod.cli.dev_space_utils.load_schema_for_version') + def test_filtered_kwargs(self, mock_load_schema): + """Test that None/empty values are filtered out""" + schema = { + 'properties': { + 'name': {'type': 'string'}, + 'image': {'type': 'string', 'default': 'default-image'}, + 'namespace': {'type': 'string', 'default': None} + }, + 'required': ['name'] + } + mock_load_schema.return_value = schema + + class DummyModel: + def __init__(self, **kwargs): + self.received_kwargs = kwargs + self.__dict__.update(kwargs) + def to_domain(self): + return self + + registry = {'1.0': DummyModel} + + @click.command() + @generate_click_command(registry=registry, schema_pkg="hyperpod_dev_space_template") + def cmd(version, domain_config): + # Check that None values were filtered out + click.echo(json.dumps(domain_config.received_kwargs)) + + result = self.runner.invoke(cmd, ['--name', 'test-space']) + assert result.exit_code == 0 + output = json.loads(result.output) + assert output['name'] == 'test-space' + assert output['image'] == 'default-image' + assert 'namespace' not in output + + @patch('sagemaker.hyperpod.cli.dev_space_utils.load_schema_for_version') + def test_default_version_injection(self, mock_load_schema): + """Test that version flag is injected when no version_key is provided""" + schema = {'properties': {}, 'required': []} + mock_load_schema.return_value = schema + + class DummyModel: + def __init__(self, **kwargs): pass + def to_domain(self): return self + + registry = {'1.0': DummyModel, '2.0': DummyModel} + + @click.command() + @generate_click_command(registry=registry, schema_pkg="hyperpod_dev_space_template") + def cmd(version, domain_config): + click.echo(version) + + # Test default version + result = self.runner.invoke(cmd, []) + assert result.exit_code == 0 + assert result.output.strip() == '1.0' + + # Test custom version + result = self.runner.invoke(cmd, ['--version', '2.0']) + print(result.output) + assert result.exit_code == 0 + assert result.output.strip() == '2.0' + + @patch('sagemaker.hyperpod.cli.dev_space_utils.load_schema_for_version') + def test_schema_defaults_and_required_fields(self, mock_load_schema): + """Test handling of schema defaults and required fields""" + schema = { + 'properties': { + 'name': {'type': 'string'}, + 'image': {'type': 'string', 'default': 'default-image'}, + 'namespace': {'type': 'string', 'default': None} + }, + 'required': ['name', 'namespace'] + } + mock_load_schema.return_value = schema + + class DummyModel: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + def to_domain(self): + return self + + registry = {'1.0': DummyModel} + + @click.command() + @generate_click_command(registry=registry, schema_pkg="hyperpod_dev_space_template") + def cmd(version, domain_config): + click.echo('success') + + # Test missing required field + result = self.runner.invoke(cmd, []) + assert result.exit_code == 2 + assert "Missing option" in result.output + + # Test with required field provided + result = self.runner.invoke(cmd, ['--name', 'test-space', '--namespace', 'test-ns']) + print(result.output) + assert result.exit_code == 0 + assert result.output.strip() == 'success' From ca486b7b340736841251afbff08f1d4c515394c3 Mon Sep 17 00:00:00 2001 From: aws-brianxia Date: Tue, 28 Oct 2025 08:58:47 -0700 Subject: [PATCH 09/31] Add dev space CLI (#269) --- setup.py | 2 + .../hyperpod/cli/clients/kubernetes_client.py | 133 ++++++++++++++++++ .../hyperpod/cli/commands/dev_space.py | 29 ---- src/sagemaker/hyperpod/cli/hyp_cli.py | 51 ++++++- test/unit_tests/cli/test_dev_space.py | 73 ---------- .../clients/test_kubernetes_client.py | 102 ++++++++++++++ 6 files changed, 282 insertions(+), 108 deletions(-) diff --git a/setup.py b/setup.py index 8aa1a32e..59ab053c 100644 --- a/setup.py +++ b/setup.py @@ -91,6 +91,8 @@ "hyperpod-custom-inference-template>=1.0.0, <2.0.0", "hyperpod-jumpstart-inference-template>=1.0.0, <2.0.0", "hyperpod-cluster-stack-template>=1.0.0, <2.0.0" + # TODO: need to uncomment before pushing to master + # "hyperpod_dev_space_template>=1.0.0, <2.0.0" ], entry_points={ "console_scripts": [ diff --git a/src/sagemaker/hyperpod/cli/clients/kubernetes_client.py b/src/sagemaker/hyperpod/cli/clients/kubernetes_client.py index 3e6d0202..d7855029 100644 --- a/src/sagemaker/hyperpod/cli/clients/kubernetes_client.py +++ b/src/sagemaker/hyperpod/cli/clients/kubernetes_client.py @@ -40,6 +40,22 @@ PYTORCH_CUSTOM_OBJECT_PLURAL, PYTORCH_CUSTOM_OBJECT_VERSION, ) +from sagemaker.hyperpod.cli.constants.dev_space_constants import ( + DEV_SPACE_GROUP, + DEV_SPACE_VERSION, + DEV_SPACE_PLURAL, + DEFAULT_DEV_SPACE_PORT, +) +from sagemaker.hyperpod.cli.constants.space_admin_config_constants import ( + SPACE_ADMIN_CONFIG_GROUP, + SPACE_ADMIN_CONFIG_VERSION, + SPACE_ADMIN_CONFIG_PLURAL, +) +from sagemaker.hyperpod.cli.constants.dev_space_access_constants import ( + DEV_SPACE_ACCESS_GROUP, + DEV_SPACE_ACCESS_VERSION, + DEV_SPACE_ACCESS_PLURAL, +) from sagemaker.hyperpod.cli.utils import setup_logger logger = setup_logger(__name__) @@ -358,4 +374,121 @@ def get_cluster_queue(self, cluster_queue_name: str): plural=CLUSTER_QUEUE_PRIORITY_CLASS_CUSTOM_OBJECT_PLURAL, name=cluster_queue_name ) + + def create_dev_space(self, namespace: str, dev_space_spec: dict): + return client.CustomObjectsApi().create_namespaced_custom_object( + group=DEV_SPACE_GROUP, + version=DEV_SPACE_VERSION, + namespace=namespace, + plural=DEV_SPACE_PLURAL, + body=dev_space_spec + ) + + def list_dev_spaces(self, namespace: str): + if namespace: + return client.CustomObjectsApi().list_namespaced_custom_object( + group=DEV_SPACE_GROUP, + version=DEV_SPACE_VERSION, + namespace=namespace, + plural=DEV_SPACE_PLURAL + ) + else: + return client.CustomObjectsApi().list_cluster_custom_object( + group=DEV_SPACE_GROUP, + version=DEV_SPACE_VERSION, + plural=DEV_SPACE_PLURAL + ) + + def get_dev_space(self, namespace: str, name: str): + return client.CustomObjectsApi().get_namespaced_custom_object( + group=DEV_SPACE_GROUP, + version=DEV_SPACE_VERSION, + namespace=namespace, + plural=DEV_SPACE_PLURAL, + name=name + ) + + def delete_dev_space(self, namespace: str, name: str): + return client.CustomObjectsApi().delete_namespaced_custom_object( + group=DEV_SPACE_GROUP, + version=DEV_SPACE_VERSION, + namespace=namespace, + plural=DEV_SPACE_PLURAL, + name=name + ) + + def patch_dev_space(self, namespace: str, name: str, body: dict): + return client.CustomObjectsApi().patch_namespaced_custom_object( + group=DEV_SPACE_GROUP, + version=DEV_SPACE_VERSION, + namespace=namespace, + plural=DEV_SPACE_PLURAL, + name=name, + body=body + ) + + + + # Space Admin Configuration methods + def create_space_admin_config(self, namespace: str, config_spec: dict): + return client.CustomObjectsApi().create_namespaced_custom_object( + group=SPACE_ADMIN_CONFIG_GROUP, + version=SPACE_ADMIN_CONFIG_VERSION, + namespace=namespace, + plural=SPACE_ADMIN_CONFIG_PLURAL, + body=config_spec + ) + + def list_space_admin_configs(self, namespace: str = None): + if namespace: + return client.CustomObjectsApi().list_namespaced_custom_object( + group=SPACE_ADMIN_CONFIG_GROUP, + version=SPACE_ADMIN_CONFIG_VERSION, + namespace=namespace, + plural=SPACE_ADMIN_CONFIG_PLURAL + ) + else: + return client.CustomObjectsApi().list_cluster_custom_object( + group=SPACE_ADMIN_CONFIG_GROUP, + version=SPACE_ADMIN_CONFIG_VERSION, + plural=SPACE_ADMIN_CONFIG_PLURAL + ) + + def get_space_admin_config(self, namespace: str, name: str): + return client.CustomObjectsApi().get_namespaced_custom_object( + group=SPACE_ADMIN_CONFIG_GROUP, + version=SPACE_ADMIN_CONFIG_VERSION, + namespace=namespace, + plural=SPACE_ADMIN_CONFIG_PLURAL, + name=name + ) + + def delete_space_admin_config(self, namespace: str, name: str): + return client.CustomObjectsApi().delete_namespaced_custom_object( + group=SPACE_ADMIN_CONFIG_GROUP, + version=SPACE_ADMIN_CONFIG_VERSION, + namespace=namespace, + plural=SPACE_ADMIN_CONFIG_PLURAL, + name=name + ) + + def patch_space_admin_config(self, namespace: str, name: str, body: dict): + return client.CustomObjectsApi().patch_namespaced_custom_object( + group=SPACE_ADMIN_CONFIG_GROUP, + version=SPACE_ADMIN_CONFIG_VERSION, + namespace=namespace, + plural=SPACE_ADMIN_CONFIG_PLURAL, + name=name, + body=body + ) + + def create_dev_space_access(self, namespace: str, config_spec: dict): + return client.CustomObjectsApi().create_namespaced_custom_object( + group=DEV_SPACE_ACCESS_GROUP, + version=DEV_SPACE_ACCESS_VERSION, + namespace=namespace, + plural=DEV_SPACE_ACCESS_PLURAL, + body=config_spec + ) + # Add more methods to access other APIs as needed diff --git a/src/sagemaker/hyperpod/cli/commands/dev_space.py b/src/sagemaker/hyperpod/cli/commands/dev_space.py index 4a3aa324..3164fca4 100644 --- a/src/sagemaker/hyperpod/cli/commands/dev_space.py +++ b/src/sagemaker/hyperpod/cli/commands/dev_space.py @@ -194,33 +194,4 @@ def dev_space_get_logs(name, namespace): click.echo(f"Error getting logs for dev space '{name}': {e}", err=True) -@click.command("hyp-dev-space") -@click.option("--name", required=True, help="Name of the dev space") -@click.option("--namespace", "-n", required=False, default="default", help="Kubernetes namespace") -@click.option("--port", required=True, help="Mapping localhost port to pod") -def dev_space_port_forward(name, namespace, port): - """Forward a local port to a dev-space pod.""" - k8s_client = KubernetesClient() - - try: - # Get pods associated with the dev space - pods = k8s_client.list_pods_with_labels( - namespace=namespace, - label_selector=f"sagemaker.aws.com/space-name={name}" - ) - - if not pods.items: - click.echo(f"No pods found for dev space '{name}'") - return - - # Get the first running pod - pod_name = pods.items[0].metadata.name - k8s_client.port_forward_dev_space( - namespace=namespace, - pod_name=pod_name, - local_port=port, - ) - - except Exception as e: - click.echo(f"Error forwarding port for dev space '{name}': {e}", err=True) diff --git a/src/sagemaker/hyperpod/cli/hyp_cli.py b/src/sagemaker/hyperpod/cli/hyp_cli.py index 872c21ee..96a1bbcc 100644 --- a/src/sagemaker/hyperpod/cli/hyp_cli.py +++ b/src/sagemaker/hyperpod/cli/hyp_cli.py @@ -38,6 +38,16 @@ js_get_operator_logs, custom_get_operator_logs, ) +from sagemaker.hyperpod.cli.commands.dev_space import ( + dev_space_create, + dev_space_list, + dev_space_describe, + dev_space_delete, + dev_space_update, + dev_space_start, + dev_space_stop, + dev_space_get_logs, +) from sagemaker.hyperpod.cli.commands.init import ( init, @@ -97,7 +107,7 @@ def parse_args(self, ctx, args): @cli.group(cls=CLICommand, default_cmd='_default_create') def create(): """ - Create endpoints, pytorch jobs or cluster stacks. + Create endpoints, pytorch jobs, cluster stacks, dev space, dev space access or space admin config. If only used as 'hyp create' without [OPTIONS] COMMAND [ARGS] during init experience, then it will validate configuration and render template files for deployment. @@ -113,26 +123,41 @@ def create(): @cli.group(cls=CLICommand) def list(): - """List endpoints, pytorch jobs or cluster stacks.""" + """List endpoints, pytorch jobs, cluster stacks or dev spaces.""" pass @cli.group(cls=CLICommand) def describe(): - """Describe endpoints, pytorch jobs or cluster stacks.""" + """Describe endpoints, pytorch jobs or cluster stacks, dev spaces or space admin configs.""" pass @cli.group(cls=CLICommand) def update(): - """Update an existing HyperPod cluster configuration.""" + """Update an existing HyperPod cluster configuration, dev space, or space admin config.""" pass @cli.group(cls=CLICommand) def delete(): - """Delete endpoints or pytorch jobs.""" + """Delete endpoints, pytorch jobs, dev space, dev space access or space admin config.""" + pass + + +@cli.group(cls=CLICommand) +def start(): + """Start dev space resources.""" + pass + + +@cli.group(cls=CLICommand) +def stop(): + """Stop dev space resources.""" pass + + + @cli.group(cls=CLICommand) def list_pods(): """List pods for endpoints or pytorch jobs.""" @@ -141,7 +166,7 @@ def list_pods(): @cli.group(cls=CLICommand) def get_logs(): - """Get pod logs for endpoints or pytorch jobs.""" + """Get pod logs for endpoints, pytorch jobs or dev spaces.""" pass @@ -171,26 +196,37 @@ def exec(): create.add_command(pytorch_create) create.add_command(js_create) create.add_command(custom_create) + _default_create.hidden = True create.add_command(_default_create) +create.add_command(dev_space_create) list.add_command(list_jobs) list.add_command(js_list) list.add_command(custom_list) list.add_command(list_cluster_stacks) +list.add_command(dev_space_list) describe.add_command(pytorch_describe) describe.add_command(js_describe) describe.add_command(custom_describe) describe.add_command(describe_cluster_stack) + describe.add_command(describe_cluster) +describe.add_command(dev_space_describe) update.add_command(update_cluster) +update.add_command(dev_space_update) delete.add_command(pytorch_delete) delete.add_command(js_delete) delete.add_command(custom_delete) delete.add_command(delete_cluster_stack) +delete.add_command(dev_space_delete) + +start.add_command(dev_space_start) + +stop.add_command(dev_space_stop) list_pods.add_command(pytorch_list_pods) list_pods.add_command(js_list_pods) @@ -199,6 +235,9 @@ def exec(): get_logs.add_command(pytorch_get_logs) get_logs.add_command(js_get_logs) get_logs.add_command(custom_get_logs) +get_logs.add_command(dev_space_get_logs) + + get_operator_logs.add_command(pytorch_get_operator_logs) get_operator_logs.add_command(js_get_operator_logs) diff --git a/test/unit_tests/cli/test_dev_space.py b/test/unit_tests/cli/test_dev_space.py index 2b1ac434..cd7ddae4 100644 --- a/test/unit_tests/cli/test_dev_space.py +++ b/test/unit_tests/cli/test_dev_space.py @@ -12,7 +12,6 @@ dev_space_start, dev_space_stop, dev_space_get_logs, - dev_space_port_forward, ) @@ -464,70 +463,6 @@ def test_dev_space_get_logs_k8s_error(self, mock_k8s_client_class): assert result.exit_code == 0 assert "Error getting logs for dev space 'test-space': List pod failed" in result.output - @patch('sagemaker.hyperpod.cli.commands.dev_space.KubernetesClient') - def test_dev_space_port_forward_success(self, mock_k8s_client_class): - """Test successful dev space port forward""" - mock_pod = Mock() - mock_pod.metadata.name = "test-pod" - mock_pods = Mock() - mock_pods.items = [mock_pod] - - mock_k8s_instance = Mock() - mock_k8s_instance.list_pods_with_labels.return_value = mock_pods - mock_k8s_client_class.return_value = mock_k8s_instance - - result = self.runner.invoke(dev_space_port_forward, [ - '--name', 'test-space', - '--namespace', 'test-ns', - '--port', '8080' - ]) - - assert result.exit_code == 0 - mock_k8s_instance.list_pods_with_labels.assert_called_once_with( - namespace='test-ns', - label_selector='sagemaker.aws.com/space-name=test-space' - ) - mock_k8s_instance.port_forward_dev_space.assert_called_once_with( - namespace='test-ns', - pod_name='test-pod', - local_port='8080' - ) - - @patch('sagemaker.hyperpod.cli.commands.dev_space.KubernetesClient') - def test_dev_space_port_forward_no_pods(self, mock_k8s_client_class): - """Test dev space port forward with no pods""" - mock_pods = Mock() - mock_pods.items = [] - - mock_k8s_instance = Mock() - mock_k8s_instance.list_pods_with_labels.return_value = mock_pods - mock_k8s_client_class.return_value = mock_k8s_instance - - result = self.runner.invoke(dev_space_port_forward, [ - '--name', 'test-space', - '--namespace', 'test-ns', - '--port', '8080' - ]) - - assert result.exit_code == 0 - assert "No pods found for dev space 'test-space'" in result.output - - @patch('sagemaker.hyperpod.cli.commands.dev_space.KubernetesClient') - def test_dev_space_port_forward_error(self, mock_k8s_client_class): - """Test dev space port forward error handling""" - mock_k8s_instance = Mock() - mock_k8s_instance.list_pods_with_labels.side_effect = Exception("Port forward failed") - mock_k8s_client_class.return_value = mock_k8s_instance - - result = self.runner.invoke(dev_space_port_forward, [ - '--name', 'test-space', - '--namespace', 'test-ns', - '--port', '8080' - ]) - - assert result.exit_code == 0 - assert "Error forwarding port for dev space 'test-space': Port forward failed" in result.output - def test_missing_required_arguments(self): """Test commands with missing required arguments""" # Test create without name @@ -564,11 +499,3 @@ def test_missing_required_arguments(self): result = self.runner.invoke(dev_space_get_logs, ['--namespace', 'test-ns']) assert result.exit_code == 2 assert "Missing option '--name'" in result.output - - # Test port forward without port - result = self.runner.invoke(dev_space_port_forward, [ - '--name', 'test-space', - '--namespace', 'test-ns' - ]) - assert result.exit_code == 2 - assert "Missing option '--port'" in result.output diff --git a/test/unit_tests/clients/test_kubernetes_client.py b/test/unit_tests/clients/test_kubernetes_client.py index 5eb302fa..765e487f 100644 --- a/test/unit_tests/clients/test_kubernetes_client.py +++ b/test/unit_tests/clients/test_kubernetes_client.py @@ -697,3 +697,105 @@ def test_check_if_namespace_exists_false( test_client = KubernetesClient() result = test_client.check_if_namespace_exists("abcdef") self.assertFalse(result) + + @patch("kubernetes.client.CustomObjectsApi.create_namespaced_custom_object") + def test_create_dev_space(self, mock_create_namespaced_custom_object): + """Test creating a dev space""" + test_client = KubernetesClient() + dev_space_spec = {"spec": {"image": "test-image"}} + + test_client.create_dev_space("test-namespace", dev_space_spec) + + mock_create_namespaced_custom_object.assert_called_once_with( + group="sagemaker.aws.com", + version="v1alpha1", + namespace="test-namespace", + plural="spaces", + body=dev_space_spec + ) + + @patch("kubernetes.client.CustomObjectsApi.list_namespaced_custom_object") + def test_list_dev_spaces_with_namespace(self, mock_list_namespaced_custom_object): + """Test listing dev spaces in a specific namespace""" + test_client = KubernetesClient() + mock_list_namespaced_custom_object.return_value = {"items": []} + + result = test_client.list_dev_spaces("test-namespace") + + mock_list_namespaced_custom_object.assert_called_once_with( + group="sagemaker.aws.com", + version="v1alpha1", + namespace="test-namespace", + plural="spaces" + ) + self.assertEqual(result, {"items": []}) + + @patch("kubernetes.client.CustomObjectsApi.list_cluster_custom_object") + def test_list_dev_spaces_without_namespace(self, mock_list_cluster_custom_object): + """Test listing dev spaces across all namespaces""" + test_client = KubernetesClient() + mock_list_cluster_custom_object.return_value = {"items": []} + + result = test_client.list_dev_spaces(None) + + mock_list_cluster_custom_object.assert_called_once_with( + group="sagemaker.aws.com", + version="v1alpha1", + plural="spaces" + ) + self.assertEqual(result, {"items": []}) + + @patch("kubernetes.client.CustomObjectsApi.get_namespaced_custom_object") + def test_get_dev_space(self, mock_get_namespaced_custom_object): + """Test getting a specific dev space""" + test_client = KubernetesClient() + mock_dev_space = {"metadata": {"name": "test-space"}} + mock_get_namespaced_custom_object.return_value = mock_dev_space + + result = test_client.get_dev_space("test-namespace", "test-space") + + mock_get_namespaced_custom_object.assert_called_once_with( + group="sagemaker.aws.com", + version="v1alpha1", + namespace="test-namespace", + plural="spaces", + name="test-space" + ) + self.assertEqual(result, mock_dev_space) + + @patch("kubernetes.client.CustomObjectsApi.delete_namespaced_custom_object") + def test_delete_dev_space(self, mock_delete_namespaced_custom_object): + """Test deleting a dev space""" + test_client = KubernetesClient() + mock_delete_namespaced_custom_object.return_value = {} + + result = test_client.delete_dev_space("test-namespace", "test-space") + + mock_delete_namespaced_custom_object.assert_called_once_with( + group="sagemaker.aws.com", + version="v1alpha1", + namespace="test-namespace", + plural="spaces", + name="test-space" + ) + self.assertEqual(result, {}) + + @patch("kubernetes.client.CustomObjectsApi.patch_namespaced_custom_object") + def test_patch_dev_space(self, mock_patch_namespaced_custom_object): + """Test patching a dev space""" + test_client = KubernetesClient() + patch_body = {"spec": {"desiredStatus": "Running"}} + mock_patch_namespaced_custom_object.return_value = {} + + result = test_client.patch_dev_space("test-namespace", "test-space", patch_body) + + mock_patch_namespaced_custom_object.assert_called_once_with( + group="sagemaker.aws.com", + version="v1alpha1", + namespace="test-namespace", + plural="spaces", + name="test-space", + body=patch_body + ) + self.assertEqual(result, {}) + From 1568d31f1b70873e04311674c8b14c91e32dd8f3 Mon Sep 17 00:00:00 2001 From: aws-brianxia Date: Tue, 28 Oct 2025 14:38:43 -0700 Subject: [PATCH 10/31] Rename dev space to space (#272) --- hyperpod-dev-space-template/update_schema.py | 8 - .../hyperpod_space_template}/__init__.py | 0 .../hyperpod_space_template}/registry.py | 4 +- .../hyperpod_space_template}/v1_0/__init__.py | 0 .../hyperpod_space_template}/v1_0/model.py | 16 +- .../hyperpod_space_template}/v1_0/schema.json | 6 +- .../pyproject.toml | 8 +- hyperpod-space-template/update_schema.py | 8 + setup.py | 2 +- .../hyperpod/cli/clients/kubernetes_client.py | 74 ++--- .../cli/commands/{dev_space.py => space.py} | 108 +++---- ...constants.py => space_access_constants.py} | 6 +- ..._space_constants.py => space_constants.py} | 10 +- src/sagemaker/hyperpod/cli/hyp_cli.py | 50 +-- .../{dev_space_utils.py => space_utils.py} | 6 +- .../cli/{test_dev_space.py => test_space.py} | 286 +++++++++--------- ...dev_space_utils.py => test_space_utils.py} | 40 +-- .../clients/test_kubernetes_client.py | 46 +-- 18 files changed, 339 insertions(+), 339 deletions(-) delete mode 100644 hyperpod-dev-space-template/update_schema.py rename {hyperpod-dev-space-template/hyperpod_dev_space_template => hyperpod-space-template/hyperpod_space_template}/__init__.py (100%) rename {hyperpod-dev-space-template/hyperpod_dev_space_template => hyperpod-space-template/hyperpod_space_template}/registry.py (91%) rename {hyperpod-dev-space-template/hyperpod_dev_space_template => hyperpod-space-template/hyperpod_space_template}/v1_0/__init__.py (100%) rename {hyperpod-dev-space-template/hyperpod_dev_space_template => hyperpod-space-template/hyperpod_space_template}/v1_0/model.py (94%) rename {hyperpod-dev-space-template/hyperpod_dev_space_template => hyperpod-space-template/hyperpod_space_template}/v1_0/schema.json (96%) rename {hyperpod-dev-space-template => hyperpod-space-template}/pyproject.toml (70%) create mode 100644 hyperpod-space-template/update_schema.py rename src/sagemaker/hyperpod/cli/commands/{dev_space.py => space.py} (60%) rename src/sagemaker/hyperpod/cli/constants/{dev_space_access_constants.py => space_access_constants.py} (79%) rename src/sagemaker/hyperpod/cli/constants/{dev_space_constants.py => space_constants.py} (75%) rename src/sagemaker/hyperpod/cli/{dev_space_utils.py => space_utils.py} (96%) rename test/unit_tests/cli/{test_dev_space.py => test_space.py} (53%) rename test/unit_tests/cli/{test_dev_space_utils.py => test_space_utils.py} (90%) diff --git a/hyperpod-dev-space-template/update_schema.py b/hyperpod-dev-space-template/update_schema.py deleted file mode 100644 index 01c0b87d..00000000 --- a/hyperpod-dev-space-template/update_schema.py +++ /dev/null @@ -1,8 +0,0 @@ -#!/usr/bin/env python3 -import json -from hyperpod_dev_space_template.v1_0.model import DevSpaceConfig - -schema = DevSpaceConfig.model_json_schema() -with open('hyperpod_dev_space_template/v1_0/schema.json', 'w') as f: - json.dump(schema, f, indent=2) -print('✅ Schema updated!') diff --git a/hyperpod-dev-space-template/hyperpod_dev_space_template/__init__.py b/hyperpod-space-template/hyperpod_space_template/__init__.py similarity index 100% rename from hyperpod-dev-space-template/hyperpod_dev_space_template/__init__.py rename to hyperpod-space-template/hyperpod_space_template/__init__.py diff --git a/hyperpod-dev-space-template/hyperpod_dev_space_template/registry.py b/hyperpod-space-template/hyperpod_space_template/registry.py similarity index 91% rename from hyperpod-dev-space-template/hyperpod_dev_space_template/registry.py rename to hyperpod-space-template/hyperpod_space_template/registry.py index bdf80082..9d120531 100644 --- a/hyperpod-dev-space-template/hyperpod_dev_space_template/registry.py +++ b/hyperpod-space-template/hyperpod_space_template/registry.py @@ -10,11 +10,11 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -from .v1_0.model import DevSpaceConfig +from .v1_0.model import SpaceConfig from typing import Dict, Type from pydantic import BaseModel # Direct version-to-model mapping SCHEMA_REGISTRY: Dict[str, Type[BaseModel]] = { - "1.0": DevSpaceConfig, + "1.0": SpaceConfig, } diff --git a/hyperpod-dev-space-template/hyperpod_dev_space_template/v1_0/__init__.py b/hyperpod-space-template/hyperpod_space_template/v1_0/__init__.py similarity index 100% rename from hyperpod-dev-space-template/hyperpod_dev_space_template/v1_0/__init__.py rename to hyperpod-space-template/hyperpod_space_template/v1_0/__init__.py diff --git a/hyperpod-dev-space-template/hyperpod_dev_space_template/v1_0/model.py b/hyperpod-space-template/hyperpod_space_template/v1_0/model.py similarity index 94% rename from hyperpod-dev-space-template/hyperpod_dev_space_template/v1_0/model.py rename to hyperpod-space-template/hyperpod_space_template/v1_0/model.py index e30c4884..b8b06758 100644 --- a/hyperpod-dev-space-template/hyperpod_dev_space_template/v1_0/model.py +++ b/hyperpod-space-template/hyperpod_space_template/v1_0/model.py @@ -45,7 +45,7 @@ class ResourcesConfig(BaseModel): nvidia_gpu: Optional[str] = Field(default=None, alias="nvidia.com/gpu", description="GPU limit") -class DevSpaceConfig(BaseModel): +class SpaceConfig(BaseModel): model_config = ConfigDict(extra="forbid") name: str = Field( @@ -56,7 +56,7 @@ class DevSpaceConfig(BaseModel): ) image: Optional[str] = Field( default="public.ecr.aws/sagemaker/sagemaker-distribution:3.2.0-cpu", - description="Container image for the dev space", + description="Container image for the space", min_length=1 ) namespace: str = Field( @@ -67,7 +67,7 @@ class DevSpaceConfig(BaseModel): desired_status: Optional[Literal['Running', 'Stopped']] = Field( default="Running", alias="desired_status", - description="Desired status of the dev space" + description="Desired status of the space" ) service_account_name: Optional[str] = Field( default="default", @@ -142,9 +142,9 @@ class DevSpaceConfig(BaseModel): def to_domain(self) -> Dict: """ - Convert flat config to domain model for dev space creation + Convert flat config to domain model for space creation """ - # Create the dev space spec + # Create the space spec spec = { "image": self.image } @@ -180,8 +180,8 @@ def to_domain(self) -> Dict: # if labels: # metadata["labels"] = labels - # Create the complete dev space configuration - dev_space_config = { + # Create the complete space configuration + space_config = { "apiVersion": "sagemaker.aws.com/v1alpha1", "kind": "Space", "metadata": metadata, @@ -191,5 +191,5 @@ def to_domain(self) -> Dict: return { "name": self.name, "namespace": self.namespace, - "dev_space_spec": dev_space_config + "space_spec": space_config } diff --git a/hyperpod-dev-space-template/hyperpod_dev_space_template/v1_0/schema.json b/hyperpod-space-template/hyperpod_space_template/v1_0/schema.json similarity index 96% rename from hyperpod-dev-space-template/hyperpod_dev_space_template/v1_0/schema.json rename to hyperpod-space-template/hyperpod_space_template/v1_0/schema.json index c235a896..82693bb8 100644 --- a/hyperpod-dev-space-template/hyperpod_dev_space_template/v1_0/schema.json +++ b/hyperpod-space-template/hyperpod_space_template/v1_0/schema.json @@ -83,7 +83,7 @@ } ], "default": "public.ecr.aws/sagemaker/sagemaker-distribution:3.2.0-cpu", - "description": "Container image for the dev space", + "description": "Container image for the space", "title": "Image" }, "namespace": { @@ -107,7 +107,7 @@ } ], "default": "Running", - "description": "Desired status of the dev space", + "description": "Desired status of the space", "title": "Desired Status" }, "service_account_name": { @@ -196,6 +196,6 @@ "required": [ "name" ], - "title": "DevSpaceConfig", + "title": "SpaceConfig", "type": "object" } \ No newline at end of file diff --git a/hyperpod-dev-space-template/pyproject.toml b/hyperpod-space-template/pyproject.toml similarity index 70% rename from hyperpod-dev-space-template/pyproject.toml rename to hyperpod-space-template/pyproject.toml index 817ce58c..adaab3a8 100644 --- a/hyperpod-dev-space-template/pyproject.toml +++ b/hyperpod-space-template/pyproject.toml @@ -3,9 +3,9 @@ requires = ["setuptools>=45", "wheel", "setuptools_scm[toml]>=6.2"] build-backend = "setuptools.build_meta" [project] -name = "hyperpod-dev-space-template" +name = "hyperpod-space-template" version = "1.0.0" -description = "Template for HyperPod Dev Space configuration" +description = "Template for HyperPod Space configuration" authors = [ {name = "Amazon Web Services"}, ] @@ -20,7 +20,7 @@ Homepage = "https://github.com/aws/sagemaker-hyperpod-cli" [tool.setuptools.packages.find] where = ["."] -include = ["hyperpod_dev_space_template*"] +include = ["hyperpod_space_template*"] [tool.setuptools.package-data] -"hyperpod_dev_space_template.v1_0" = ["schema.json"] +"hyperpod_space_template.v1_0" = ["schema.json"] diff --git a/hyperpod-space-template/update_schema.py b/hyperpod-space-template/update_schema.py new file mode 100644 index 00000000..85a789db --- /dev/null +++ b/hyperpod-space-template/update_schema.py @@ -0,0 +1,8 @@ +#!/usr/bin/env python3 +import json +from hyperpod_space_template.v1_0.model import SpaceConfig + +schema = SpaceConfig.model_json_schema() +with open('hyperpod_space_template/v1_0/schema.json', 'w') as f: + json.dump(schema, f, indent=2) +print('✅ Schema updated!') diff --git a/setup.py b/setup.py index 59ab053c..14014833 100644 --- a/setup.py +++ b/setup.py @@ -92,7 +92,7 @@ "hyperpod-jumpstart-inference-template>=1.0.0, <2.0.0", "hyperpod-cluster-stack-template>=1.0.0, <2.0.0" # TODO: need to uncomment before pushing to master - # "hyperpod_dev_space_template>=1.0.0, <2.0.0" + # "hyperpod_space_template>=1.0.0, <2.0.0" ], entry_points={ "console_scripts": [ diff --git a/src/sagemaker/hyperpod/cli/clients/kubernetes_client.py b/src/sagemaker/hyperpod/cli/clients/kubernetes_client.py index d7855029..55576857 100644 --- a/src/sagemaker/hyperpod/cli/clients/kubernetes_client.py +++ b/src/sagemaker/hyperpod/cli/clients/kubernetes_client.py @@ -40,21 +40,21 @@ PYTORCH_CUSTOM_OBJECT_PLURAL, PYTORCH_CUSTOM_OBJECT_VERSION, ) -from sagemaker.hyperpod.cli.constants.dev_space_constants import ( - DEV_SPACE_GROUP, - DEV_SPACE_VERSION, - DEV_SPACE_PLURAL, - DEFAULT_DEV_SPACE_PORT, +from sagemaker.hyperpod.cli.constants.space_constants import ( + SPACE_GROUP, + SPACE_VERSION, + SPACE_PLURAL, + DEFAULT_SPACE_PORT, ) from sagemaker.hyperpod.cli.constants.space_admin_config_constants import ( SPACE_ADMIN_CONFIG_GROUP, SPACE_ADMIN_CONFIG_VERSION, SPACE_ADMIN_CONFIG_PLURAL, ) -from sagemaker.hyperpod.cli.constants.dev_space_access_constants import ( - DEV_SPACE_ACCESS_GROUP, - DEV_SPACE_ACCESS_VERSION, - DEV_SPACE_ACCESS_PLURAL, +from sagemaker.hyperpod.cli.constants.space_access_constants import ( + SPACE_ACCESS_GROUP, + SPACE_ACCESS_VERSION, + SPACE_ACCESS_PLURAL, ) from sagemaker.hyperpod.cli.utils import setup_logger @@ -375,54 +375,54 @@ def get_cluster_queue(self, cluster_queue_name: str): name=cluster_queue_name ) - def create_dev_space(self, namespace: str, dev_space_spec: dict): + def create_space(self, namespace: str, space_spec: dict): return client.CustomObjectsApi().create_namespaced_custom_object( - group=DEV_SPACE_GROUP, - version=DEV_SPACE_VERSION, + group=SPACE_GROUP, + version=SPACE_VERSION, namespace=namespace, - plural=DEV_SPACE_PLURAL, - body=dev_space_spec + plural=SPACE_PLURAL, + body=space_spec ) - def list_dev_spaces(self, namespace: str): + def list_spaces(self, namespace: str): if namespace: return client.CustomObjectsApi().list_namespaced_custom_object( - group=DEV_SPACE_GROUP, - version=DEV_SPACE_VERSION, + group=SPACE_GROUP, + version=SPACE_VERSION, namespace=namespace, - plural=DEV_SPACE_PLURAL + plural=SPACE_PLURAL ) else: return client.CustomObjectsApi().list_cluster_custom_object( - group=DEV_SPACE_GROUP, - version=DEV_SPACE_VERSION, - plural=DEV_SPACE_PLURAL + group=SPACE_GROUP, + version=SPACE_VERSION, + plural=SPACE_PLURAL ) - def get_dev_space(self, namespace: str, name: str): + def get_space(self, namespace: str, name: str): return client.CustomObjectsApi().get_namespaced_custom_object( - group=DEV_SPACE_GROUP, - version=DEV_SPACE_VERSION, + group=SPACE_GROUP, + version=SPACE_VERSION, namespace=namespace, - plural=DEV_SPACE_PLURAL, + plural=SPACE_PLURAL, name=name ) - def delete_dev_space(self, namespace: str, name: str): + def delete_space(self, namespace: str, name: str): return client.CustomObjectsApi().delete_namespaced_custom_object( - group=DEV_SPACE_GROUP, - version=DEV_SPACE_VERSION, + group=SPACE_GROUP, + version=SPACE_VERSION, namespace=namespace, - plural=DEV_SPACE_PLURAL, + plural=SPACE_PLURAL, name=name ) - def patch_dev_space(self, namespace: str, name: str, body: dict): + def patch_space(self, namespace: str, name: str, body: dict): return client.CustomObjectsApi().patch_namespaced_custom_object( - group=DEV_SPACE_GROUP, - version=DEV_SPACE_VERSION, + group=SPACE_GROUP, + version=SPACE_VERSION, namespace=namespace, - plural=DEV_SPACE_PLURAL, + plural=SPACE_PLURAL, name=name, body=body ) @@ -482,12 +482,12 @@ def patch_space_admin_config(self, namespace: str, name: str, body: dict): body=body ) - def create_dev_space_access(self, namespace: str, config_spec: dict): + def create_space_access(self, namespace: str, config_spec: dict): return client.CustomObjectsApi().create_namespaced_custom_object( - group=DEV_SPACE_ACCESS_GROUP, - version=DEV_SPACE_ACCESS_VERSION, + group=SPACE_ACCESS_GROUP, + version=SPACE_ACCESS_VERSION, namespace=namespace, - plural=DEV_SPACE_ACCESS_PLURAL, + plural=SPACE_ACCESS_PLURAL, body=config_spec ) diff --git a/src/sagemaker/hyperpod/cli/commands/dev_space.py b/src/sagemaker/hyperpod/cli/commands/space.py similarity index 60% rename from src/sagemaker/hyperpod/cli/commands/dev_space.py rename to src/sagemaker/hyperpod/cli/commands/space.py index 3164fca4..f8e0473b 100644 --- a/src/sagemaker/hyperpod/cli/commands/dev_space.py +++ b/src/sagemaker/hyperpod/cli/commands/space.py @@ -2,44 +2,44 @@ import json from tabulate import tabulate from sagemaker.hyperpod.cli.clients.kubernetes_client import KubernetesClient -from sagemaker.hyperpod.cli.dev_space_utils import generate_click_command -from hyperpod_dev_space_template.registry import SCHEMA_REGISTRY +from sagemaker.hyperpod.cli.space_utils import generate_click_command +from hyperpod_space_template.registry import SCHEMA_REGISTRY from sagemaker.hyperpod.common.telemetry.telemetry_logging import ( _hyperpod_telemetry_emitter, ) from sagemaker.hyperpod.common.telemetry.constants import Feature -@click.command("hyp-dev-space") +@click.command("hyp-space") @generate_click_command( - schema_pkg="hyperpod_dev_space_template", + schema_pkg="hyperpod_space_template", registry=SCHEMA_REGISTRY, ) -def dev_space_create(version, config): - """Create a dev-space resource.""" +def space_create(version, config): + """Create a space resource.""" try: name = config.get("name") namespace = config.get("namespace") - dev_space_spec = config.get("dev_space_spec") + space_spec = config.get("space_spec") k8s_client = KubernetesClient() - k8s_client.create_dev_space(namespace, dev_space_spec) + k8s_client.create_space(namespace, space_spec) click.echo(f"Dev space '{name}' created successfully in namespace '{namespace}'") except Exception as e: - click.echo(f"Error creating dev space: {e}", err=True) + click.echo(f"Error creating space: {e}", err=True) -@click.command("hyp-dev-space") +@click.command("hyp-space") @click.option("--namespace", "-n", required=False, default="default", help="Kubernetes namespace") @click.option("--output", "-o", type=click.Choice(["table", "json"]), default="table") -def dev_space_list(namespace, output): - """List dev-space resources.""" +def space_list(namespace, output): + """List space resources.""" k8s_client = KubernetesClient() try: - resources = k8s_client.list_dev_spaces(namespace) + resources = k8s_client.list_spaces(namespace) if output == "json": click.echo(json.dumps(resources, indent=2)) @@ -55,21 +55,21 @@ def dev_space_list(namespace, output): ]) click.echo(tabulate(table_data, headers=["NAME", "NAMESPACE", "STATUS"])) else: - click.echo("No dev spaces found") + click.echo("No spaces found") except Exception as e: - click.echo(f"Error listing dev spaces: {e}", err=True) + click.echo(f"Error listing spaces: {e}", err=True) -@click.command("hyp-dev-space") -@click.option("--name", required=True, help="Name of the dev space") +@click.command("hyp-space") +@click.option("--name", required=True, help="Name of the space") @click.option("--namespace", "-n", required=False, default="default", help="Kubernetes namespace") @click.option("--output", "-o", type=click.Choice(["yaml", "json"]), default="yaml") -def dev_space_describe(name, namespace, output): - """Describe a dev-space resource.""" +def space_describe(name, namespace, output): + """Describe a space resource.""" k8s_client = KubernetesClient() try: - resource = k8s_client.get_dev_space(namespace, name) + resource = k8s_client.get_space(namespace, name) resource["metadata"].pop('managedFields', None) if output == "json": @@ -78,61 +78,61 @@ def dev_space_describe(name, namespace, output): import yaml click.echo(yaml.dump(resource, default_flow_style=False)) except Exception as e: - click.echo(f"Error describing dev space '{name}': {e}", err=True) + click.echo(f"Error describing space '{name}': {e}", err=True) -@click.command("hyp-dev-space") -@click.option("--name", required=True, help="Name of the dev space") +@click.command("hyp-space") +@click.option("--name", required=True, help="Name of the space") @click.option("--namespace", "-n", required=False, default="default", help="Kubernetes namespace") -def dev_space_delete(name, namespace): - """Delete a dev-space resource.""" +def space_delete(name, namespace): + """Delete a space resource.""" k8s_client = KubernetesClient() try: - k8s_client.delete_dev_space(namespace, name) + k8s_client.delete_space(namespace, name) click.echo(f"Dev space '{name}' deleted successfully") except Exception as e: - click.echo(f"Error deleting dev space '{name}': {e}", err=True) + click.echo(f"Error deleting space '{name}': {e}", err=True) -@click.command("hyp-dev-space") +@click.command("hyp-space") @generate_click_command( - schema_pkg="hyperpod_dev_space_template", + schema_pkg="hyperpod_space_template", registry=SCHEMA_REGISTRY, is_update=True, ) -def dev_space_update(version, config): - """Update a dev-space resource.""" +def space_update(version, config): + """Update a space resource.""" k8s_client = KubernetesClient() try: name = config["name"] namespace = config["namespace"] - dev_space_spec = config.get("dev_space_spec", {}) + space_spec = config.get("space_spec", {}) - k8s_client.patch_dev_space( + k8s_client.patch_space( namespace=namespace, name=name, - body=dev_space_spec + body=space_spec ) click.echo(f"Dev space '{name}' updated successfully") except Exception as e: - click.echo(f"Error updating dev space '{name}': {e}", err=True) + click.echo(f"Error updating space '{name}': {e}", err=True) -@click.command("hyp-dev-space") -@click.option("--name", required=True, help="Name of the dev space") +@click.command("hyp-space") +@click.option("--name", required=True, help="Name of the space") @click.option("--namespace", "-n", required=False, default="default", help="Kubernetes namespace") -def dev_space_start(name, namespace): - """Start a dev-space resource.""" +def space_start(name, namespace): + """Start a space resource.""" k8s_client = KubernetesClient() try: # Patch the resource to set desired status to "Running" patch_body = {"spec": {"desiredStatus": "Running"}} - k8s_client.patch_dev_space( + k8s_client.patch_space( namespace=namespace, name=name, body=patch_body @@ -140,20 +140,20 @@ def dev_space_start(name, namespace): click.echo(f"Dev space '{name}' start requested") except Exception as e: - click.echo(f"Error starting dev space '{name}': {e}", err=True) + click.echo(f"Error starting space '{name}': {e}", err=True) -@click.command("hyp-dev-space") -@click.option("--name", required=True, help="Name of the dev space") +@click.command("hyp-space") +@click.option("--name", required=True, help="Name of the space") @click.option("--namespace", "-n", required=False, default="default", help="Kubernetes namespace") -def dev_space_stop(name, namespace): - """Stop a dev-space resource.""" +def space_stop(name, namespace): + """Stop a space resource.""" k8s_client = KubernetesClient() try: # Patch the resource to set desired status to "Stopped" patch_body = {"spec": {"desiredStatus": "Stopped"}} - k8s_client.patch_dev_space( + k8s_client.patch_space( namespace=namespace, name=name, body=patch_body @@ -161,25 +161,25 @@ def dev_space_stop(name, namespace): click.echo(f"Dev space '{name}' stop requested") except Exception as e: - click.echo(f"Error stopping dev space '{name}': {e}", err=True) + click.echo(f"Error stopping space '{name}': {e}", err=True) -@click.command("hyp-dev-space") -@click.option("--name", required=True, help="Name of the dev space") +@click.command("hyp-space") +@click.option("--name", required=True, help="Name of the space") @click.option("--namespace", "-n", required=False, default="default", help="Kubernetes namespace") -def dev_space_get_logs(name, namespace): - """Get logs for a dev-space resource.""" +def space_get_logs(name, namespace): + """Get logs for a space resource.""" k8s_client = KubernetesClient() try: - # Get pods associated with the dev space + # Get pods associated with the space pods = k8s_client.list_pods_with_labels( namespace=namespace, label_selector=f"sagemaker.aws.com/space-name={name}" ) if not pods.items: - click.echo(f"No pods found for dev space '{name}'") + click.echo(f"No pods found for space '{name}'") return # Get logs from the first pod @@ -191,7 +191,7 @@ def dev_space_get_logs(name, namespace): click.echo(logs) except Exception as e: - click.echo(f"Error getting logs for dev space '{name}': {e}", err=True) + click.echo(f"Error getting logs for space '{name}': {e}", err=True) diff --git a/src/sagemaker/hyperpod/cli/constants/dev_space_access_constants.py b/src/sagemaker/hyperpod/cli/constants/space_access_constants.py similarity index 79% rename from src/sagemaker/hyperpod/cli/constants/dev_space_access_constants.py rename to src/sagemaker/hyperpod/cli/constants/space_access_constants.py index 6d41c6a0..55fc9522 100644 --- a/src/sagemaker/hyperpod/cli/constants/dev_space_access_constants.py +++ b/src/sagemaker/hyperpod/cli/constants/space_access_constants.py @@ -11,6 +11,6 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -DEV_SPACE_ACCESS_GROUP = "access.devspaces.sagemaker.aws.com" -DEV_SPACE_ACCESS_VERSION = "v1alpha1" -DEV_SPACE_ACCESS_PLURAL = "devspaceaccess" +SPACE_ACCESS_GROUP = "access.devspaces.sagemaker.aws.com" +SPACE_ACCESS_VERSION = "v1alpha1" +SPACE_ACCESS_PLURAL = "devspaceaccess" diff --git a/src/sagemaker/hyperpod/cli/constants/dev_space_constants.py b/src/sagemaker/hyperpod/cli/constants/space_constants.py similarity index 75% rename from src/sagemaker/hyperpod/cli/constants/dev_space_constants.py rename to src/sagemaker/hyperpod/cli/constants/space_constants.py index 3d6cab95..006a9235 100644 --- a/src/sagemaker/hyperpod/cli/constants/dev_space_constants.py +++ b/src/sagemaker/hyperpod/cli/constants/space_constants.py @@ -10,11 +10,11 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -DEV_SPACE_GROUP = "sagemaker.aws.com" -DEV_SPACE_VERSION = "v1alpha1" -DEV_SPACE_PLURAL = "spaces" -DEFAULT_DEV_SPACE_PORT = "8888" -# Immutable fields that cannot be updated after dev space creation +SPACE_GROUP = "sagemaker.aws.com" +SPACE_VERSION = "v1alpha1" +SPACE_PLURAL = "spaces" +DEFAULT_SPACE_PORT = "8888" +# Immutable fields that cannot be updated after space creation IMMUTABLE_FIELDS = { "storage_class_name", } \ No newline at end of file diff --git a/src/sagemaker/hyperpod/cli/hyp_cli.py b/src/sagemaker/hyperpod/cli/hyp_cli.py index 96a1bbcc..3904cc50 100644 --- a/src/sagemaker/hyperpod/cli/hyp_cli.py +++ b/src/sagemaker/hyperpod/cli/hyp_cli.py @@ -38,15 +38,15 @@ js_get_operator_logs, custom_get_operator_logs, ) -from sagemaker.hyperpod.cli.commands.dev_space import ( - dev_space_create, - dev_space_list, - dev_space_describe, - dev_space_delete, - dev_space_update, - dev_space_start, - dev_space_stop, - dev_space_get_logs, +from sagemaker.hyperpod.cli.commands.space import ( + space_create, + space_list, + space_describe, + space_delete, + space_update, + space_start, + space_stop, + space_get_logs, ) from sagemaker.hyperpod.cli.commands.init import ( @@ -107,7 +107,7 @@ def parse_args(self, ctx, args): @cli.group(cls=CLICommand, default_cmd='_default_create') def create(): """ - Create endpoints, pytorch jobs, cluster stacks, dev space, dev space access or space admin config. + Create endpoints, pytorch jobs, cluster stacks, space, space access or space admin config. If only used as 'hyp create' without [OPTIONS] COMMAND [ARGS] during init experience, then it will validate configuration and render template files for deployment. @@ -123,35 +123,35 @@ def create(): @cli.group(cls=CLICommand) def list(): - """List endpoints, pytorch jobs, cluster stacks or dev spaces.""" + """List endpoints, pytorch jobs, cluster stacks or spaces.""" pass @cli.group(cls=CLICommand) def describe(): - """Describe endpoints, pytorch jobs or cluster stacks, dev spaces or space admin configs.""" + """Describe endpoints, pytorch jobs or cluster stacks, spaces or space admin configs.""" pass @cli.group(cls=CLICommand) def update(): - """Update an existing HyperPod cluster configuration, dev space, or space admin config.""" + """Update an existing HyperPod cluster configuration, space, or space admin config.""" pass @cli.group(cls=CLICommand) def delete(): - """Delete endpoints, pytorch jobs, dev space, dev space access or space admin config.""" + """Delete endpoints, pytorch jobs, space, space access or space admin config.""" pass @cli.group(cls=CLICommand) def start(): - """Start dev space resources.""" + """Start space resources.""" pass @cli.group(cls=CLICommand) def stop(): - """Stop dev space resources.""" + """Stop space resources.""" pass @@ -166,7 +166,7 @@ def list_pods(): @cli.group(cls=CLICommand) def get_logs(): - """Get pod logs for endpoints, pytorch jobs or dev spaces.""" + """Get pod logs for endpoints, pytorch jobs or spaces.""" pass @@ -199,13 +199,13 @@ def exec(): _default_create.hidden = True create.add_command(_default_create) -create.add_command(dev_space_create) +create.add_command(space_create) list.add_command(list_jobs) list.add_command(js_list) list.add_command(custom_list) list.add_command(list_cluster_stacks) -list.add_command(dev_space_list) +list.add_command(space_list) describe.add_command(pytorch_describe) describe.add_command(js_describe) @@ -213,20 +213,20 @@ def exec(): describe.add_command(describe_cluster_stack) describe.add_command(describe_cluster) -describe.add_command(dev_space_describe) +describe.add_command(space_describe) update.add_command(update_cluster) -update.add_command(dev_space_update) +update.add_command(space_update) delete.add_command(pytorch_delete) delete.add_command(js_delete) delete.add_command(custom_delete) delete.add_command(delete_cluster_stack) -delete.add_command(dev_space_delete) +delete.add_command(space_delete) -start.add_command(dev_space_start) +start.add_command(space_start) -stop.add_command(dev_space_stop) +stop.add_command(space_stop) list_pods.add_command(pytorch_list_pods) list_pods.add_command(js_list_pods) @@ -235,7 +235,7 @@ def exec(): get_logs.add_command(pytorch_get_logs) get_logs.add_command(js_get_logs) get_logs.add_command(custom_get_logs) -get_logs.add_command(dev_space_get_logs) +get_logs.add_command(space_get_logs) diff --git a/src/sagemaker/hyperpod/cli/dev_space_utils.py b/src/sagemaker/hyperpod/cli/space_utils.py similarity index 96% rename from src/sagemaker/hyperpod/cli/dev_space_utils.py rename to src/sagemaker/hyperpod/cli/space_utils.py index f9f94d03..b32071ab 100644 --- a/src/sagemaker/hyperpod/cli/dev_space_utils.py +++ b/src/sagemaker/hyperpod/cli/space_utils.py @@ -3,7 +3,7 @@ import click from typing import Callable, Optional, Mapping, Type, Dict, Any from pydantic import ValidationError -from sagemaker.hyperpod.cli.constants.dev_space_constants import IMMUTABLE_FIELDS +from sagemaker.hyperpod.cli.constants.space_constants import IMMUTABLE_FIELDS def load_schema_for_version( @@ -26,12 +26,12 @@ def load_schema_for_version( def generate_click_command( *, version_key: Optional[str] = None, - schema_pkg: str = "hyperpod_dev_space_template", + schema_pkg: str = "hyperpod_space_template", registry: Mapping[str, Type] = None, is_update: bool = False, ) -> Callable: """ - Decorator factory for dev space commands. + Decorator factory for space commands. """ if registry is None: raise ValueError("You must pass a registry mapping version→Model") diff --git a/test/unit_tests/cli/test_dev_space.py b/test/unit_tests/cli/test_space.py similarity index 53% rename from test/unit_tests/cli/test_dev_space.py rename to test/unit_tests/cli/test_space.py index cd7ddae4..111f4770 100644 --- a/test/unit_tests/cli/test_dev_space.py +++ b/test/unit_tests/cli/test_space.py @@ -3,29 +3,29 @@ from click.testing import CliRunner from unittest.mock import Mock, patch, MagicMock -from sagemaker.hyperpod.cli.commands.dev_space import ( - dev_space_create, - dev_space_list, - dev_space_describe, - dev_space_delete, - dev_space_update, - dev_space_start, - dev_space_stop, - dev_space_get_logs, +from sagemaker.hyperpod.cli.commands.space import ( + space_create, + space_list, + space_describe, + space_delete, + space_update, + space_start, + space_stop, + space_get_logs, ) -class TestDevSpaceCommands: - """Test cases for dev space commands""" +class TestSpaceCommands: + """Test cases for space commands""" def setup_method(self): self.runner = CliRunner() self.mock_k8s_client = Mock() - @patch('sagemaker.hyperpod.cli.dev_space_utils.load_schema_for_version') - @patch('sagemaker.hyperpod.cli.commands.dev_space.KubernetesClient') - def test_dev_space_create_success(self, mock_k8s_client_class, mock_load_schema): - """Test successful dev space creation""" + @patch('sagemaker.hyperpod.cli.space_utils.load_schema_for_version') + @patch('sagemaker.hyperpod.cli.commands.space.KubernetesClient') + def test_space_create_success(self, mock_k8s_client_class, mock_load_schema): + """Test successful space creation""" # Mock schema loading mock_load_schema.return_value = { "properties": { @@ -41,15 +41,15 @@ def test_dev_space_create_success(self, mock_k8s_client_class, mock_load_schema) mock_model.return_value.to_domain.return_value = { "name": "test-space", "namespace": "test-ns", - "dev_space_spec": {"spec": {"image": "test-image"}} + "space_spec": {"spec": {"image": "test-image"}} } # Mock KubernetesClient mock_k8s_instance = Mock() mock_k8s_client_class.return_value = mock_k8s_instance - with patch('hyperpod_dev_space_template.registry.SCHEMA_REGISTRY', {'1.0': mock_model}): - result = self.runner.invoke(dev_space_create, [ + with patch('hyperpod_space_template.registry.SCHEMA_REGISTRY', {'1.0': mock_model}): + result = self.runner.invoke(space_create, [ '--version', '1.0', '--name', 'test-space', '--namespace', 'test-ns' @@ -57,25 +57,25 @@ def test_dev_space_create_success(self, mock_k8s_client_class, mock_load_schema) assert result.exit_code == 0 assert "Dev space 'test-space' created successfully" in result.output - mock_k8s_instance.create_dev_space.assert_called_once() + mock_k8s_instance.create_space.assert_called_once() - @patch('sagemaker.hyperpod.cli.dev_space_utils.load_schema_for_version') - def test_dev_space_create_missing_required_args(self, mock_load_schema): - """Test dev space creation with missing required arguments""" + @patch('sagemaker.hyperpod.cli.space_utils.load_schema_for_version') + def test_space_create_missing_required_args(self, mock_load_schema): + """Test space creation with missing required arguments""" mock_load_schema.return_value = { "properties": {"name": {"type": "string"}}, "required": ["name"] } - result = self.runner.invoke(dev_space_create, ['--version', '1.0']) + result = self.runner.invoke(space_create, ['--version', '1.0']) assert result.exit_code != 0 assert 'Missing option' in result.output - @patch('sagemaker.hyperpod.cli.commands.dev_space.KubernetesClient') - def test_dev_space_create_k8s_error(self, mock_k8s_client_class): - """Test dev space creation error handling""" + @patch('sagemaker.hyperpod.cli.commands.space.KubernetesClient') + def test_space_create_k8s_error(self, mock_k8s_client_class): + """Test space creation error handling""" mock_k8s_instance = Mock() - mock_k8s_instance.create_dev_space.side_effect = Exception("Creation failed") + mock_k8s_instance.create_space.side_effect = Exception("Creation failed") mock_k8s_client_class.return_value = mock_k8s_instance mock_model = Mock() @@ -83,11 +83,11 @@ def test_dev_space_create_k8s_error(self, mock_k8s_client_class): mock_model.return_value.to_domain.return_value = { "name": "test-space", "namespace": "test-ns", - "dev_space_spec": {} + "space_spec": {} } - with patch('hyperpod_dev_space_template.registry.SCHEMA_REGISTRY', {'1.0': mock_model}): - with patch('sagemaker.hyperpod.cli.dev_space_utils.load_schema_for_version') as mock_load_schema: + with patch('hyperpod_space_template.registry.SCHEMA_REGISTRY', {'1.0': mock_model}): + with patch('sagemaker.hyperpod.cli.space_utils.load_schema_for_version') as mock_load_schema: mock_load_schema.return_value = { "properties": { "name": {"type": "string"}, @@ -95,20 +95,20 @@ def test_dev_space_create_k8s_error(self, mock_k8s_client_class): }, "required": ["name", "namespace"] } - result = self.runner.invoke(dev_space_create, [ + result = self.runner.invoke(space_create, [ '--version', '1.0', '--name', 'test-space', '--namespace', 'test-ns' ]) assert result.exit_code == 0 - assert "Error creating dev space: Creation failed" in result.output + assert "Error creating space: Creation failed" in result.output - @patch('sagemaker.hyperpod.cli.commands.dev_space.KubernetesClient') - def test_dev_space_list_table_output(self, mock_k8s_client_class): - """Test dev space list with table output""" + @patch('sagemaker.hyperpod.cli.commands.space.KubernetesClient') + def test_space_list_table_output(self, mock_k8s_client_class): + """Test space list with table output""" mock_k8s_instance = Mock() - mock_k8s_instance.list_dev_spaces.return_value = { + mock_k8s_instance.list_spaces.return_value = { "items": [ { "metadata": {"name": "space1", "namespace": "ns1"}, @@ -122,7 +122,7 @@ def test_dev_space_list_table_output(self, mock_k8s_client_class): } mock_k8s_client_class.return_value = mock_k8s_instance - result = self.runner.invoke(dev_space_list, [ + result = self.runner.invoke(space_list, [ '--namespace', 'test-ns', '--output', 'table' ]) @@ -130,21 +130,21 @@ def test_dev_space_list_table_output(self, mock_k8s_client_class): assert result.exit_code == 0 assert "space1" in result.output assert "space2" in result.output - mock_k8s_instance.list_dev_spaces.assert_called_once_with('test-ns') + mock_k8s_instance.list_spaces.assert_called_once_with('test-ns') - @patch('sagemaker.hyperpod.cli.commands.dev_space.KubernetesClient') - def test_dev_space_list_json_output(self, mock_k8s_client_class): - """Test dev space list with JSON output""" + @patch('sagemaker.hyperpod.cli.commands.space.KubernetesClient') + def test_space_list_json_output(self, mock_k8s_client_class): + """Test space list with JSON output""" mock_resources = { "items": [ {"metadata": {"name": "space1", "namespace": "ns1"}} ] } mock_k8s_instance = Mock() - mock_k8s_instance.list_dev_spaces.return_value = mock_resources + mock_k8s_instance.list_spaces.return_value = mock_resources mock_k8s_client_class.return_value = mock_k8s_instance - result = self.runner.invoke(dev_space_list, [ + result = self.runner.invoke(space_list, [ '--namespace', 'test-ns', '--output', 'json' ]) @@ -153,62 +153,62 @@ def test_dev_space_list_json_output(self, mock_k8s_client_class): output_json = json.loads(result.output) assert output_json == mock_resources - @patch('sagemaker.hyperpod.cli.commands.dev_space.KubernetesClient') - def test_dev_space_list_empty(self, mock_k8s_client_class): - """Test dev space list with no items""" + @patch('sagemaker.hyperpod.cli.commands.space.KubernetesClient') + def test_space_list_empty(self, mock_k8s_client_class): + """Test space list with no items""" mock_k8s_instance = Mock() - mock_k8s_instance.list_dev_spaces.return_value = {"items": []} + mock_k8s_instance.list_spaces.return_value = {"items": []} mock_k8s_client_class.return_value = mock_k8s_instance - result = self.runner.invoke(dev_space_list, [ + result = self.runner.invoke(space_list, [ '--namespace', 'test-ns' ]) assert result.exit_code == 0 - assert "No dev spaces found" in result.output + assert "No spaces found" in result.output - @patch('sagemaker.hyperpod.cli.commands.dev_space.KubernetesClient') - def test_dev_space_list_error(self, mock_k8s_client_class): - """Test dev space list error handling""" + @patch('sagemaker.hyperpod.cli.commands.space.KubernetesClient') + def test_space_list_error(self, mock_k8s_client_class): + """Test space list error handling""" mock_k8s_instance = Mock() - mock_k8s_instance.list_dev_spaces.side_effect = Exception("List failed") + mock_k8s_instance.list_spaces.side_effect = Exception("List failed") mock_k8s_client_class.return_value = mock_k8s_instance - result = self.runner.invoke(dev_space_list, [ + result = self.runner.invoke(space_list, [ '--namespace', 'test-ns' ]) assert result.exit_code == 0 - assert "Error listing dev spaces: List failed" in result.output + assert "Error listing spaces: List failed" in result.output - @patch('sagemaker.hyperpod.cli.commands.dev_space.KubernetesClient') - def test_dev_space_describe_yaml_output(self, mock_k8s_client_class): - """Test dev space describe with YAML output""" + @patch('sagemaker.hyperpod.cli.commands.space.KubernetesClient') + def test_space_describe_yaml_output(self, mock_k8s_client_class): + """Test space describe with YAML output""" mock_resource = {"metadata": {"name": "test-space"}} mock_k8s_instance = Mock() - mock_k8s_instance.get_dev_space.return_value = mock_resource + mock_k8s_instance.get_space.return_value = mock_resource mock_k8s_client_class.return_value = mock_k8s_instance with patch('yaml.dump') as mock_yaml_dump: mock_yaml_dump.return_value = "yaml_output" - result = self.runner.invoke(dev_space_describe, [ + result = self.runner.invoke(space_describe, [ '--name', 'test-space', '--namespace', 'test-ns', ]) assert result.exit_code == 0 assert "yaml_output" in result.output - mock_k8s_instance.get_dev_space.assert_called_once_with('test-ns', 'test-space') + mock_k8s_instance.get_space.assert_called_once_with('test-ns', 'test-space') - @patch('sagemaker.hyperpod.cli.commands.dev_space.KubernetesClient') - def test_dev_space_describe_json_output(self, mock_k8s_client_class): - """Test dev space describe with JSON output""" + @patch('sagemaker.hyperpod.cli.commands.space.KubernetesClient') + def test_space_describe_json_output(self, mock_k8s_client_class): + """Test space describe with JSON output""" mock_resource = {"metadata": {"name": "test-space"}} mock_k8s_instance = Mock() - mock_k8s_instance.get_dev_space.return_value = mock_resource + mock_k8s_instance.get_space.return_value = mock_resource mock_k8s_client_class.return_value = mock_k8s_instance - result = self.runner.invoke(dev_space_describe, [ + result = self.runner.invoke(space_describe, [ '--name', 'test-space', '--namespace', 'test-ns', '--output', 'json' @@ -218,55 +218,55 @@ def test_dev_space_describe_json_output(self, mock_k8s_client_class): output_json = json.loads(result.output) assert output_json == mock_resource - @patch('sagemaker.hyperpod.cli.commands.dev_space.KubernetesClient') - def test_dev_space_describe_k8s_error(self, mock_k8s_client_class): - """Test dev space describe error handling""" + @patch('sagemaker.hyperpod.cli.commands.space.KubernetesClient') + def test_space_describe_k8s_error(self, mock_k8s_client_class): + """Test space describe error handling""" mock_k8s_instance = Mock() - mock_k8s_instance.get_dev_space.side_effect = Exception("Describe failed") + mock_k8s_instance.get_space.side_effect = Exception("Describe failed") mock_k8s_client_class.return_value = mock_k8s_instance - result = self.runner.invoke(dev_space_describe, [ + result = self.runner.invoke(space_describe, [ '--name', 'test-space', '--namespace', 'test-ns' ]) assert result.exit_code == 0 - assert "Error describing dev space 'test-space': Describe failed" in result.output + assert "Error describing space 'test-space': Describe failed" in result.output - @patch('sagemaker.hyperpod.cli.commands.dev_space.KubernetesClient') - def test_dev_space_delete_success(self, mock_k8s_client_class): - """Test successful dev space deletion""" + @patch('sagemaker.hyperpod.cli.commands.space.KubernetesClient') + def test_space_delete_success(self, mock_k8s_client_class): + """Test successful space deletion""" mock_k8s_instance = Mock() mock_k8s_client_class.return_value = mock_k8s_instance - result = self.runner.invoke(dev_space_delete, [ + result = self.runner.invoke(space_delete, [ '--name', 'test-space', '--namespace', 'test-ns' ]) assert result.exit_code == 0 assert "Dev space 'test-space' deleted successfully" in result.output - mock_k8s_instance.delete_dev_space.assert_called_once_with('test-ns', 'test-space') + mock_k8s_instance.delete_space.assert_called_once_with('test-ns', 'test-space') - @patch('sagemaker.hyperpod.cli.commands.dev_space.KubernetesClient') - def test_dev_space_delete_k8s_error(self, mock_k8s_client_class): - """Test dev space delete error handling""" + @patch('sagemaker.hyperpod.cli.commands.space.KubernetesClient') + def test_space_delete_k8s_error(self, mock_k8s_client_class): + """Test space delete error handling""" mock_k8s_instance = Mock() - mock_k8s_instance.delete_dev_space.side_effect = Exception("Delete failed") + mock_k8s_instance.delete_space.side_effect = Exception("Delete failed") mock_k8s_client_class.return_value = mock_k8s_instance - result = self.runner.invoke(dev_space_delete, [ + result = self.runner.invoke(space_delete, [ '--name', 'test-space', '--namespace', 'test-ns' ]) assert result.exit_code == 0 - assert "Error deleting dev space 'test-space': Delete failed" in result.output + assert "Error deleting space 'test-space': Delete failed" in result.output - @patch('sagemaker.hyperpod.cli.dev_space_utils.load_schema_for_version') - @patch('sagemaker.hyperpod.cli.commands.dev_space.KubernetesClient') - def test_dev_space_update_success(self, mock_k8s_client_class, mock_load_schema): - """Test successful dev space update""" + @patch('sagemaker.hyperpod.cli.space_utils.load_schema_for_version') + @patch('sagemaker.hyperpod.cli.commands.space.KubernetesClient') + def test_space_update_success(self, mock_k8s_client_class, mock_load_schema): + """Test successful space update""" # Mock schema loading mock_load_schema.return_value = { "properties": { @@ -282,15 +282,15 @@ def test_dev_space_update_success(self, mock_k8s_client_class, mock_load_schema) mock_model.return_value.to_domain.return_value = { "name": "test-space", "namespace": "test-ns", - "dev_space_spec": {"spec": {"image": "updated-image"}} + "space_spec": {"spec": {"image": "updated-image"}} } # Mock KubernetesClient mock_k8s_instance = Mock() mock_k8s_client_class.return_value = mock_k8s_instance - with patch('hyperpod_dev_space_template.registry.SCHEMA_REGISTRY', {'1.0': mock_model}): - result = self.runner.invoke(dev_space_update, [ + with patch('hyperpod_space_template.registry.SCHEMA_REGISTRY', {'1.0': mock_model}): + result = self.runner.invoke(space_update, [ '--version', '1.0', '--name', 'test-space', '--namespace', 'test-ns' @@ -298,13 +298,13 @@ def test_dev_space_update_success(self, mock_k8s_client_class, mock_load_schema) assert result.exit_code == 0 assert "Dev space 'test-space' updated successfully" in result.output - mock_k8s_instance.patch_dev_space.assert_called_once() + mock_k8s_instance.patch_space.assert_called_once() - @patch('sagemaker.hyperpod.cli.commands.dev_space.KubernetesClient') - def test_dev_space_update_k8_error(self, mock_k8s_client_class): - """Test dev space update error handling""" + @patch('sagemaker.hyperpod.cli.commands.space.KubernetesClient') + def test_space_update_k8_error(self, mock_k8s_client_class): + """Test space update error handling""" mock_k8s_instance = Mock() - mock_k8s_instance.patch_dev_space.side_effect = Exception("Update failed") + mock_k8s_instance.patch_space.side_effect = Exception("Update failed") mock_k8s_client_class.return_value = mock_k8s_instance mock_model = Mock() @@ -312,11 +312,11 @@ def test_dev_space_update_k8_error(self, mock_k8s_client_class): mock_model.return_value.to_domain.return_value = { "name": "test-space", "namespace": "test-ns", - "dev_space_spec": {} + "space_spec": {} } - with patch('hyperpod_dev_space_template.registry.SCHEMA_REGISTRY', {'1.0': mock_model}): - with patch('sagemaker.hyperpod.cli.dev_space_utils.load_schema_for_version') as mock_load_schema: + with patch('hyperpod_space_template.registry.SCHEMA_REGISTRY', {'1.0': mock_model}): + with patch('sagemaker.hyperpod.cli.space_utils.load_schema_for_version') as mock_load_schema: mock_load_schema.return_value = { "properties": { "name": {"type": "string"}, @@ -324,86 +324,86 @@ def test_dev_space_update_k8_error(self, mock_k8s_client_class): }, "required": ["name", "namespace"] } - result = self.runner.invoke(dev_space_update, [ + result = self.runner.invoke(space_update, [ '--version', '1.0', '--name', 'test-space', '--namespace', 'test-ns' ]) assert result.exit_code == 0 - assert "Error updating dev space 'test-space': Update failed" in result.output + assert "Error updating space 'test-space': Update failed" in result.output - @patch('sagemaker.hyperpod.cli.commands.dev_space.KubernetesClient') - def test_dev_space_start_success(self, mock_k8s_client_class): - """Test successful dev space start""" + @patch('sagemaker.hyperpod.cli.commands.space.KubernetesClient') + def test_space_start_success(self, mock_k8s_client_class): + """Test successful space start""" mock_k8s_instance = Mock() mock_k8s_client_class.return_value = mock_k8s_instance - result = self.runner.invoke(dev_space_start, [ + result = self.runner.invoke(space_start, [ '--name', 'test-space', '--namespace', 'test-ns' ]) assert result.exit_code == 0 assert "Dev space 'test-space' start requested" in result.output - mock_k8s_instance.patch_dev_space.assert_called_once_with( + mock_k8s_instance.patch_space.assert_called_once_with( namespace='test-ns', name='test-space', body={"spec": {"desiredStatus": "Running"}} ) - @patch('sagemaker.hyperpod.cli.commands.dev_space.KubernetesClient') - def test_dev_space_start_k8s_error(self, mock_k8s_client_class): - """Test dev space start error handling""" + @patch('sagemaker.hyperpod.cli.commands.space.KubernetesClient') + def test_space_start_k8s_error(self, mock_k8s_client_class): + """Test space start error handling""" mock_k8s_instance = Mock() - mock_k8s_instance.patch_dev_space.side_effect = Exception("Start failed") + mock_k8s_instance.patch_space.side_effect = Exception("Start failed") mock_k8s_client_class.return_value = mock_k8s_instance - result = self.runner.invoke(dev_space_start, [ + result = self.runner.invoke(space_start, [ '--name', 'test-space', '--namespace', 'test-ns' ]) assert result.exit_code == 0 - assert "Error starting dev space 'test-space': Start failed" in result.output + assert "Error starting space 'test-space': Start failed" in result.output - @patch('sagemaker.hyperpod.cli.commands.dev_space.KubernetesClient') - def test_dev_space_stop_success(self, mock_k8s_client_class): - """Test successful dev space stop""" + @patch('sagemaker.hyperpod.cli.commands.space.KubernetesClient') + def test_space_stop_success(self, mock_k8s_client_class): + """Test successful space stop""" mock_k8s_instance = Mock() mock_k8s_client_class.return_value = mock_k8s_instance - result = self.runner.invoke(dev_space_stop, [ + result = self.runner.invoke(space_stop, [ '--name', 'test-space', '--namespace', 'test-ns' ]) assert result.exit_code == 0 assert "Dev space 'test-space' stop requested" in result.output - mock_k8s_instance.patch_dev_space.assert_called_once_with( + mock_k8s_instance.patch_space.assert_called_once_with( namespace='test-ns', name='test-space', body={"spec": {"desiredStatus": "Stopped"}} ) - @patch('sagemaker.hyperpod.cli.commands.dev_space.KubernetesClient') - def test_dev_space_stop_k8s_error(self, mock_k8s_client_class): - """Test dev space stop error handling""" + @patch('sagemaker.hyperpod.cli.commands.space.KubernetesClient') + def test_space_stop_k8s_error(self, mock_k8s_client_class): + """Test space stop error handling""" mock_k8s_instance = Mock() - mock_k8s_instance.patch_dev_space.side_effect = Exception("Stop failed") + mock_k8s_instance.patch_space.side_effect = Exception("Stop failed") mock_k8s_client_class.return_value = mock_k8s_instance - result = self.runner.invoke(dev_space_stop, [ + result = self.runner.invoke(space_stop, [ '--name', 'test-space', '--namespace', 'test-ns' ]) assert result.exit_code == 0 - assert "Error stopping dev space 'test-space': Stop failed" in result.output + assert "Error stopping space 'test-space': Stop failed" in result.output - @patch('sagemaker.hyperpod.cli.commands.dev_space.KubernetesClient') - def test_dev_space_get_logs_success(self, mock_k8s_client_class): - """Test successful dev space get logs""" + @patch('sagemaker.hyperpod.cli.commands.space.KubernetesClient') + def test_space_get_logs_success(self, mock_k8s_client_class): + """Test successful space get logs""" mock_pod = Mock() mock_pod.metadata.name = "test-pod" mock_pods = Mock() @@ -414,7 +414,7 @@ def test_dev_space_get_logs_success(self, mock_k8s_client_class): mock_k8s_instance.get_logs_for_pod.return_value = "test logs" mock_k8s_client_class.return_value = mock_k8s_instance - result = self.runner.invoke(dev_space_get_logs, [ + result = self.runner.invoke(space_get_logs, [ '--name', 'test-space', '--namespace', 'test-ns' ]) @@ -430,9 +430,9 @@ def test_dev_space_get_logs_success(self, mock_k8s_client_class): namespace='test-ns' ) - @patch('sagemaker.hyperpod.cli.commands.dev_space.KubernetesClient') - def test_dev_space_get_logs_no_pods(self, mock_k8s_client_class): - """Test dev space get logs with no pods""" + @patch('sagemaker.hyperpod.cli.commands.space.KubernetesClient') + def test_space_get_logs_no_pods(self, mock_k8s_client_class): + """Test space get logs with no pods""" mock_pods = Mock() mock_pods.items = [] @@ -440,62 +440,62 @@ def test_dev_space_get_logs_no_pods(self, mock_k8s_client_class): mock_k8s_instance.list_pods_with_labels.return_value = mock_pods mock_k8s_client_class.return_value = mock_k8s_instance - result = self.runner.invoke(dev_space_get_logs, [ + result = self.runner.invoke(space_get_logs, [ '--name', 'test-space', '--namespace', 'test-ns' ]) assert result.exit_code == 0 - assert "No pods found for dev space 'test-space'" in result.output + assert "No pods found for space 'test-space'" in result.output - @patch('sagemaker.hyperpod.cli.commands.dev_space.KubernetesClient') - def test_dev_space_get_logs_k8s_error(self, mock_k8s_client_class): - """Test dev space get logs error handling""" + @patch('sagemaker.hyperpod.cli.commands.space.KubernetesClient') + def test_space_get_logs_k8s_error(self, mock_k8s_client_class): + """Test space get logs error handling""" mock_k8s_instance = Mock() mock_k8s_instance.list_pods_with_labels.side_effect = Exception("List pod failed") mock_k8s_client_class.return_value = mock_k8s_instance - result = self.runner.invoke(dev_space_get_logs, [ + result = self.runner.invoke(space_get_logs, [ '--name', 'test-space', '--namespace', 'test-ns' ]) assert result.exit_code == 0 - assert "Error getting logs for dev space 'test-space': List pod failed" in result.output + assert "Error getting logs for space 'test-space': List pod failed" in result.output def test_missing_required_arguments(self): """Test commands with missing required arguments""" # Test create without name - result = self.runner.invoke(dev_space_create, ['--namespace', 'test-ns']) + result = self.runner.invoke(space_create, ['--namespace', 'test-ns']) assert result.exit_code == 2 assert "Missing option '--name'" in result.output # Test describe without name - result = self.runner.invoke(dev_space_describe, ['--namespace', 'test-ns']) + result = self.runner.invoke(space_describe, ['--namespace', 'test-ns']) assert result.exit_code == 2 assert "Missing option '--name'" in result.output # Test delete without name - result = self.runner.invoke(dev_space_delete, ['--namespace', 'test-ns']) + result = self.runner.invoke(space_delete, ['--namespace', 'test-ns']) assert result.exit_code == 2 assert "Missing option '--name'" in result.output # Test update without name - result = self.runner.invoke(dev_space_update, ['--namespace', 'test-ns']) + result = self.runner.invoke(space_update, ['--namespace', 'test-ns']) assert result.exit_code == 2 assert "Missing option '--name'" in result.output # Test start without name - result = self.runner.invoke(dev_space_start, ['--namespace', 'test-ns']) + result = self.runner.invoke(space_start, ['--namespace', 'test-ns']) assert result.exit_code == 2 assert "Missing option '--name'" in result.output # Test stop without name - result = self.runner.invoke(dev_space_stop, ['--namespace', 'test-ns']) + result = self.runner.invoke(space_stop, ['--namespace', 'test-ns']) assert result.exit_code == 2 assert "Missing option '--name'" in result.output # Test get logs without name - result = self.runner.invoke(dev_space_get_logs, ['--namespace', 'test-ns']) + result = self.runner.invoke(space_get_logs, ['--namespace', 'test-ns']) assert result.exit_code == 2 assert "Missing option '--name'" in result.output diff --git a/test/unit_tests/cli/test_dev_space_utils.py b/test/unit_tests/cli/test_space_utils.py similarity index 90% rename from test/unit_tests/cli/test_dev_space_utils.py rename to test/unit_tests/cli/test_space_utils.py index d9e3a203..dfe8d389 100644 --- a/test/unit_tests/cli/test_dev_space_utils.py +++ b/test/unit_tests/cli/test_space_utils.py @@ -5,11 +5,11 @@ from unittest.mock import Mock, patch from pydantic import ValidationError -from sagemaker.hyperpod.cli.dev_space_utils import load_schema_for_version, generate_click_command +from sagemaker.hyperpod.cli.space_utils import load_schema_for_version, generate_click_command class TestLoadSchemaForVersion: - @patch('sagemaker.hyperpod.cli.dev_space_utils.pkgutil.get_data') + @patch('sagemaker.hyperpod.cli.space_utils.pkgutil.get_data') def test_success(self, mock_get_data): """Test successful schema loading""" data = {"properties": {"name": {"type": "string"}}} @@ -20,7 +20,7 @@ def test_success(self, mock_get_data): assert result == data mock_get_data.assert_called_once_with('test_package.v1_2', 'schema.json') - @patch('sagemaker.hyperpod.cli.dev_space_utils.pkgutil.get_data') + @patch('sagemaker.hyperpod.cli.space_utils.pkgutil.get_data') def test_schema_not_found(self, mock_get_data): """Test handling of missing schema file""" mock_get_data.return_value = None @@ -30,7 +30,7 @@ def test_schema_not_found(self, mock_get_data): assert "Could not load schema.json for version 1.0" in str(exc.value) - @patch('sagemaker.hyperpod.cli.dev_space_utils.pkgutil.get_data') + @patch('sagemaker.hyperpod.cli.space_utils.pkgutil.get_data') def test_invalid_json_schema(self, mock_get_data): """Test handling of invalid JSON in schema file""" mock_get_data.return_value = b'invalid json' @@ -49,7 +49,7 @@ def test_missing_registry(self): generate_click_command(schema_pkg="test_package") assert "You must pass a registry mapping" in str(exc.value) - @patch('sagemaker.hyperpod.cli.dev_space_utils.load_schema_for_version') + @patch('sagemaker.hyperpod.cli.space_utils.load_schema_for_version') def test_unsupported_version(self, mock_load_schema): """Test handling of unsupported version""" mock_load_schema.return_value = {'properties': {}, 'required': []} @@ -64,7 +64,7 @@ def cmd(version, domain_config): assert result.exit_code != 0 assert 'Unsupported schema version: 1.0' in result.output - @patch('sagemaker.hyperpod.cli.dev_space_utils.load_schema_for_version') + @patch('sagemaker.hyperpod.cli.space_utils.load_schema_for_version') def test_version_handling(self, mock_load_schema): """Test version handling in command generation""" schema = {'properties': {}, 'required': []} @@ -91,7 +91,7 @@ def cmd(version, domain_config): assert result.exit_code == 0 assert result.output.strip() == '2.0' - @patch('sagemaker.hyperpod.cli.dev_space_utils.load_schema_for_version') + @patch('sagemaker.hyperpod.cli.space_utils.load_schema_for_version') def test_resources_building(self, mock_load_schema): """Test CPU and memory resource building""" schema = { @@ -117,7 +117,7 @@ def to_domain(self): registry = {'1.0': DummyModel} @click.command() - @generate_click_command(registry=registry, schema_pkg="hyperpod_dev_space_template") + @generate_click_command(registry=registry, schema_pkg="hyperpod_space_template") def cmd(version, domain_config): click.echo(json.dumps(domain_config.resources)) @@ -141,7 +141,7 @@ def cmd(version, domain_config): assert result.exit_code == 0 assert result.output.strip() == 'null' - @patch('sagemaker.hyperpod.cli.dev_space_utils.load_schema_for_version') + @patch('sagemaker.hyperpod.cli.space_utils.load_schema_for_version') def test_type_conversion(self, mock_load_schema): """Test type conversion for different parameter types""" schema = { @@ -164,7 +164,7 @@ def to_domain(self): registry = {'1.0': DummyModel} @click.command() - @generate_click_command(registry=registry, schema_pkg="hyperpod_dev_space_template") + @generate_click_command(registry=registry, schema_pkg="hyperpod_space_template") def cmd(version, domain_config): click.echo(json.dumps({ 'name': domain_config.name, @@ -193,7 +193,7 @@ def cmd(version, domain_config): assert result.exit_code == 2 assert "Invalid value" in result.output - @patch('sagemaker.hyperpod.cli.dev_space_utils.load_schema_for_version') + @patch('sagemaker.hyperpod.cli.space_utils.load_schema_for_version') def test_successful_command_execution(self, mock_load_schema): """Test successful command execution with valid parameters""" schema = { @@ -214,7 +214,7 @@ def to_domain(self): registry = {'1.0': DummyModel} @click.command() - @generate_click_command(registry=registry, schema_pkg="hyperpod_dev_space_template") + @generate_click_command(registry=registry, schema_pkg="hyperpod_space_template") def cmd(version, domain_config): click.echo(f'success: {domain_config.name}') @@ -223,7 +223,7 @@ def cmd(version, domain_config): assert result.exit_code == 0 assert 'success: test-space' in result.output - @patch('sagemaker.hyperpod.cli.dev_space_utils.load_schema_for_version') + @patch('sagemaker.hyperpod.cli.space_utils.load_schema_for_version') def test_immutable_fields_excluded_in_update(self, mock_load_schema): """Test that immutable fields are excluded in update mode""" schema = { @@ -247,7 +247,7 @@ def to_domain(self): @click.command() @generate_click_command( registry=registry, - schema_pkg="hyperpod_dev_space_template", + schema_pkg="hyperpod_space_template", is_update=True ) def cmd(version, domain_config): @@ -262,7 +262,7 @@ def cmd(version, domain_config): assert '--name' in result.output assert '--image' in result.output - @patch('sagemaker.hyperpod.cli.dev_space_utils.load_schema_for_version') + @patch('sagemaker.hyperpod.cli.space_utils.load_schema_for_version') def test_filtered_kwargs(self, mock_load_schema): """Test that None/empty values are filtered out""" schema = { @@ -285,7 +285,7 @@ def to_domain(self): registry = {'1.0': DummyModel} @click.command() - @generate_click_command(registry=registry, schema_pkg="hyperpod_dev_space_template") + @generate_click_command(registry=registry, schema_pkg="hyperpod_space_template") def cmd(version, domain_config): # Check that None values were filtered out click.echo(json.dumps(domain_config.received_kwargs)) @@ -297,7 +297,7 @@ def cmd(version, domain_config): assert output['image'] == 'default-image' assert 'namespace' not in output - @patch('sagemaker.hyperpod.cli.dev_space_utils.load_schema_for_version') + @patch('sagemaker.hyperpod.cli.space_utils.load_schema_for_version') def test_default_version_injection(self, mock_load_schema): """Test that version flag is injected when no version_key is provided""" schema = {'properties': {}, 'required': []} @@ -310,7 +310,7 @@ def to_domain(self): return self registry = {'1.0': DummyModel, '2.0': DummyModel} @click.command() - @generate_click_command(registry=registry, schema_pkg="hyperpod_dev_space_template") + @generate_click_command(registry=registry, schema_pkg="hyperpod_space_template") def cmd(version, domain_config): click.echo(version) @@ -325,7 +325,7 @@ def cmd(version, domain_config): assert result.exit_code == 0 assert result.output.strip() == '2.0' - @patch('sagemaker.hyperpod.cli.dev_space_utils.load_schema_for_version') + @patch('sagemaker.hyperpod.cli.space_utils.load_schema_for_version') def test_schema_defaults_and_required_fields(self, mock_load_schema): """Test handling of schema defaults and required fields""" schema = { @@ -347,7 +347,7 @@ def to_domain(self): registry = {'1.0': DummyModel} @click.command() - @generate_click_command(registry=registry, schema_pkg="hyperpod_dev_space_template") + @generate_click_command(registry=registry, schema_pkg="hyperpod_space_template") def cmd(version, domain_config): click.echo('success') diff --git a/test/unit_tests/clients/test_kubernetes_client.py b/test/unit_tests/clients/test_kubernetes_client.py index 765e487f..5c30a17b 100644 --- a/test/unit_tests/clients/test_kubernetes_client.py +++ b/test/unit_tests/clients/test_kubernetes_client.py @@ -699,28 +699,28 @@ def test_check_if_namespace_exists_false( self.assertFalse(result) @patch("kubernetes.client.CustomObjectsApi.create_namespaced_custom_object") - def test_create_dev_space(self, mock_create_namespaced_custom_object): - """Test creating a dev space""" + def test_create_space(self, mock_create_namespaced_custom_object): + """Test creating a space""" test_client = KubernetesClient() - dev_space_spec = {"spec": {"image": "test-image"}} + space_spec = {"spec": {"image": "test-image"}} - test_client.create_dev_space("test-namespace", dev_space_spec) + test_client.create_space("test-namespace", space_spec) mock_create_namespaced_custom_object.assert_called_once_with( group="sagemaker.aws.com", version="v1alpha1", namespace="test-namespace", plural="spaces", - body=dev_space_spec + body=space_spec ) @patch("kubernetes.client.CustomObjectsApi.list_namespaced_custom_object") - def test_list_dev_spaces_with_namespace(self, mock_list_namespaced_custom_object): - """Test listing dev spaces in a specific namespace""" + def test_list_spaces_with_namespace(self, mock_list_namespaced_custom_object): + """Test listing spaces in a specific namespace""" test_client = KubernetesClient() mock_list_namespaced_custom_object.return_value = {"items": []} - result = test_client.list_dev_spaces("test-namespace") + result = test_client.list_spaces("test-namespace") mock_list_namespaced_custom_object.assert_called_once_with( group="sagemaker.aws.com", @@ -731,12 +731,12 @@ def test_list_dev_spaces_with_namespace(self, mock_list_namespaced_custom_object self.assertEqual(result, {"items": []}) @patch("kubernetes.client.CustomObjectsApi.list_cluster_custom_object") - def test_list_dev_spaces_without_namespace(self, mock_list_cluster_custom_object): - """Test listing dev spaces across all namespaces""" + def test_list_spaces_without_namespace(self, mock_list_cluster_custom_object): + """Test listing spaces across all namespaces""" test_client = KubernetesClient() mock_list_cluster_custom_object.return_value = {"items": []} - result = test_client.list_dev_spaces(None) + result = test_client.list_spaces(None) mock_list_cluster_custom_object.assert_called_once_with( group="sagemaker.aws.com", @@ -746,13 +746,13 @@ def test_list_dev_spaces_without_namespace(self, mock_list_cluster_custom_object self.assertEqual(result, {"items": []}) @patch("kubernetes.client.CustomObjectsApi.get_namespaced_custom_object") - def test_get_dev_space(self, mock_get_namespaced_custom_object): - """Test getting a specific dev space""" + def test_get_space(self, mock_get_namespaced_custom_object): + """Test getting a specific space""" test_client = KubernetesClient() - mock_dev_space = {"metadata": {"name": "test-space"}} - mock_get_namespaced_custom_object.return_value = mock_dev_space + mock_space = {"metadata": {"name": "test-space"}} + mock_get_namespaced_custom_object.return_value = mock_space - result = test_client.get_dev_space("test-namespace", "test-space") + result = test_client.get_space("test-namespace", "test-space") mock_get_namespaced_custom_object.assert_called_once_with( group="sagemaker.aws.com", @@ -761,15 +761,15 @@ def test_get_dev_space(self, mock_get_namespaced_custom_object): plural="spaces", name="test-space" ) - self.assertEqual(result, mock_dev_space) + self.assertEqual(result, mock_space) @patch("kubernetes.client.CustomObjectsApi.delete_namespaced_custom_object") - def test_delete_dev_space(self, mock_delete_namespaced_custom_object): - """Test deleting a dev space""" + def test_delete_space(self, mock_delete_namespaced_custom_object): + """Test deleting a space""" test_client = KubernetesClient() mock_delete_namespaced_custom_object.return_value = {} - result = test_client.delete_dev_space("test-namespace", "test-space") + result = test_client.delete_space("test-namespace", "test-space") mock_delete_namespaced_custom_object.assert_called_once_with( group="sagemaker.aws.com", @@ -781,13 +781,13 @@ def test_delete_dev_space(self, mock_delete_namespaced_custom_object): self.assertEqual(result, {}) @patch("kubernetes.client.CustomObjectsApi.patch_namespaced_custom_object") - def test_patch_dev_space(self, mock_patch_namespaced_custom_object): - """Test patching a dev space""" + def test_patch_space(self, mock_patch_namespaced_custom_object): + """Test patching a space""" test_client = KubernetesClient() patch_body = {"spec": {"desiredStatus": "Running"}} mock_patch_namespaced_custom_object.return_value = {} - result = test_client.patch_dev_space("test-namespace", "test-space", patch_body) + result = test_client.patch_space("test-namespace", "test-space", patch_body) mock_patch_namespaced_custom_object.assert_called_once_with( group="sagemaker.aws.com", From 514630782a594f14f81a935c82d41e1737b64963 Mon Sep 17 00:00:00 2001 From: aws-brianxia Date: Thu, 30 Oct 2025 10:22:57 -0700 Subject: [PATCH 11/31] Update the Space model and constants per latest operator (#275) --- .../hyperpod_space_template/v1_0/model.py | 291 ++++++++-------- .../hyperpod_space_template/v1_0/schema.json | 317 +++++++++++++----- .../hyperpod/cli/clients/kubernetes_client.py | 1 - src/sagemaker/hyperpod/cli/commands/space.py | 10 +- .../hyperpod/cli/constants/space_constants.py | 8 +- src/sagemaker/hyperpod/cli/space_utils.py | 148 +++++++- test/unit_tests/cli/test_space.py | 26 +- test/unit_tests/cli/test_space_utils.py | 249 +++++++++++++- .../clients/test_kubernetes_client.py | 24 +- 9 files changed, 822 insertions(+), 252 deletions(-) diff --git a/hyperpod-space-template/hyperpod_space_template/v1_0/model.py b/hyperpod-space-template/hyperpod_space_template/v1_0/model.py index b8b06758..016c2978 100644 --- a/hyperpod-space-template/hyperpod_space_template/v1_0/model.py +++ b/hyperpod-space-template/hyperpod_space_template/v1_0/model.py @@ -1,62 +1,91 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator -from typing import Optional, List, Dict, Literal +from typing import Optional, List, Dict, Literal, Any from enum import Enum -# TODO: Temporarily removed for private beta -# class VolumeConfig(BaseModel): -# name: str = Field( -# ..., -# description="Volume name", -# min_length=1 -# ) -# type: Literal['hostPath', 'pvc'] = Field(..., description="Volume type") -# mount_path: str = Field( -# ..., -# description="Mount path in container", -# min_length=1 -# ) -# path: Optional[str] = Field( -# None, -# description="Host path (required for hostPath volumes)", -# min_length=1 -# ) -# claim_name: Optional[str] = Field( -# None, -# description="PVC claim name (required for pvc volumes)", -# min_length=1 -# ) -# read_only: Optional[Literal['true', 'false']] = Field(None, description="Read-only flag for pvc volumes") - - -class SharedStatus(str, Enum): - PUBLIC = "public" - PRIVATE = "private" - - -class Application(str, Enum): - JUPYTER = "jupyter" - CODE_EDITOR = "code-editor" - - -class ResourcesConfig(BaseModel): - memory: Optional[str] = Field(default="1Gi", description="Memory limit") - cpu: Optional[str] = Field(default="500m", description="CPU limit") - nvidia_gpu: Optional[str] = Field(default=None, alias="nvidia.com/gpu", description="GPU limit") +class OwnershipType(str, Enum): + PUBLIC = "Public" + OWNER_ONLY = "OwnerOnly" + + +class DesiredStatus(str, Enum): + RUNNING = "Running" + STOPPED = "Stopped" + + +class VolumeSpec(BaseModel): + """VolumeSpec defines a volume to mount from an existing PVC""" + name: str = Field( + description="Name is a unique identifier for this volume within the pod (maps to pod.spec.volumes[].name)", + min_length=1 + ) + mount_path: str = Field( + alias="mountPath", + description="MountPath is the path where the volume should be mounted (Unix-style path, e.g. /data)", + min_length=1 + ) + persistent_volume_claim_name: str = Field( + alias="persistentVolumeClaimName", + description="PersistentVolumeClaimName is the name of the existing PVC to mount", + min_length=1 + ) + + +class ContainerConfig(BaseModel): + """ContainerConfig defines container command and args configuration""" + command: Optional[List[str]] = Field( + default=None, + description="Command specifies the container command" + ) + args: Optional[List[str]] = Field( + default=None, + description="Args specifies the container arguments" + ) + + +class StorageSpec(BaseModel): + """StorageSpec defines the storage configuration for Workspace""" + storage_class_name: Optional[str] = Field( + default=None, + alias="storageClassName", + description="StorageClassName specifies the storage class to use for persistent storage" + ) + size: Optional[str] = Field( + default="10Gi", + description="Size specifies the size of the persistent volume. Supports standard Kubernetes resource quantities (e.g., '10Gi', '500Mi', '1Ti'). Integer values without units are interpreted as bytes" + ) + mount_path: Optional[str] = Field( + default="/home", + alias="mountPath", + description="MountPath specifies where to mount the persistent volume in the container. Default is /home/jovyan (jovyan is the standard user in Jupyter images)" + ) + + +class ResourceRequirements(BaseModel): + """ResourceRequirements describes the compute resource requirements""" + requests: Optional[Dict[str, str]] = Field( + default=None, + description="Requests describes the minimum amount of compute resources required. If Requests is omitted for a container, it defaults to Limits if that is explicitly specified, otherwise to an implementation-defined value. Requests cannot exceed Limits." + ) + limits: Optional[Dict[str, str]] = Field( + default=None, + description="Limits describes the maximum amount of compute resources allowed." + ) class SpaceConfig(BaseModel): + """SpaceConfig defines the desired state of a Space""" model_config = ConfigDict(extra="forbid") name: str = Field( - description="Dev space name", + description="Space name", min_length=1, max_length=63, pattern=r'^[a-z0-9]([-a-z0-9]*[a-z0-9])?$' ) - image: Optional[str] = Field( - default="public.ecr.aws/sagemaker/sagemaker-distribution:3.2.0-cpu", - description="Container image for the space", + display_name: str = Field( + alias="display_name", + description="Display Name of the space", min_length=1 ) namespace: str = Field( @@ -64,81 +93,77 @@ class SpaceConfig(BaseModel): description="Kubernetes namespace", min_length=1 ) - desired_status: Optional[Literal['Running', 'Stopped']] = Field( - default="Running", + image: Optional[str] = Field( + default=None, + description="Image specifies the container image to use" + ) + desired_status: Optional[DesiredStatus] = Field( + default=None, alias="desired_status", - description="Desired status of the space" + description="DesiredStatus specifies the desired operational status" ) - service_account_name: Optional[str] = Field( - default="default", - alias="service_account_name", - description="Service account name", - min_length=1 + ownership_type: Optional[OwnershipType] = Field( + default=None, + alias="ownership_type", + description="OwnershipType specifies who can modify the space. Public means anyone with RBAC permissions can update/delete the space. OwnerOnly means only the creator can update/delete the space." ) - resources: Optional[ResourcesConfig] = Field( - default=ResourcesConfig(), - description="Resource limit" + resources: Optional[ResourceRequirements] = Field( + default=None, + description="Resources specifies the resource requirements" ) - storage_class_name: Optional[str] = Field( + storage: Optional[StorageSpec] = Field( default=None, - alias="storage_class_name", - description="Storage class name", - min_length=1 + description="Storage specifies the storage configuration" ) - storage_size: Optional[str] = Field( + volumes: Optional[List[VolumeSpec]] = Field( default=None, - alias="storage_size", - description="Storage size (e.g., '10Gi')", - min_length=1 + description="Volumes specifies additional volumes to mount from existing PersistentVolumeClaims" + ) + container_config: Optional[ContainerConfig] = Field( + default=None, + alias="container_config", + description="ContainerConfig specifies container command and args configuration" + ) + node_selector: Optional[Dict[str, str]] = Field( + default=None, + alias="node_selector", + description="NodeSelector specifies node selection constraints for the space pod (JSON)" + ) + affinity: Optional[Dict[str, Any]] = Field( + default=None, + description="Affinity specifies node affinity and anti-affinity rules for the space pod (JSON)" + ) + tolerations: Optional[List[Dict[str, Any]]] = Field( + default=None, + description="Tolerations specifies tolerations for the space pod to schedule on nodes with matching taints (JSON)" ) - shared_status: Optional[SharedStatus] = Field( - default=SharedStatus.PRIVATE, - description="Space shared setting (private | public)" - ) - application: Optional[Application] = Field( - default=Application.JUPYTER, - description="Application to run in the container (jupyter | code-editor)" - ) - # TODO: Temporarily removed for private beta - # queue_name: Optional[str] = Field( - # default=None, - # alias="queue_name", - # description="Queue name for scheduling", - # min_length=1, - # max_length=63, - # pattern=r'^[a-z0-9]([-a-z0-9]*[a-z0-9])?$' - # ) - # priority: Optional[str] = Field( - # default=None, - # description="Priority class for scheduling", - # min_length=1 - # ) - # volume: Optional[List[VolumeConfig]] = Field( - # default=None, description="List of volume configurations. \ - # Command structure: --volume name=,type=,mount_path=, \ - # For hostPath: --volume name=model-data,type=hostPath,mount_path=/data,path=/data \ - # For persistentVolumeClaim: --volume name=training-output,type=pvc,mount_path=/mnt/output,claim_name=training-output-pvc,read_only=false \ - # If multiple --volume flag if multiple volumes are needed \ - # " - # ) - - # @field_validator('volume') - # def validate_no_duplicates(cls, v): - # """Validate no duplicate volume names or mount paths.""" - # if not v: - # return v + lifecycle: Optional[Dict[str, Any]] = Field( + default=None, + description="Lifecycle specifies actions that the management system should take in response to container lifecycle events (JSON)" + ) + template_ref: Optional[str] = Field( + default=None, + alias="template_ref", + description="TemplateRef references a WorkspaceTemplate to use as base configuration. When set, template provides defaults and spec fields (Image, Resources, Storage.Size) act as overrides." + ) + + @field_validator('volumes') + def validate_no_duplicate_volumes(cls, v): + """Validate no duplicate volume names or mount paths.""" + if not v: + return v - # # Check for duplicate volume names - # names = [vol.name for vol in v] - # if len(names) != len(set(names)): - # raise ValueError("Duplicate volume names found") + # Check for duplicate volume names + names = [vol.name for vol in v] + if len(names) != len(set(names)): + raise ValueError("Duplicate volume names found") - # # Check for duplicate mount paths - # mount_paths = [vol.mount_path for vol in v] - # if len(mount_paths) != len(set(mount_paths)): - # raise ValueError("Duplicate mount paths found") + # Check for duplicate mount paths + mount_paths = [vol.mount_path for vol in v] + if len(mount_paths) != len(set(mount_paths)): + raise ValueError("Duplicate mount paths found") - # return v + return v def to_domain(self) -> Dict: """ @@ -146,44 +171,44 @@ def to_domain(self) -> Dict: """ # Create the space spec spec = { - "image": self.image + "displayName": self.display_name } # Add optional spec fields + if self.image is not None: + spec["image"] = self.image if self.desired_status is not None: - spec["desiredStatus"] = self.desired_status - if self.service_account_name is not None: - spec["serviceAccountName"] = self.service_account_name + spec["desiredStatus"] = self.desired_status.value + if self.ownership_type is not None: + spec["ownershipType"] = self.ownership_type.value if self.resources is not None: spec["resources"] = self.resources.model_dump(exclude_none=True) - if self.storage_class_name is not None: - spec["storageClassName"] = self.storage_class_name - if self.storage_size is not None: - spec["storageSize"] = self.storage_size - if self.shared_status is not None: - spec["sharedStatus"] = self.shared_status.value - if self.application is not None: - spec["application"] = self.application.value + if self.storage is not None: + spec["storage"] = self.storage.model_dump(exclude_none=True, by_alias=True) + if self.volumes is not None: + spec["volumes"] = [vol.model_dump(exclude_none=True, by_alias=True) for vol in self.volumes] + if self.container_config is not None: + spec["containerConfig"] = self.container_config.model_dump(exclude_none=True) + if self.node_selector is not None: + spec["nodeSelector"] = self.node_selector + if self.affinity is not None: + spec["affinity"] = self.affinity + if self.tolerations is not None: + spec["tolerations"] = self.tolerations + if self.lifecycle is not None: + spec["lifecycle"] = self.lifecycle + if self.template_ref is not None: + spec["templateRef"] = self.template_ref # Create metadata metadata = {"name": self.name} if self.namespace is not None: metadata["namespace"] = self.namespace - # Add labels for scheduling - # labels = {} - # if self.queue_name is not None: - # labels["kueue.x-k8s.io/queue-name"] = self.queue_name - # if self.priority is not None: - # labels["kueue.x-k8s.io/priority-class"] = self.priority - - # if labels: - # metadata["labels"] = labels - # Create the complete space configuration space_config = { - "apiVersion": "sagemaker.aws.com/v1alpha1", - "kind": "Space", + "apiVersion": "workspace.jupyter.org/v1alpha1", + "kind": "Workspace", "metadata": metadata, "spec": spec } diff --git a/hyperpod-space-template/hyperpod_space_template/v1_0/schema.json b/hyperpod-space-template/hyperpod_space_template/v1_0/schema.json index 82693bb8..30aa045d 100644 --- a/hyperpod-space-template/hyperpod_space_template/v1_0/schema.json +++ b/hyperpod-space-template/hyperpod_space_template/v1_0/schema.json @@ -1,16 +1,103 @@ { "$defs": { - "Application": { + "ContainerConfig": { + "description": "ContainerConfig defines container command and args configuration", + "properties": { + "command": { + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Command specifies the container command", + "title": "Command" + }, + "args": { + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Args specifies the container arguments", + "title": "Args" + } + }, + "title": "ContainerConfig", + "type": "object" + }, + "DesiredStatus": { + "enum": [ + "Running", + "Stopped" + ], + "title": "DesiredStatus", + "type": "string" + }, + "OwnershipType": { "enum": [ - "jupyter", - "code-editor" + "Public", + "OwnerOnly" ], - "title": "Application", + "title": "OwnershipType", "type": "string" }, - "ResourcesConfig": { + "ResourceRequirements": { + "description": "ResourceRequirements describes the compute resource requirements", + "properties": { + "requests": { + "anyOf": [ + { + "additionalProperties": { + "type": "string" + }, + "type": "object" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Requests describes the minimum amount of compute resources required. If Requests is omitted for a container, it defaults to Limits if that is explicitly specified, otherwise to an implementation-defined value. Requests cannot exceed Limits.", + "title": "Requests" + }, + "limits": { + "anyOf": [ + { + "additionalProperties": { + "type": "string" + }, + "type": "object" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Limits describes the maximum amount of compute resources allowed.", + "title": "Limits" + } + }, + "title": "ResourceRequirements", + "type": "object" + }, + "StorageSpec": { + "description": "StorageSpec defines the storage configuration for Workspace", "properties": { - "memory": { + "storageClassName": { "anyOf": [ { "type": "string" @@ -19,11 +106,11 @@ "type": "null" } ], - "default": "1Gi", - "description": "Memory limit", - "title": "Memory" + "default": null, + "description": "StorageClassName specifies the storage class to use for persistent storage", + "title": "Storageclassname" }, - "cpu": { + "size": { "anyOf": [ { "type": "string" @@ -32,11 +119,11 @@ "type": "null" } ], - "default": "500m", - "description": "CPU limit", - "title": "Cpu" + "default": "10Gi", + "description": "Size specifies the size of the persistent volume. Supports standard Kubernetes resource quantities (e.g., '10Gi', '500Mi', '1Ti'). Integer values without units are interpreted as bytes", + "title": "Size" }, - "nvidia.com/gpu": { + "mountPath": { "anyOf": [ { "type": "string" @@ -45,156 +132,236 @@ "type": "null" } ], - "default": null, - "description": "GPU limit", - "title": "Nvidia.Com/Gpu" + "default": "/home", + "description": "MountPath specifies where to mount the persistent volume in the container. Default is /home/jovyan (jovyan is the standard user in Jupyter images)", + "title": "Mountpath" } }, - "title": "ResourcesConfig", + "title": "StorageSpec", "type": "object" }, - "SharedStatus": { - "enum": [ - "public", - "private" + "VolumeSpec": { + "description": "VolumeSpec defines a volume to mount from an existing PVC", + "properties": { + "name": { + "description": "Name is a unique identifier for this volume within the pod (maps to pod.spec.volumes[].name)", + "minLength": 1, + "title": "Name", + "type": "string" + }, + "mountPath": { + "description": "MountPath is the path where the volume should be mounted (Unix-style path, e.g. /data)", + "minLength": 1, + "title": "Mountpath", + "type": "string" + }, + "persistentVolumeClaimName": { + "description": "PersistentVolumeClaimName is the name of the existing PVC to mount", + "minLength": 1, + "title": "Persistentvolumeclaimname", + "type": "string" + } + }, + "required": [ + "name", + "mountPath", + "persistentVolumeClaimName" ], - "title": "SharedStatus", - "type": "string" + "title": "VolumeSpec", + "type": "object" } }, "additionalProperties": false, + "description": "SpaceConfig defines the desired state of a Space", "properties": { "name": { - "description": "Dev space name", + "description": "Space name", "maxLength": 63, "minLength": 1, "pattern": "^[a-z0-9]([-a-z0-9]*[a-z0-9])?$", "title": "Name", "type": "string" }, + "display_name": { + "description": "Display Name of the space", + "minLength": 1, + "title": "Display Name", + "type": "string" + }, + "namespace": { + "default": "default", + "description": "Kubernetes namespace", + "minLength": 1, + "title": "Namespace", + "type": "string" + }, "image": { "anyOf": [ { - "minLength": 1, "type": "string" }, { "type": "null" } ], - "default": "public.ecr.aws/sagemaker/sagemaker-distribution:3.2.0-cpu", - "description": "Container image for the space", + "default": null, + "description": "Image specifies the container image to use", "title": "Image" }, - "namespace": { - "default": "default", - "description": "Kubernetes namespace", - "minLength": 1, - "title": "Namespace", - "type": "string" - }, "desired_status": { "anyOf": [ { - "enum": [ - "Running", - "Stopped" - ], - "type": "string" + "$ref": "#/$defs/DesiredStatus" }, { "type": "null" } ], - "default": "Running", - "description": "Desired status of the space", - "title": "Desired Status" + "default": null, + "description": "DesiredStatus specifies the desired operational status" }, - "service_account_name": { + "ownership_type": { "anyOf": [ { - "minLength": 1, - "type": "string" + "$ref": "#/$defs/OwnershipType" }, { "type": "null" } ], - "default": "default", - "description": "Service account name", - "title": "Service Account Name" + "default": null, + "description": "OwnershipType specifies who can modify the space. Public means anyone with RBAC permissions can update/delete the space. OwnerOnly means only the creator can update/delete the space." }, "resources": { "anyOf": [ { - "$ref": "#/$defs/ResourcesConfig" + "$ref": "#/$defs/ResourceRequirements" }, { "type": "null" } ], - "default": { - "memory": "1Gi", - "cpu": "500m", - "nvidia.com/gpu": null - }, - "description": "Resource limit" + "default": null, + "description": "Resources specifies the resource requirements" }, - "storage_class_name": { + "storage": { "anyOf": [ { - "minLength": 1, - "type": "string" + "$ref": "#/$defs/StorageSpec" }, { "type": "null" } ], "default": null, - "description": "Storage class name", - "title": "Storage Class Name" + "description": "Storage specifies the storage configuration" }, - "storage_size": { + "volumes": { "anyOf": [ { - "minLength": 1, - "type": "string" + "items": { + "$ref": "#/$defs/VolumeSpec" + }, + "type": "array" }, { "type": "null" } ], "default": null, - "description": "Storage size (e.g., '10Gi')", - "title": "Storage Size" + "description": "Volumes specifies additional volumes to mount from existing PersistentVolumeClaims", + "title": "Volumes" }, - "shared_status": { + "container_config": { "anyOf": [ { - "$ref": "#/$defs/SharedStatus" + "$ref": "#/$defs/ContainerConfig" }, { "type": "null" } ], - "default": "private", - "description": "Space shared setting (private | public)" + "default": null, + "description": "ContainerConfig specifies container command and args configuration" }, - "application": { + "node_selector": { "anyOf": [ { - "$ref": "#/$defs/Application" + "additionalProperties": { + "type": "string" + }, + "type": "object" }, { "type": "null" } ], - "default": "jupyter", - "description": "Application to run in the container (jupyter | code-editor)" + "default": null, + "description": "NodeSelector specifies node selection constraints for the space pod (JSON)", + "title": "Node Selector" + }, + "affinity": { + "anyOf": [ + { + "additionalProperties": true, + "type": "object" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Affinity specifies node affinity and anti-affinity rules for the space pod (JSON)", + "title": "Affinity" + }, + "tolerations": { + "anyOf": [ + { + "items": { + "additionalProperties": true, + "type": "object" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Tolerations specifies tolerations for the space pod to schedule on nodes with matching taints (JSON)", + "title": "Tolerations" + }, + "lifecycle": { + "anyOf": [ + { + "additionalProperties": true, + "type": "object" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Lifecycle specifies actions that the management system should take in response to container lifecycle events (JSON)", + "title": "Lifecycle" + }, + "template_ref": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "TemplateRef references a WorkspaceTemplate to use as base configuration. When set, template provides defaults and spec fields (Image, Resources, Storage.Size) act as overrides.", + "title": "Template Ref" } }, "required": [ - "name" + "name", + "display_name" ], "title": "SpaceConfig", "type": "object" diff --git a/src/sagemaker/hyperpod/cli/clients/kubernetes_client.py b/src/sagemaker/hyperpod/cli/clients/kubernetes_client.py index 55576857..18bb07d6 100644 --- a/src/sagemaker/hyperpod/cli/clients/kubernetes_client.py +++ b/src/sagemaker/hyperpod/cli/clients/kubernetes_client.py @@ -44,7 +44,6 @@ SPACE_GROUP, SPACE_VERSION, SPACE_PLURAL, - DEFAULT_SPACE_PORT, ) from sagemaker.hyperpod.cli.constants.space_admin_config_constants import ( SPACE_ADMIN_CONFIG_GROUP, diff --git a/src/sagemaker/hyperpod/cli/commands/space.py b/src/sagemaker/hyperpod/cli/commands/space.py index f8e0473b..450fcaac 100644 --- a/src/sagemaker/hyperpod/cli/commands/space.py +++ b/src/sagemaker/hyperpod/cli/commands/space.py @@ -26,7 +26,7 @@ def space_create(version, config): k8s_client = KubernetesClient() k8s_client.create_space(namespace, space_spec) - click.echo(f"Dev space '{name}' created successfully in namespace '{namespace}'") + click.echo(f"Space '{name}' created successfully in namespace '{namespace}'") except Exception as e: click.echo(f"Error creating space: {e}", err=True) @@ -91,7 +91,7 @@ def space_delete(name, namespace): try: k8s_client.delete_space(namespace, name) - click.echo(f"Dev space '{name}' deleted successfully") + click.echo(f"Space '{name}' deleted successfully") except Exception as e: click.echo(f"Error deleting space '{name}': {e}", err=True) @@ -117,7 +117,7 @@ def space_update(version, config): body=space_spec ) - click.echo(f"Dev space '{name}' updated successfully") + click.echo(f"Space '{name}' updated successfully") except Exception as e: click.echo(f"Error updating space '{name}': {e}", err=True) @@ -138,7 +138,7 @@ def space_start(name, namespace): body=patch_body ) - click.echo(f"Dev space '{name}' start requested") + click.echo(f"Space '{name}' start requested") except Exception as e: click.echo(f"Error starting space '{name}': {e}", err=True) @@ -159,7 +159,7 @@ def space_stop(name, namespace): body=patch_body ) - click.echo(f"Dev space '{name}' stop requested") + click.echo(f"Space '{name}' stop requested") except Exception as e: click.echo(f"Error stopping space '{name}': {e}", err=True) diff --git a/src/sagemaker/hyperpod/cli/constants/space_constants.py b/src/sagemaker/hyperpod/cli/constants/space_constants.py index 006a9235..0c4d4453 100644 --- a/src/sagemaker/hyperpod/cli/constants/space_constants.py +++ b/src/sagemaker/hyperpod/cli/constants/space_constants.py @@ -10,11 +10,11 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -SPACE_GROUP = "sagemaker.aws.com" +SPACE_GROUP = "workspace.jupyter.org" SPACE_VERSION = "v1alpha1" -SPACE_PLURAL = "spaces" -DEFAULT_SPACE_PORT = "8888" +SPACE_PLURAL = "workspaces" # Immutable fields that cannot be updated after space creation IMMUTABLE_FIELDS = { - "storage_class_name", + "storage", # storage is immutable per Go struct validation + "template_ref", # templateRef is immutable per Go struct validation } \ No newline at end of file diff --git a/src/sagemaker/hyperpod/cli/space_utils.py b/src/sagemaker/hyperpod/cli/space_utils.py index b32071ab..03a497e5 100644 --- a/src/sagemaker/hyperpod/cli/space_utils.py +++ b/src/sagemaker/hyperpod/cli/space_utils.py @@ -46,12 +46,91 @@ def _build_resources(cpu, memory, gpu): if cpu is None and memory is None and gpu is None: return None - default_resources = props["resources"]["default"] + # Build requests dictionary + requests = {} + if cpu is not None: + requests["cpu"] = cpu + if memory is not None: + requests["memory"] = memory + if gpu is not None: + requests["nvidia.com/gpu"] = gpu + + # Return ResourceRequirements structure return { - "cpu": cpu or default_resources["cpu"], - "memory": memory or default_resources["memory"], - "nvidia.com/gpu": gpu or default_resources["nvidia.com/gpu"] + "requests": requests } + + def _parse_volume_param(ctx, param, value): + """Parse volume parameters from command line format to dictionary format.""" + if not value: + return None + + volumes = [] + for i, v in enumerate(value): + try: + # Split by comma and then by equals, with validation + parts = {} + for item in v.split(','): + if '=' not in item: + raise click.UsageError(f"Invalid volume format in volume {i+1}: '{item}' should be key=value") + key, val = item.split('=', 1) # Split only on first '=' to handle values with '=' + # Convert snake_case to match model field names + if key.strip() == 'mount_path': + key = 'mountPath' + elif key.strip() == 'persistent_volume_claim_name': + key = 'persistentVolumeClaimName' + parts[key.strip()] = val.strip() + + volumes.append(parts) + except Exception as e: + raise click.UsageError(f"Error parsing volume {i+1}: {str(e)}") + + return volumes + + def _parse_storage_param(ctx, param, value): + """Parse storage parameters from command line format to dictionary format.""" + if not value: + return None + + try: + parts = {} + for item in value.split(','): + if '=' not in item: + raise click.UsageError(f"Invalid storage format: '{item}' should be key=value") + key, val = item.split('=', 1) + # Convert snake_case to match model field names + if key.strip() == 'storage_class_name': + key = 'storageClassName' + elif key.strip() == 'mount_path': + key = 'mountPath' + parts[key.strip()] = val.strip() + return parts + except Exception as e: + raise click.UsageError(f"Error parsing storage: {str(e)}") + + def _parse_container_config_param(ctx, param, value): + """Parse container config parameters from command line format to dictionary format.""" + if not value: + return None + + try: + parts = {} + for item in value.split(','): + if '=' not in item: + raise click.UsageError(f"Invalid container-config format: '{item}' should be key=value") + key, val = item.split('=', 1) + key = key.strip() + val = val.strip() + + # Handle array fields (command and args) + if key in ['command', 'args']: + parts[key] = [item.strip() for item in val.split(';') if item.strip()] + else: + parts[key] = val + + return parts + except Exception as e: + raise click.UsageError(f"Error parsing container-config: {str(e)}") # 1) the wrapper click will call def wrapped_func(*args, **kwargs): @@ -65,10 +144,41 @@ def wrapped_func(*args, **kwargs): if resources is not None: kwargs["resources"] = resources + volumes = kwargs.pop("volume", None) + if volumes is not None: + kwargs["volumes"] = volumes + + storage = kwargs.pop("storage", None) + if storage is not None: + kwargs["storage"] = storage + + container_config = kwargs.pop("container_config", None) + if container_config is not None: + kwargs["container_config"] = container_config + # filter out None/empty values so Pydantic model defaults apply filtered_kwargs = {} for key, value in kwargs.items(): if value is not None: + # Parse JSON for object/array type parameters + spec = props.get(key, {}) + is_object_type = False + + if spec.get("type") == "object" or spec.get("type") == "array": + is_object_type = True + elif "anyOf" in spec: + # Check if any of the anyOf options is an object/aray type + for option in spec["anyOf"]: + if option.get("type") == "object" or option.get("type") == "array": + is_object_type = True + break + + if isinstance(value, str) and is_object_type: + try: + value = json.loads(value) + except json.JSONDecodeError: + raise click.UsageError(f"Invalid JSON for --{key.replace('_', '-')}: {value}") + filtered_kwargs[key] = value try: @@ -77,12 +187,12 @@ def wrapped_func(*args, **kwargs): except ValidationError as e: error_messages = [] for err in e.errors(): - loc = ".".join(str(x) for x in err["loc"]) + loc = ".".join(str(x).replace('_','-') for x in err["loc"]) msg = err["msg"] error_messages.append(f" – {loc}: {msg}") raise click.UsageError( - f"❌ Configuration validation errors:\n" + "\n".join(error_messages) + f"Configuration validation errors:\n" + "\n".join(error_messages) ) return func(version, domain_config) @@ -109,11 +219,35 @@ def wrapped_func(*args, **kwargs): help="Gpu resource, e.g. '1'", )(wrapped_func) + wrapped_func = click.option( + "--volume", + multiple=True, + callback=_parse_volume_param, + help="Volume configuration. Format: --volume name=,mountPath=,persistentVolumeClaimName=. Use multiple --volume flags for multiple volumes.", + )(wrapped_func) + + # Only add storage option if not in update mode as storage is immutable + if not is_update: + wrapped_func = click.option( + "--storage", + callback=_parse_storage_param, + help="Storage configuration. Format: --storage storageClassName=,size=,mountPath=", + )(wrapped_func) + + wrapped_func = click.option( + "--container-config", + callback=_parse_container_config_param, + help="Container configuration. Format: --container-config command=,args=", + )(wrapped_func) + # Exclude the props that were handled out of the below for loop excluded_props = set( [ "resources", "version", + "volumes", + "storage", + "container_config", ] ) @@ -136,6 +270,8 @@ def wrapped_func(*args, **kwargs): ctype = float elif spec.get("type") == "boolean": ctype = bool + elif spec.get("type") == "object": + ctype = str # JSON string input else: ctype = str diff --git a/test/unit_tests/cli/test_space.py b/test/unit_tests/cli/test_space.py index 111f4770..07851f9b 100644 --- a/test/unit_tests/cli/test_space.py +++ b/test/unit_tests/cli/test_space.py @@ -30,9 +30,10 @@ def test_space_create_success(self, mock_k8s_client_class, mock_load_schema): mock_load_schema.return_value = { "properties": { "name": {"type": "string"}, + "display_name": {"type": "string"}, "namespace": {"type": "string"} }, - "required": ["name"] + "required": ["name", "display_name"] } # Mock model registry @@ -40,6 +41,7 @@ def test_space_create_success(self, mock_k8s_client_class, mock_load_schema): mock_model.return_value = Mock() mock_model.return_value.to_domain.return_value = { "name": "test-space", + "display_name": "Test Space", "namespace": "test-ns", "space_spec": {"spec": {"image": "test-image"}} } @@ -52,11 +54,12 @@ def test_space_create_success(self, mock_k8s_client_class, mock_load_schema): result = self.runner.invoke(space_create, [ '--version', '1.0', '--name', 'test-space', + '--display-name', 'Test Space', '--namespace', 'test-ns' ]) assert result.exit_code == 0 - assert "Dev space 'test-space' created successfully" in result.output + assert "Space 'test-space' created successfully" in result.output mock_k8s_instance.create_space.assert_called_once() @patch('sagemaker.hyperpod.cli.space_utils.load_schema_for_version') @@ -82,6 +85,7 @@ def test_space_create_k8s_error(self, mock_k8s_client_class): mock_model.return_value = Mock() mock_model.return_value.to_domain.return_value = { "name": "test-space", + "display_name": "Test Space", "namespace": "test-ns", "space_spec": {} } @@ -91,13 +95,15 @@ def test_space_create_k8s_error(self, mock_k8s_client_class): mock_load_schema.return_value = { "properties": { "name": {"type": "string"}, + "display_name": {"type": "string"}, "namespace": {"type": "string"} }, - "required": ["name", "namespace"] + "required": ["name", "display_name"] } result = self.runner.invoke(space_create, [ '--version', '1.0', '--name', 'test-space', + '--display-name', 'Test Space', '--namespace', 'test-ns' ]) @@ -245,7 +251,7 @@ def test_space_delete_success(self, mock_k8s_client_class): ]) assert result.exit_code == 0 - assert "Dev space 'test-space' deleted successfully" in result.output + assert "Space 'test-space' deleted successfully" in result.output mock_k8s_instance.delete_space.assert_called_once_with('test-ns', 'test-space') @patch('sagemaker.hyperpod.cli.commands.space.KubernetesClient') @@ -271,6 +277,7 @@ def test_space_update_success(self, mock_k8s_client_class, mock_load_schema): mock_load_schema.return_value = { "properties": { "name": {"type": "string"}, + "display_name": {"type": "string"}, "namespace": {"type": "string"} }, "required": ["name"] @@ -293,11 +300,12 @@ def test_space_update_success(self, mock_k8s_client_class, mock_load_schema): result = self.runner.invoke(space_update, [ '--version', '1.0', '--name', 'test-space', + '--display-name', 'Test Space', '--namespace', 'test-ns' ]) assert result.exit_code == 0 - assert "Dev space 'test-space' updated successfully" in result.output + assert "Space 'test-space' updated successfully" in result.output mock_k8s_instance.patch_space.assert_called_once() @patch('sagemaker.hyperpod.cli.commands.space.KubernetesClient') @@ -320,13 +328,15 @@ def test_space_update_k8_error(self, mock_k8s_client_class): mock_load_schema.return_value = { "properties": { "name": {"type": "string"}, + "display_name": {"type": "string"}, "namespace": {"type": "string"} }, - "required": ["name", "namespace"] + "required": ["name"] } result = self.runner.invoke(space_update, [ '--version', '1.0', '--name', 'test-space', + '--display-name', 'Test Space', '--namespace', 'test-ns' ]) @@ -345,7 +355,7 @@ def test_space_start_success(self, mock_k8s_client_class): ]) assert result.exit_code == 0 - assert "Dev space 'test-space' start requested" in result.output + assert "Space 'test-space' start requested" in result.output mock_k8s_instance.patch_space.assert_called_once_with( namespace='test-ns', name='test-space', @@ -379,7 +389,7 @@ def test_space_stop_success(self, mock_k8s_client_class): ]) assert result.exit_code == 0 - assert "Dev space 'test-space' stop requested" in result.output + assert "Space 'test-space' stop requested" in result.output mock_k8s_instance.patch_space.assert_called_once_with( namespace='test-ns', name='test-space', diff --git a/test/unit_tests/cli/test_space_utils.py b/test/unit_tests/cli/test_space_utils.py index dfe8d389..b8529f14 100644 --- a/test/unit_tests/cli/test_space_utils.py +++ b/test/unit_tests/cli/test_space_utils.py @@ -125,16 +125,16 @@ def cmd(version, domain_config): result = self.runner.invoke(cmd, ['--cpu', '1000m', '--memory', '1Gi']) assert result.exit_code == 0 output = json.loads(result.output) - assert output['cpu'] == '1000m' - assert output['memory'] == '1Gi' - assert output['nvidia.com/gpu'] is None + assert output['requests']['cpu'] == '1000m' + assert output['requests']['memory'] == '1Gi' + assert 'nvidia.com/gpu' not in output['requests'] # Test with only CPU result = self.runner.invoke(cmd, ['--cpu', '750m']) assert result.exit_code == 0 output = json.loads(result.output) - assert output['cpu'] == '750m' - assert output['memory'] == '256Mi' # default + assert output['requests']['cpu'] == '750m' + assert 'memory' not in output['requests'] # Test with no resources specified result = self.runner.invoke(cmd, []) @@ -229,7 +229,8 @@ def test_immutable_fields_excluded_in_update(self, mock_load_schema): schema = { 'properties': { 'name': {'type': 'string'}, - 'storage_class_name': {'type': 'string'}, + 'storage': {'type': 'object'}, # storage is immutable + 'template_ref': {'type': 'string'}, # template_ref is immutable 'image': {'type': 'string'} }, 'required': ['name'] @@ -256,8 +257,9 @@ def cmd(version, domain_config): # Get the command's help to check available options result = self.runner.invoke(cmd, ['--help']) assert result.exit_code == 0 - # storage_class_name should not be available in update mode - assert '--storage-class-name' not in result.output + # storage and template_ref should not be available in update mode + assert '--storage' not in result.output + assert '--template-ref' not in result.output # but other fields should be available assert '--name' in result.output assert '--image' in result.output @@ -361,3 +363,234 @@ def cmd(version, domain_config): print(result.output) assert result.exit_code == 0 assert result.output.strip() == 'success' + + @patch('sagemaker.hyperpod.cli.space_utils.load_schema_for_version') + def test_volume_parsing(self, mock_load_schema): + """Test volume parameter parsing""" + schema = { + 'properties': { + 'name': {'type': 'string'}, + 'volumes': {'type': 'array'} + }, + 'required': ['name'] + } + mock_load_schema.return_value = schema + + class DummyModel: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + def to_domain(self): + return self + + registry = {'1.0': DummyModel} + + @click.command() + @generate_click_command(registry=registry, schema_pkg="hyperpod_space_template") + def cmd(version, domain_config): + click.echo(json.dumps(getattr(domain_config, 'volumes', None))) + + # Test valid volume parsing + result = self.runner.invoke(cmd, [ + '--name', 'test-space', + '--volume', 'name=vol1,mountPath=/data,persistentVolumeClaimName=pvc1' + ]) + assert result.exit_code == 0 + volumes = json.loads(result.output) + assert len(volumes) == 1 + assert volumes[0]['name'] == 'vol1' + assert volumes[0]['mountPath'] == '/data' + assert volumes[0]['persistentVolumeClaimName'] == 'pvc1' + + # Test multiple volumes + result = self.runner.invoke(cmd, [ + '--name', 'test-space', + '--volume', 'name=vol1,mountPath=/data1', + '--volume', 'name=vol2,mountPath=/data2' + ]) + assert result.exit_code == 0 + volumes = json.loads(result.output) + assert len(volumes) == 2 + + # Test invalid volume format + result = self.runner.invoke(cmd, [ + '--name', 'test-space', + '--volume', 'invalid_format' + ]) + assert result.exit_code == 2 + assert 'Invalid volume format' in result.output + + @patch('sagemaker.hyperpod.cli.space_utils.load_schema_for_version') + def test_storage_parsing(self, mock_load_schema): + """Test storage parameter parsing""" + schema = { + 'properties': { + 'name': {'type': 'string'}, + 'storage': {'type': 'object'} + }, + 'required': ['name'] + } + mock_load_schema.return_value = schema + + class DummyModel: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + def to_domain(self): + return self + + registry = {'1.0': DummyModel} + + @click.command() + @generate_click_command(registry=registry, schema_pkg="hyperpod_space_template") + def cmd(version, domain_config): + click.echo(json.dumps(getattr(domain_config, 'storage', None))) + + # Test valid storage parsing + result = self.runner.invoke(cmd, [ + '--name', 'test-space', + '--storage', 'storageClassName=gp2,size=20Gi,mountPath=/data' + ]) + assert result.exit_code == 0 + storage = json.loads(result.output) + assert storage['storageClassName'] == 'gp2' + assert storage['size'] == '20Gi' + assert storage['mountPath'] == '/data' + + # Test invalid storage format + result = self.runner.invoke(cmd, [ + '--name', 'test-space', + '--storage', 'invalid_format' + ]) + assert result.exit_code == 2 + assert 'Invalid storage format' in result.output + + @patch('sagemaker.hyperpod.cli.space_utils.load_schema_for_version') + def test_container_config_parsing_simple(self, mock_load_schema): + """Test container config parameter parsing with simple format""" + schema = { + 'properties': { + 'name': {'type': 'string'}, + 'container_config': {'type': 'object'} + }, + 'required': ['name'] + } + mock_load_schema.return_value = schema + + class DummyModel: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + def to_domain(self): + return self + + registry = {'1.0': DummyModel} + + @click.command() + @generate_click_command(registry=registry, schema_pkg="hyperpod_space_template") + def cmd(version, domain_config): + click.echo(json.dumps(getattr(domain_config, 'container_config', None))) + + # Test valid container config with semicolon format + result = self.runner.invoke(cmd, [ + '--name', 'test-space', + '--container-config', 'command=python;app.py,args=--port;8080' + ]) + assert result.exit_code == 0 + config = json.loads(result.output) + assert config['command'] == ['python', 'app.py'] + assert config['args'] == ['--port', '8080'] + + # Test invalid container config format + result = self.runner.invoke(cmd, [ + '--name', 'test-space', + '--container-config', 'invalid_format' + ]) + assert result.exit_code == 2 + assert 'Invalid container-config format' in result.output + + @patch('sagemaker.hyperpod.cli.space_utils.load_schema_for_version') + def test_json_object_parsing(self, mock_load_schema): + """Test JSON object parameter parsing""" + schema = { + 'properties': { + 'name': {'type': 'string'}, + 'metadata': {'type': 'object'}, + 'tags': {'type': 'array'} + }, + 'required': ['name'] + } + mock_load_schema.return_value = schema + + class DummyModel: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + def to_domain(self): + return self + + registry = {'1.0': DummyModel} + + @click.command() + @generate_click_command(registry=registry, schema_pkg="hyperpod_space_template") + def cmd(version, domain_config): + result = { + 'metadata': getattr(domain_config, 'metadata', None), + 'tags': getattr(domain_config, 'tags', None) + } + click.echo(json.dumps(result)) + + # Test valid JSON object + result = self.runner.invoke(cmd, [ + '--name', 'test-space', + '--metadata', '{"key": "value", "number": 42}', + '--tags', '["tag1", "tag2"]' + ]) + assert result.exit_code == 0 + output = json.loads(result.output) + assert output['metadata']['key'] == 'value' + assert output['metadata']['number'] == 42 + assert output['tags'] == ['tag1', 'tag2'] + + # Test invalid JSON + result = self.runner.invoke(cmd, [ + '--name', 'test-space', + '--metadata', 'invalid json' + ]) + assert result.exit_code == 2 + assert 'Invalid JSON for --metadata' in result.output + + @patch('sagemaker.hyperpod.cli.space_utils.load_schema_for_version') + def test_anyof_type_handling(self, mock_load_schema): + """Test handling of anyOf type specifications""" + schema = { + 'properties': { + 'name': {'type': 'string'}, + 'config': { + 'anyOf': [ + {'type': 'object'}, + {'type': 'null'} + ] + } + }, + 'required': ['name'] + } + mock_load_schema.return_value = schema + + class DummyModel: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + def to_domain(self): + return self + + registry = {'1.0': DummyModel} + + @click.command() + @generate_click_command(registry=registry, schema_pkg="hyperpod_space_template") + def cmd(version, domain_config): + click.echo(json.dumps(getattr(domain_config, 'config', None))) + + # Test with JSON object for anyOf type + result = self.runner.invoke(cmd, [ + '--name', 'test-space', + '--config', '{"setting": "value"}' + ]) + assert result.exit_code == 0 + config = json.loads(result.output) + assert config['setting'] == 'value' diff --git a/test/unit_tests/clients/test_kubernetes_client.py b/test/unit_tests/clients/test_kubernetes_client.py index 5c30a17b..5e466419 100644 --- a/test/unit_tests/clients/test_kubernetes_client.py +++ b/test/unit_tests/clients/test_kubernetes_client.py @@ -707,10 +707,10 @@ def test_create_space(self, mock_create_namespaced_custom_object): test_client.create_space("test-namespace", space_spec) mock_create_namespaced_custom_object.assert_called_once_with( - group="sagemaker.aws.com", + group="workspace.jupyter.org", version="v1alpha1", namespace="test-namespace", - plural="spaces", + plural="workspaces", body=space_spec ) @@ -723,10 +723,10 @@ def test_list_spaces_with_namespace(self, mock_list_namespaced_custom_object): result = test_client.list_spaces("test-namespace") mock_list_namespaced_custom_object.assert_called_once_with( - group="sagemaker.aws.com", + group="workspace.jupyter.org", version="v1alpha1", namespace="test-namespace", - plural="spaces" + plural="workspaces" ) self.assertEqual(result, {"items": []}) @@ -739,9 +739,9 @@ def test_list_spaces_without_namespace(self, mock_list_cluster_custom_object): result = test_client.list_spaces(None) mock_list_cluster_custom_object.assert_called_once_with( - group="sagemaker.aws.com", + group="workspace.jupyter.org", version="v1alpha1", - plural="spaces" + plural="workspaces" ) self.assertEqual(result, {"items": []}) @@ -755,10 +755,10 @@ def test_get_space(self, mock_get_namespaced_custom_object): result = test_client.get_space("test-namespace", "test-space") mock_get_namespaced_custom_object.assert_called_once_with( - group="sagemaker.aws.com", + group="workspace.jupyter.org", version="v1alpha1", namespace="test-namespace", - plural="spaces", + plural="workspaces", name="test-space" ) self.assertEqual(result, mock_space) @@ -772,10 +772,10 @@ def test_delete_space(self, mock_delete_namespaced_custom_object): result = test_client.delete_space("test-namespace", "test-space") mock_delete_namespaced_custom_object.assert_called_once_with( - group="sagemaker.aws.com", + group="workspace.jupyter.org", version="v1alpha1", namespace="test-namespace", - plural="spaces", + plural="workspaces", name="test-space" ) self.assertEqual(result, {}) @@ -790,10 +790,10 @@ def test_patch_space(self, mock_patch_namespaced_custom_object): result = test_client.patch_space("test-namespace", "test-space", patch_body) mock_patch_namespaced_custom_object.assert_called_once_with( - group="sagemaker.aws.com", + group="workspace.jupyter.org", version="v1alpha1", namespace="test-namespace", - plural="spaces", + plural="workspaces", name="test-space", body=patch_body ) From 046d4c8e426e10b2e3fa165298592405efa1ceb6 Mon Sep 17 00:00:00 2001 From: Brian Xia Date: Thu, 30 Oct 2025 11:19:55 -0700 Subject: [PATCH 12/31] Add space_admin_config.py CLI command (#260) * Add space_admin_config.py CLI command * Update the space admin config to space template --------- Co-authored-by: Brian Xia --- .../hyperpod/cli/clients/kubernetes_client.py | 81 ++--- .../hyperpod/cli/commands/space_template.py | 117 +++++++ ...nstants.py => space_template_constants.py} | 6 +- src/sagemaker/hyperpod/cli/hyp_cli.py | 20 +- test/unit_tests/cli/test_space_template.py | 329 ++++++++++++++++++ .../clients/test_kubernetes_client.py | 215 ++++++++++++ 6 files changed, 713 insertions(+), 55 deletions(-) create mode 100644 src/sagemaker/hyperpod/cli/commands/space_template.py rename src/sagemaker/hyperpod/cli/constants/{space_admin_config_constants.py => space_template_constants.py} (80%) create mode 100644 test/unit_tests/cli/test_space_template.py diff --git a/src/sagemaker/hyperpod/cli/clients/kubernetes_client.py b/src/sagemaker/hyperpod/cli/clients/kubernetes_client.py index 18bb07d6..96c92bb7 100644 --- a/src/sagemaker/hyperpod/cli/clients/kubernetes_client.py +++ b/src/sagemaker/hyperpod/cli/clients/kubernetes_client.py @@ -45,10 +45,10 @@ SPACE_VERSION, SPACE_PLURAL, ) -from sagemaker.hyperpod.cli.constants.space_admin_config_constants import ( - SPACE_ADMIN_CONFIG_GROUP, - SPACE_ADMIN_CONFIG_VERSION, - SPACE_ADMIN_CONFIG_PLURAL, +from sagemaker.hyperpod.cli.constants.space_template_constants import ( + SPACE_TEMPLATE_GROUP, + SPACE_TEMPLATE_VERSION, + SPACE_TEMPLATE_PLURAL, ) from sagemaker.hyperpod.cli.constants.space_access_constants import ( SPACE_ACCESS_GROUP, @@ -426,66 +426,51 @@ def patch_space(self, namespace: str, name: str, body: dict): body=body ) - - - # Space Admin Configuration methods - def create_space_admin_config(self, namespace: str, config_spec: dict): - return client.CustomObjectsApi().create_namespaced_custom_object( - group=SPACE_ADMIN_CONFIG_GROUP, - version=SPACE_ADMIN_CONFIG_VERSION, - namespace=namespace, - plural=SPACE_ADMIN_CONFIG_PLURAL, + # Space Template Configuration methods + def create_space_template(self, config_spec: dict): + return client.CustomObjectsApi().create_cluster_custom_object( + group=SPACE_TEMPLATE_GROUP, + version=SPACE_TEMPLATE_VERSION, + plural=SPACE_TEMPLATE_PLURAL, body=config_spec ) - def list_space_admin_configs(self, namespace: str = None): - if namespace: - return client.CustomObjectsApi().list_namespaced_custom_object( - group=SPACE_ADMIN_CONFIG_GROUP, - version=SPACE_ADMIN_CONFIG_VERSION, - namespace=namespace, - plural=SPACE_ADMIN_CONFIG_PLURAL - ) - else: - return client.CustomObjectsApi().list_cluster_custom_object( - group=SPACE_ADMIN_CONFIG_GROUP, - version=SPACE_ADMIN_CONFIG_VERSION, - plural=SPACE_ADMIN_CONFIG_PLURAL - ) + def list_space_templates(self): + return client.CustomObjectsApi().list_cluster_custom_object( + group=SPACE_TEMPLATE_GROUP, + version=SPACE_TEMPLATE_VERSION, + plural=SPACE_TEMPLATE_PLURAL + ) - def get_space_admin_config(self, namespace: str, name: str): - return client.CustomObjectsApi().get_namespaced_custom_object( - group=SPACE_ADMIN_CONFIG_GROUP, - version=SPACE_ADMIN_CONFIG_VERSION, - namespace=namespace, - plural=SPACE_ADMIN_CONFIG_PLURAL, + def get_space_template(self, name: str): + return client.CustomObjectsApi().get_cluster_custom_object( + group=SPACE_TEMPLATE_GROUP, + version=SPACE_TEMPLATE_VERSION, + plural=SPACE_TEMPLATE_PLURAL, name=name ) - def delete_space_admin_config(self, namespace: str, name: str): - return client.CustomObjectsApi().delete_namespaced_custom_object( - group=SPACE_ADMIN_CONFIG_GROUP, - version=SPACE_ADMIN_CONFIG_VERSION, - namespace=namespace, - plural=SPACE_ADMIN_CONFIG_PLURAL, + def delete_space_template(self, name: str): + return client.CustomObjectsApi().delete_cluster_custom_object( + group=SPACE_TEMPLATE_GROUP, + version=SPACE_TEMPLATE_VERSION, + plural=SPACE_TEMPLATE_PLURAL, name=name ) - def patch_space_admin_config(self, namespace: str, name: str, body: dict): - return client.CustomObjectsApi().patch_namespaced_custom_object( - group=SPACE_ADMIN_CONFIG_GROUP, - version=SPACE_ADMIN_CONFIG_VERSION, - namespace=namespace, - plural=SPACE_ADMIN_CONFIG_PLURAL, + def patch_space_template(self, name: str, body: dict): + return client.CustomObjectsApi().patch_cluster_custom_object( + group=SPACE_TEMPLATE_GROUP, + version=SPACE_TEMPLATE_VERSION, + plural=SPACE_TEMPLATE_PLURAL, name=name, body=body ) - def create_space_access(self, namespace: str, config_spec: dict): - return client.CustomObjectsApi().create_namespaced_custom_object( + def create_space_access(self, config_spec: dict): + return client.CustomObjectsApi().create_cluster_custom_object( group=SPACE_ACCESS_GROUP, version=SPACE_ACCESS_VERSION, - namespace=namespace, plural=SPACE_ACCESS_PLURAL, body=config_spec ) diff --git a/src/sagemaker/hyperpod/cli/commands/space_template.py b/src/sagemaker/hyperpod/cli/commands/space_template.py new file mode 100644 index 00000000..92941d79 --- /dev/null +++ b/src/sagemaker/hyperpod/cli/commands/space_template.py @@ -0,0 +1,117 @@ +import click +import json +import yaml +from tabulate import tabulate +from sagemaker.hyperpod.cli.clients.kubernetes_client import KubernetesClient + + +@click.command("hyp-space-template") +@click.option("--file", "-f", required=True, help="YAML file containing the configuration") +def space_template_create(file): + """Create a space-template resource.""" + k8s_client = KubernetesClient() + + try: + with open(file, 'r') as f: + config_data = yaml.safe_load(f) + + k8s_client.create_space_template(config_data) + click.echo(f"Space template '{config_data['metadata']['name']}' created successfully") + except FileNotFoundError: + click.echo(f"Error: File '{file}' not found", err=True) + except yaml.YAMLError as e: + click.echo(f"Error parsing YAML file: {e}", err=True) + except Exception as e: + click.echo(f"Error creating space template: {e}", err=True) + + +@click.command("hyp-space-template") +@click.option("--output", "-o", type=click.Choice(["table", "json"]), default="table") +def space_template_list(output): + """List space-template resources.""" + k8s_client = KubernetesClient() + + try: + resources = k8s_client.list_space_templates() + + if output == "json": + click.echo(json.dumps(resources, indent=2)) + else: + items = resources.get("items", []) + if items: + table_data = [] + for item in items: + table_data.append([ + item["metadata"]["name"], + ]) + click.echo(tabulate(table_data, headers=["NAME"])) + else: + click.echo("No space templates found") + except Exception as e: + click.echo(f"Error listing space templates: {e}", err=True) + + +@click.command("hyp-space-template") +@click.option("--name", required=False, help="Name of the space template") +@click.option("--output", "-o", type=click.Choice(["yaml", "json"]), default="yaml") +def space_template_describe(name, output): + """Describe a space-template resource.""" + k8s_client = KubernetesClient() + + try: + resource = k8s_client.get_space_template(name) + resource["metadata"].pop('managedFields', None) + + if output == "json": + click.echo(json.dumps(resource, indent=2)) + else: + click.echo(yaml.dump(resource, default_flow_style=False)) + except Exception as e: + click.echo(f"Error describing space template '{name}': {e}", err=True) + + +@click.command("hyp-space-template") +@click.option("--name", required=False, help="Name of the space template") +def space_template_delete(name): + """Delete a space-template resource.""" + k8s_client = KubernetesClient() + + try: + k8s_client.delete_space_template(name) + click.echo(f"Space template '{name}' deleted successfully") + except Exception as e: + click.echo(f"Error deleting space template '{name}': {e}", err=True) + + +@click.command("hyp-space-template") +@click.option("--name", required=True, help="Name of the space template") +@click.option("--file", "-f", required=True, help="YAML file containing the updated template") +def space_template_update(name, file): + """Update a space-template resource.""" + k8s_client = KubernetesClient() + + try: + with open(file, 'r') as f: + config_data = yaml.safe_load(f) + + # Validate that the name matches + yaml_name = config_data.get('metadata', {}).get('name') + if yaml_name and yaml_name != name: + click.echo(f"Error: Name mismatch. CLI parameter '{name}' does not match YAML name '{yaml_name}'", err=True) + return + + # Remove immutable fields from the update + if 'metadata' in config_data: + config_data['metadata'].pop('resourceVersion', None) + config_data['metadata'].pop('uid', None) + config_data['metadata'].pop('creationTimestamp', None) + config_data['metadata'].pop('managedFields', None) + + k8s_client.patch_space_template(name, config_data) + click.echo(f"Space template '{name}' updated successfully") + except FileNotFoundError: + click.echo(f"Error: File '{file}' not found", err=True) + except yaml.YAMLError as e: + click.echo(f"Error parsing YAML file: {e}", err=True) + except Exception as e: + click.echo(f"Error updating space template '{name}': {e}", err=True) diff --git a/src/sagemaker/hyperpod/cli/constants/space_admin_config_constants.py b/src/sagemaker/hyperpod/cli/constants/space_template_constants.py similarity index 80% rename from src/sagemaker/hyperpod/cli/constants/space_admin_config_constants.py rename to src/sagemaker/hyperpod/cli/constants/space_template_constants.py index bd793538..664f25b6 100644 --- a/src/sagemaker/hyperpod/cli/constants/space_admin_config_constants.py +++ b/src/sagemaker/hyperpod/cli/constants/space_template_constants.py @@ -11,6 +11,6 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -SPACE_ADMIN_CONFIG_GROUP = "sagemaker.aws.com" -SPACE_ADMIN_CONFIG_VERSION = "v1alpha1" -SPACE_ADMIN_CONFIG_PLURAL = "spaceadminconfigs" +SPACE_TEMPLATE_GROUP = "workspace.jupyter.org" +SPACE_TEMPLATE_VERSION = "v1alpha1" +SPACE_TEMPLATE_PLURAL = "workspacetemplates" diff --git a/src/sagemaker/hyperpod/cli/hyp_cli.py b/src/sagemaker/hyperpod/cli/hyp_cli.py index 3904cc50..ec60f303 100644 --- a/src/sagemaker/hyperpod/cli/hyp_cli.py +++ b/src/sagemaker/hyperpod/cli/hyp_cli.py @@ -48,6 +48,13 @@ space_stop, space_get_logs, ) +from sagemaker.hyperpod.cli.commands.space_template import ( + space_template_create, + space_template_list, + space_template_describe, + space_template_delete, + space_template_update, +) from sagemaker.hyperpod.cli.commands.init import ( init, @@ -123,23 +130,23 @@ def create(): @cli.group(cls=CLICommand) def list(): - """List endpoints, pytorch jobs, cluster stacks or spaces.""" + """List endpoints, pytorch jobs, cluster stacks, spaces, and space templates.""" pass @cli.group(cls=CLICommand) def describe(): - """Describe endpoints, pytorch jobs or cluster stacks, spaces or space admin configs.""" + """Describe endpoints, pytorch jobs or cluster stacks, spaces or space template.""" pass @cli.group(cls=CLICommand) def update(): - """Update an existing HyperPod cluster configuration, space, or space admin config.""" + """Update an existing HyperPod cluster configuration, space, or space template.""" pass @cli.group(cls=CLICommand) def delete(): - """Delete endpoints, pytorch jobs, space, space access or space admin config.""" + """Delete endpoints, pytorch jobs, space, space access or space template.""" pass @@ -200,12 +207,14 @@ def exec(): _default_create.hidden = True create.add_command(_default_create) create.add_command(space_create) +create.add_command(space_template_create) list.add_command(list_jobs) list.add_command(js_list) list.add_command(custom_list) list.add_command(list_cluster_stacks) list.add_command(space_list) +list.add_command(space_template_list) describe.add_command(pytorch_describe) describe.add_command(js_describe) @@ -214,15 +223,18 @@ def exec(): describe.add_command(describe_cluster) describe.add_command(space_describe) +describe.add_command(space_template_describe) update.add_command(update_cluster) update.add_command(space_update) +update.add_command(space_template_update) delete.add_command(pytorch_delete) delete.add_command(js_delete) delete.add_command(custom_delete) delete.add_command(delete_cluster_stack) delete.add_command(space_delete) +delete.add_command(space_template_delete) start.add_command(space_start) diff --git a/test/unit_tests/cli/test_space_template.py b/test/unit_tests/cli/test_space_template.py new file mode 100644 index 00000000..e468c85d --- /dev/null +++ b/test/unit_tests/cli/test_space_template.py @@ -0,0 +1,329 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +import json +import unittest +import yaml +from unittest.mock import Mock, patch, mock_open +from click.testing import CliRunner + +from sagemaker.hyperpod.cli.commands.space_template import ( + space_template_create, + space_template_list, + space_template_describe, + space_template_delete, + space_template_update, +) + + +class TestSpaceTemplateCommands(unittest.TestCase): + def setUp(self): + self.runner = CliRunner() + self.mock_config_data = { + "apiVersion": "workspace.jupyter.org/v1alpha1", + "kind": "WorkspaceTemplate", + "metadata": {"name": "test-template"}, + "spec": {"displayName": "Test Template"} + } + + @patch("sagemaker.hyperpod.cli.commands.space_template.KubernetesClient") + @patch("builtins.open", new_callable=mock_open, read_data="test: data") + @patch("yaml.safe_load") + def test_space_template_create_success(self, mock_yaml_load, mock_file, mock_k8s_client): + """Test successful space template creation""" + mock_yaml_load.return_value = self.mock_config_data + mock_client_instance = Mock() + mock_k8s_client.return_value = mock_client_instance + + result = self.runner.invoke(space_template_create, ["--file", "test.yaml"]) + + self.assertEqual(result.exit_code, 0) + self.assertIn("Space template 'test-template' created successfully", result.output) + mock_client_instance.create_space_template.assert_called_once_with(self.mock_config_data) + + @patch("sagemaker.hyperpod.cli.commands.space_template.KubernetesClient") + def test_space_template_create_file_not_found(self, mock_k8s_client): + """Test space template creation with missing file""" + result = self.runner.invoke(space_template_create, ["--file", "nonexistent.yaml"]) + + self.assertEqual(result.exit_code, 0) + self.assertIn("Error: File 'nonexistent.yaml' not found", result.output) + + @patch("sagemaker.hyperpod.cli.commands.space_template.KubernetesClient") + @patch("builtins.open", new_callable=mock_open, read_data="invalid: yaml: content:") + @patch("yaml.safe_load") + def test_space_template_create_yaml_error(self, mock_yaml_load, mock_file, mock_k8s_client): + """Test space template creation with YAML parsing error""" + mock_yaml_load.side_effect = yaml.YAMLError("Invalid YAML") + + result = self.runner.invoke(space_template_create, ["--file", "test.yaml"]) + + self.assertEqual(result.exit_code, 0) + self.assertIn("Error parsing YAML file: Invalid YAML", result.output) + + @patch("sagemaker.hyperpod.cli.commands.space_template.KubernetesClient") + @patch("builtins.open", new_callable=mock_open, read_data="test: data") + @patch("yaml.safe_load") + def test_space_template_create_k8s_error(self, mock_yaml_load, mock_file, mock_k8s_client): + """Test space template creation with Kubernetes error""" + mock_yaml_load.return_value = self.mock_config_data + mock_client_instance = Mock() + mock_k8s_client.return_value = mock_client_instance + mock_client_instance.create_space_template.side_effect = Exception("K8s error") + + result = self.runner.invoke(space_template_create, ["--file", "test.yaml"]) + + self.assertEqual(result.exit_code, 0) + self.assertIn("Error creating space template: K8s error", result.output) + + @patch("sagemaker.hyperpod.cli.commands.space_template.KubernetesClient") + def test_space_template_list_table_output(self, mock_k8s_client): + """Test space template list with table output""" + mock_client_instance = Mock() + mock_k8s_client.return_value = mock_client_instance + mock_client_instance.list_space_templates.return_value = { + "items": [ + {"metadata": {"name": "template1"}}, + {"metadata": {"name": "template2"}} + ] + } + + result = self.runner.invoke(space_template_list, ["--output", "table"]) + + self.assertEqual(result.exit_code, 0) + self.assertIn("template1", result.output) + self.assertIn("template2", result.output) + self.assertIn("NAME", result.output) + + @patch("sagemaker.hyperpod.cli.commands.space_template.KubernetesClient") + def test_space_template_list_json_output(self, mock_k8s_client): + """Test space template list with JSON output""" + mock_client_instance = Mock() + mock_k8s_client.return_value = mock_client_instance + mock_resources = { + "items": [ + {"metadata": {"name": "template1"}}, + {"metadata": {"name": "template2"}} + ] + } + mock_client_instance.list_space_templates.return_value = mock_resources + + result = self.runner.invoke(space_template_list, ["--output", "json"]) + + self.assertEqual(result.exit_code, 0) + output_json = json.loads(result.output) + self.assertEqual(output_json, mock_resources) + + @patch("sagemaker.hyperpod.cli.commands.space_template.KubernetesClient") + def test_space_template_list_empty(self, mock_k8s_client): + """Test space template list with no templates""" + mock_client_instance = Mock() + mock_k8s_client.return_value = mock_client_instance + mock_client_instance.list_space_templates.return_value = {"items": []} + + result = self.runner.invoke(space_template_list, ["--output", "table"]) + + self.assertEqual(result.exit_code, 0) + self.assertIn("No space templates found", result.output) + + @patch("sagemaker.hyperpod.cli.commands.space_template.KubernetesClient") + def test_space_template_list_error(self, mock_k8s_client): + """Test space template list with error""" + mock_client_instance = Mock() + mock_k8s_client.return_value = mock_client_instance + mock_client_instance.list_space_templates.side_effect = Exception("List error") + + result = self.runner.invoke(space_template_list) + + self.assertEqual(result.exit_code, 0) + self.assertIn("Error listing space templates: List error", result.output) + + @patch("sagemaker.hyperpod.cli.commands.space_template.KubernetesClient") + def test_space_template_describe_yaml_output(self, mock_k8s_client): + """Test space template describe with YAML output""" + mock_client_instance = Mock() + mock_k8s_client.return_value = mock_client_instance + mock_resource = { + "metadata": { + "name": "test-template", + "managedFields": [{"manager": "kubectl"}] + }, + "spec": {"displayName": "Test Template"} + } + mock_client_instance.get_space_template.return_value = mock_resource + + result = self.runner.invoke(space_template_describe, ["--name", "test-template", "--output", "yaml"]) + + self.assertEqual(result.exit_code, 0) + self.assertIn("name: test-template", result.output) + self.assertIn("displayName: Test Template", result.output) + # managedFields should be removed + self.assertNotIn("managedFields", result.output) + + @patch("sagemaker.hyperpod.cli.commands.space_template.KubernetesClient") + def test_space_template_describe_json_output(self, mock_k8s_client): + """Test space template describe with JSON output""" + mock_client_instance = Mock() + mock_k8s_client.return_value = mock_client_instance + mock_resource = { + "metadata": { + "name": "test-template", + "managedFields": [{"manager": "kubectl"}] + }, + "spec": {"displayName": "Test Template"} + } + mock_client_instance.get_space_template.return_value = mock_resource + + result = self.runner.invoke(space_template_describe, ["--name", "test-template", "--output", "json"]) + + self.assertEqual(result.exit_code, 0) + output_json = json.loads(result.output) + self.assertEqual(output_json["metadata"]["name"], "test-template") + self.assertNotIn("managedFields", output_json["metadata"]) + + @patch("sagemaker.hyperpod.cli.commands.space_template.KubernetesClient") + def test_space_template_describe_error(self, mock_k8s_client): + """Test space template describe with error""" + mock_client_instance = Mock() + mock_k8s_client.return_value = mock_client_instance + mock_client_instance.get_space_template.side_effect = Exception("Not found") + + result = self.runner.invoke(space_template_describe, ["--name", "nonexistent"]) + + self.assertEqual(result.exit_code, 0) + self.assertIn("Error describing space template 'nonexistent': Not found", result.output) + + @patch("sagemaker.hyperpod.cli.commands.space_template.KubernetesClient") + def test_space_template_delete_success(self, mock_k8s_client): + """Test successful space template deletion""" + mock_client_instance = Mock() + mock_k8s_client.return_value = mock_client_instance + + result = self.runner.invoke(space_template_delete, ["--name", "test-template"]) + + self.assertEqual(result.exit_code, 0) + self.assertIn("Space template 'test-template' deleted successfully", result.output) + mock_client_instance.delete_space_template.assert_called_once_with("test-template") + + @patch("sagemaker.hyperpod.cli.commands.space_template.KubernetesClient") + def test_space_template_delete_error(self, mock_k8s_client): + """Test space template deletion with error""" + mock_client_instance = Mock() + mock_k8s_client.return_value = mock_client_instance + mock_client_instance.delete_space_template.side_effect = Exception("Delete error") + + result = self.runner.invoke(space_template_delete, ["--name", "test-template"]) + + self.assertEqual(result.exit_code, 0) + self.assertIn("Error deleting space template 'test-template': Delete error", result.output) + + @patch("sagemaker.hyperpod.cli.commands.space_template.KubernetesClient") + @patch("builtins.open", new_callable=mock_open, read_data="test: data") + @patch("yaml.safe_load") + def test_space_template_update_success(self, mock_yaml_load, mock_file, mock_k8s_client): + """Test successful space template update""" + mock_yaml_load.return_value = self.mock_config_data + mock_client_instance = Mock() + mock_k8s_client.return_value = mock_client_instance + + result = self.runner.invoke(space_template_update, ["--name", "test-template", "--file", "test.yaml"]) + + self.assertEqual(result.exit_code, 0) + self.assertIn("Space template 'test-template' updated successfully", result.output) + mock_client_instance.patch_space_template.assert_called_once() + + @patch("sagemaker.hyperpod.cli.commands.space_template.KubernetesClient") + @patch("builtins.open", new_callable=mock_open, read_data="test: data") + @patch("yaml.safe_load") + def test_space_template_update_name_mismatch(self, mock_yaml_load, mock_file, mock_k8s_client): + """Test space template update with name mismatch""" + config_with_different_name = self.mock_config_data.copy() + config_with_different_name["metadata"]["name"] = "different-name" + mock_yaml_load.return_value = config_with_different_name + mock_client_instance = Mock() + mock_k8s_client.return_value = mock_client_instance + + result = self.runner.invoke(space_template_update, ["--name", "test-template", "--file", "test.yaml"]) + + self.assertEqual(result.exit_code, 0) + self.assertIn("Error: Name mismatch. CLI parameter 'test-template' does not match YAML name 'different-name'", result.output) + mock_client_instance.patch_space_template.assert_not_called() + + @patch("sagemaker.hyperpod.cli.commands.space_template.KubernetesClient") + def test_space_template_update_file_not_found(self, mock_k8s_client): + """Test space template update with missing file""" + result = self.runner.invoke(space_template_update, ["--name", "test-template", "--file", "nonexistent.yaml"]) + + self.assertEqual(result.exit_code, 0) + self.assertIn("Error: File 'nonexistent.yaml' not found", result.output) + + @patch("sagemaker.hyperpod.cli.commands.space_template.KubernetesClient") + @patch("builtins.open", new_callable=mock_open, read_data="invalid: yaml: content:") + @patch("yaml.safe_load") + def test_space_template_update_yaml_error(self, mock_yaml_load, mock_file, mock_k8s_client): + """Test space template update with YAML parsing error""" + mock_yaml_load.side_effect = yaml.YAMLError("Invalid YAML") + + result = self.runner.invoke(space_template_update, ["--name", "test-template", "--file", "test.yaml"]) + + self.assertEqual(result.exit_code, 0) + self.assertIn("Error parsing YAML file: Invalid YAML", result.output) + + @patch("sagemaker.hyperpod.cli.commands.space_template.KubernetesClient") + @patch("builtins.open", new_callable=mock_open, read_data="test: data") + @patch("yaml.safe_load") + def test_space_template_update_k8s_error(self, mock_yaml_load, mock_file, mock_k8s_client): + """Test space template update with Kubernetes error""" + mock_yaml_load.return_value = self.mock_config_data + mock_client_instance = Mock() + mock_k8s_client.return_value = mock_client_instance + mock_client_instance.patch_space_template.side_effect = Exception("K8s error") + + result = self.runner.invoke(space_template_update, ["--name", "test-template", "--file", "test.yaml"]) + + self.assertEqual(result.exit_code, 0) + self.assertIn("Error updating space template 'test-template': K8s error", result.output) + + @patch("sagemaker.hyperpod.cli.commands.space_template.KubernetesClient") + @patch("builtins.open", new_callable=mock_open, read_data="test: data") + @patch("yaml.safe_load") + def test_space_template_update_removes_immutable_fields(self, mock_yaml_load, mock_file, mock_k8s_client): + """Test space template update removes immutable fields""" + config_with_immutable_fields = { + "metadata": { + "name": "test-template", + "resourceVersion": "12345", + "uid": "abc-123", + "creationTimestamp": "2023-01-01T00:00:00Z", + "managedFields": [{"manager": "kubectl"}] + }, + "spec": {"displayName": "Test Template"} + } + mock_yaml_load.return_value = config_with_immutable_fields + mock_client_instance = Mock() + mock_k8s_client.return_value = mock_client_instance + + result = self.runner.invoke(space_template_update, ["--name", "test-template", "--file", "test.yaml"]) + + self.assertEqual(result.exit_code, 0) + # Verify patch was called with cleaned config + call_args = mock_client_instance.patch_space_template.call_args[0][1] + self.assertNotIn("resourceVersion", call_args["metadata"]) + self.assertNotIn("uid", call_args["metadata"]) + self.assertNotIn("creationTimestamp", call_args["metadata"]) + self.assertNotIn("managedFields", call_args["metadata"]) + self.assertEqual(call_args["metadata"]["name"], "test-template") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_tests/clients/test_kubernetes_client.py b/test/unit_tests/clients/test_kubernetes_client.py index 5e466419..baa0670a 100644 --- a/test/unit_tests/clients/test_kubernetes_client.py +++ b/test/unit_tests/clients/test_kubernetes_client.py @@ -799,3 +799,218 @@ def test_patch_space(self, mock_patch_namespaced_custom_object): ) self.assertEqual(result, {}) + @patch("kubernetes.client.CustomObjectsApi.create_cluster_custom_object") + def test_create_space_template(self, mock_create_cluster_custom_object): + """Test creating a space template""" + test_client = KubernetesClient() + config_spec = { + "apiVersion": "workspace.jupyter.org/v1alpha1", + "kind": "WorkspaceTemplate", + "metadata": {"name": "test-template"}, + "spec": {"displayName": "Test Template"} + } + mock_create_cluster_custom_object.return_value = config_spec + + result = test_client.create_space_template(config_spec) + + mock_create_cluster_custom_object.assert_called_once_with( + group="workspace.jupyter.org", + version="v1alpha1", + plural="workspacetemplates", + body=config_spec + ) + self.assertEqual(result, config_spec) + + @patch("kubernetes.client.CustomObjectsApi.create_cluster_custom_object") + def test_create_space_template_api_error(self, mock_create_cluster_custom_object): + """Test creating a space template with API error""" + test_client = KubernetesClient() + from kubernetes.client.rest import ApiException + config_spec = {"metadata": {"name": "test-template"}} + mock_create_cluster_custom_object.side_effect = ApiException(status=400, reason="Bad Request") + + with self.assertRaises(ApiException): + test_client.create_space_template(config_spec) + + @patch("kubernetes.client.CustomObjectsApi.create_cluster_custom_object") + def test_create_space_template_with_complex_spec(self, mock_create_cluster_custom_object): + """Test creating a space template with complex specification""" + test_client = KubernetesClient() + config_spec = { + "apiVersion": "workspace.jupyter.org/v1alpha1", + "kind": "WorkspaceTemplate", + "metadata": { + "name": "production-template", + "labels": {"environment": "prod"} + }, + "spec": { + "displayName": "Production Template", + "description": "Template for production workloads", + "defaultImage": "jupyter/scipy-notebook:latest", + "allowedImages": ["jupyter/scipy-notebook:latest", "jupyter/datascience-notebook:latest"], + "defaultResources": { + "requests": {"cpu": "200m", "memory": "256Mi"}, + "limits": {"cpu": "500m", "memory": "512Mi"} + }, + "resourceBounds": { + "cpu": {"min": "100m", "max": "2"}, + "memory": {"min": "128Mi", "max": "4Gi"} + } + } + } + mock_create_cluster_custom_object.return_value = config_spec + + result = test_client.create_space_template(config_spec) + + mock_create_cluster_custom_object.assert_called_once_with( + group="workspace.jupyter.org", + version="v1alpha1", + plural="workspacetemplates", + body=config_spec + ) + self.assertEqual(result, config_spec) + + @patch("kubernetes.client.CustomObjectsApi.list_cluster_custom_object") + def test_list_space_templates(self, mock_list_cluster_custom_object): + """Test listing space templates""" + test_client = KubernetesClient() + mock_templates = { + "items": [ + {"metadata": {"name": "template1"}}, + {"metadata": {"name": "template2"}} + ] + } + mock_list_cluster_custom_object.return_value = mock_templates + + result = test_client.list_space_templates() + + mock_list_cluster_custom_object.assert_called_once_with( + group="workspace.jupyter.org", + version="v1alpha1", + plural="workspacetemplates" + ) + self.assertEqual(result, mock_templates) + + @patch("kubernetes.client.CustomObjectsApi.list_cluster_custom_object") + def test_list_space_templates_empty(self, mock_list_cluster_custom_object): + """Test listing space templates when none exist""" + test_client = KubernetesClient() + mock_list_cluster_custom_object.return_value = {"items": []} + + result = test_client.list_space_templates() + + mock_list_cluster_custom_object.assert_called_once_with( + group="workspace.jupyter.org", + version="v1alpha1", + plural="workspacetemplates" + ) + self.assertEqual(result, {"items": []}) + + @patch("kubernetes.client.CustomObjectsApi.get_cluster_custom_object") + def test_get_space_template(self, mock_get_cluster_custom_object): + """Test getting a specific space template""" + test_client = KubernetesClient() + mock_template = { + "metadata": {"name": "test-template"}, + "spec": {"displayName": "Test Template"} + } + mock_get_cluster_custom_object.return_value = mock_template + + result = test_client.get_space_template("test-template") + + mock_get_cluster_custom_object.assert_called_once_with( + group="workspace.jupyter.org", + version="v1alpha1", + plural="workspacetemplates", + name="test-template" + ) + self.assertEqual(result, mock_template) + + @patch("kubernetes.client.CustomObjectsApi.get_cluster_custom_object") + def test_get_space_template_not_found(self, mock_get_cluster_custom_object): + """Test getting a space template that doesn't exist""" + test_client = KubernetesClient() + from kubernetes.client.rest import ApiException + mock_get_cluster_custom_object.side_effect = ApiException(status=404, reason="Not Found") + + with self.assertRaises(ApiException): + test_client.get_space_template("nonexistent-template") + + @patch("kubernetes.client.CustomObjectsApi.delete_cluster_custom_object") + def test_delete_space_template(self, mock_delete_cluster_custom_object): + """Test deleting a space template""" + test_client = KubernetesClient() + mock_delete_cluster_custom_object.return_value = {} + + result = test_client.delete_space_template("test-template") + + mock_delete_cluster_custom_object.assert_called_once_with( + group="workspace.jupyter.org", + version="v1alpha1", + plural="workspacetemplates", + name="test-template" + ) + self.assertEqual(result, {}) + + @patch("kubernetes.client.CustomObjectsApi.delete_cluster_custom_object") + def test_delete_space_template_not_found(self, mock_delete_cluster_custom_object): + """Test deleting a space template that doesn't exist""" + test_client = KubernetesClient() + from kubernetes.client.rest import ApiException + mock_delete_cluster_custom_object.side_effect = ApiException(status=404, reason="Not Found") + + with self.assertRaises(ApiException): + test_client.delete_space_template("nonexistent-template") + + @patch("kubernetes.client.CustomObjectsApi.patch_cluster_custom_object") + def test_patch_space_template_success(self, mock_patch_cluster_custom_object): + """Test successful space template patch""" + test_client = KubernetesClient() + patch_body = { + "spec": { + "displayName": "Updated Template", + "description": "Updated description" + } + } + mock_patch_cluster_custom_object.return_value = patch_body + + result = test_client.patch_space_template("test-template", patch_body) + + mock_patch_cluster_custom_object.assert_called_once_with( + group="workspace.jupyter.org", + version="v1alpha1", + plural="workspacetemplates", + name="test-template", + body=patch_body + ) + self.assertEqual(result, patch_body) + + @patch("kubernetes.client.CustomObjectsApi.patch_cluster_custom_object") + def test_patch_space_template_with_complex_body(self, mock_patch_cluster_custom_object): + """Test space template patch with complex body""" + test_client = KubernetesClient() + patch_body = { + "metadata": { + "labels": {"environment": "production", "version": "v2"} + }, + "spec": { + "displayName": "Production Template v2", + "description": "Updated production template", + "defaultResources": { + "requests": {"cpu": "500m", "memory": "1Gi"}, + "limits": {"cpu": "1", "memory": "2Gi"} + } + } + } + mock_patch_cluster_custom_object.return_value = patch_body + + result = test_client.patch_space_template("production-template", patch_body) + + mock_patch_cluster_custom_object.assert_called_once_with( + group="workspace.jupyter.org", + version="v1alpha1", + plural="workspacetemplates", + name="production-template", + body=patch_body + ) + self.assertEqual(result, patch_body) From 3010df12a470c2adc32e29027235a98b21e042a6 Mon Sep 17 00:00:00 2001 From: aws-brianxia Date: Sun, 2 Nov 2025 18:11:54 -0800 Subject: [PATCH 13/31] Implement CRUD operations for Space PySDK (#267) * Implement CRUD operations for Space PySDK * Update Space PySDK per new schema * Update Space PySDK per new schema --- src/sagemaker/hyperpod/cli/space_utils.py | 4 + src/sagemaker/hyperpod/space/__init__.py | 20 + .../hyperpod/space/hyperpod_space.py | 354 +++++++++++ src/sagemaker/hyperpod/space/utils.py | 57 ++ test/unit_tests/test_hyperpod_space.py | 571 ++++++++++++++++++ test/unit_tests/test_space_utils.py | 76 +++ 6 files changed, 1082 insertions(+) create mode 100644 src/sagemaker/hyperpod/space/__init__.py create mode 100644 src/sagemaker/hyperpod/space/hyperpod_space.py create mode 100644 src/sagemaker/hyperpod/space/utils.py create mode 100644 test/unit_tests/test_hyperpod_space.py create mode 100644 test/unit_tests/test_space_utils.py diff --git a/src/sagemaker/hyperpod/cli/space_utils.py b/src/sagemaker/hyperpod/cli/space_utils.py index 03a497e5..b87c4864 100644 --- a/src/sagemaker/hyperpod/cli/space_utils.py +++ b/src/sagemaker/hyperpod/cli/space_utils.py @@ -254,6 +254,10 @@ def wrapped_func(*args, **kwargs): # 3) auto-inject all schema.json fields reqs = set(schema.get("required", [])) + # Make display_name optional for update operation + if is_update and "display_name" in reqs: + reqs.remove("display_name") + for name, spec in reversed(list(props.items())): if name in excluded_props: continue diff --git a/src/sagemaker/hyperpod/space/__init__.py b/src/sagemaker/hyperpod/space/__init__.py new file mode 100644 index 00000000..c7ed320d --- /dev/null +++ b/src/sagemaker/hyperpod/space/__init__.py @@ -0,0 +1,20 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +from sagemaker.hyperpod.space.hyperpod_space import HPSpace +from hyperpod_space_template.v1_0.model import SpaceConfig + +__all__ = [ + "HPSpace", + "SpaceConfig", +] diff --git a/src/sagemaker/hyperpod/space/hyperpod_space.py b/src/sagemaker/hyperpod/space/hyperpod_space.py new file mode 100644 index 00000000..5982e1b7 --- /dev/null +++ b/src/sagemaker/hyperpod/space/hyperpod_space.py @@ -0,0 +1,354 @@ +import logging +import yaml +from typing import List, Optional, ClassVar, Dict +from pydantic import BaseModel, Field, ConfigDict +from kubernetes import client, config +from kubernetes.client.rest import ApiException + +from sagemaker.hyperpod.common.config.metadata import Metadata +from sagemaker.hyperpod.common.utils import ( + handle_exception, + get_default_namespace, + setup_logging, + verify_kubernetes_version_compatibility +) +from sagemaker.hyperpod.space.utils import map_kubernetes_response_to_model +from sagemaker.hyperpod.common.telemetry.telemetry_logging import ( + _hyperpod_telemetry_emitter, +) +from sagemaker.hyperpod.common.telemetry.constants import Feature +from sagemaker.hyperpod.cli.constants.space_constants import ( + SPACE_GROUP, + SPACE_VERSION, + SPACE_PLURAL, +) +from hyperpod_space_template.v1_0.model import SpaceConfig + + +class HPSpace(BaseModel): + """HyperPod Space on Amazon SageMaker HyperPod clusters. + + This class provides methods to create, manage, and monitor spaces + on SageMaker HyperPod clusters orchestrated by Amazon EKS. + """ + + is_kubeconfig_loaded: ClassVar[bool] = False + model_config = ConfigDict(extra="forbid") + + config: SpaceConfig = Field( + description="The space configuration using the template model" + ) + + @classmethod + def get_logger(cls): + """Get logger for the class.""" + return logging.getLogger(__name__) + + @classmethod + def verify_kube_config(cls): + """Verify and load Kubernetes configuration.""" + if not cls.is_kubeconfig_loaded: + try: + config.load_kube_config() + cls.is_kubeconfig_loaded = True + verify_kubernetes_version_compatibility(cls.get_logger()) + except Exception as e: + raise RuntimeError(f"Failed to load kubeconfig: {e}") + + def space_exists(self): + """Check if the space already exists""" + custom_api = client.CustomObjectsApi() + try: + custom_api.get_namespaced_custom_object( + group=SPACE_GROUP, + version=SPACE_VERSION, + namespace=self.config.namespace, + plural=SPACE_PLURAL, + name=self.config.name + ) + return True + except ApiException as e: + if e.status == 404: + return False + # re-raise if exception is not 404 (Not found) + raise + + @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "create_space") + def create(self, debug: bool = False): + """Create and submit the HyperPod Space to the Kubernetes cluster. + + Args: + debug (bool, optional): Enable debug logging. Defaults to False. + + Raises: + Exception: If the space creation fails or Kubernetes API call fails + """ + self.verify_kube_config() + + logger = self.get_logger() + logger = setup_logging(logger, debug) + + if self.space_exists(): + logger.info(f"HyperPod Space '{self.config.name}' already exists in namespace '{self.config.namespace}'") + return + + # Convert config to domain model + domain_config = self.config.to_domain() + config_body = domain_config["space_spec"] + + logger.debug( + "Creating HyperPod Space with config:\n%s", + yaml.dump(config_body), + ) + + custom_api = client.CustomObjectsApi() + + try: + custom_api.create_namespaced_custom_object( + group=SPACE_GROUP, + version=SPACE_VERSION, + namespace=self.config.namespace, + plural=SPACE_PLURAL, + body=config_body, + ) + logger.info(f"Successfully created HyperPod Space '{self.config.name}'!") + except Exception as e: + logger.error(f"Failed to create HyperPod Space {self.config.name}!") + handle_exception(e, self.config.name, self.config.namespace) + + @classmethod + @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "list_spaces") + def list(cls, namespace: Optional[str] = None) -> List["HPSpace"]: + """List all HyperPod Spaces in the specified namespace. + + Args: + namespace (str, optional): The Kubernetes namespace to list spaces from. + If None, uses the default namespace from current context. + + Returns: + List[HPSpace]: List of HPSpace instances found in the namespace + + Raises: + Exception: If the Kubernetes API call fails or spaces cannot be retrieved + """ + cls.verify_kube_config() + + if not namespace: + namespace = get_default_namespace() + + custom_api = client.CustomObjectsApi() + + try: + response = custom_api.list_namespaced_custom_object( + group=SPACE_GROUP, + version=SPACE_VERSION, + namespace=namespace, + plural=SPACE_PLURAL + ) + + spaces = [] + for item in response.get("items", []): + # Create SpaceConfig from the Kubernetes resource + spec = item.get("spec", {}) + config_data = { + "name": item["metadata"]["name"], + "namespace": item["metadata"]["namespace"], + } + + config_data = map_kubernetes_response_to_model(item, SpaceConfig) + space_config = SpaceConfig(**config_data) + + space = cls( + config=space_config, + ) + spaces.append(space) + + return spaces + except Exception as e: + handle_exception(e, "list", namespace) + + @classmethod + @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "get_space") + def get(cls, name: str, namespace: str = "default") -> "HPSpace": + """Get a specific HyperPod Space by name. + + Args: + name (str): The name of the space to retrieve + namespace (str, optional): The Kubernetes namespace. Defaults to "default". + + Returns: + HPSpace: The space instance + + Raises: + Exception: If the space is not found or Kubernetes API call fails + """ + cls.verify_kube_config() + + custom_api = client.CustomObjectsApi() + + try: + response = custom_api.get_namespaced_custom_object( + group=SPACE_GROUP, + version=SPACE_VERSION, + namespace=namespace, + plural=SPACE_PLURAL, + name=name + ) + + # Use dynamic mapping based on SpaceConfig model + config_data = map_kubernetes_response_to_model(response, SpaceConfig) + + space_config = SpaceConfig(**config_data) + + return cls( + config=space_config, + ) + except Exception as e: + handle_exception(e, name, namespace) + + @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "delete_space") + def delete(self): + """Delete the HyperPod Space from the Kubernetes cluster. + + Raises: + Exception: If the deletion fails or Kubernetes API call fails + """ + self.verify_kube_config() + logger = self.get_logger() + + if not self.space_exists(): + logger.info(f"HyperPod Space '{self.config.name}' does not exist in namespace '{self.config.namespace}'") + return + + custom_api = client.CustomObjectsApi() + + try: + custom_api.delete_namespaced_custom_object( + group=SPACE_GROUP, + version=SPACE_VERSION, + namespace=self.config.namespace, + plural=SPACE_PLURAL, + name=self.config.name + ) + logger.info(f"Successfully deleted HyperPod Space '{self.config.name}'!") + except Exception as e: + logger.error(f"Failed to delete HyperPod Space {self.config.name}!") + handle_exception(e, self.config.name, self.config.namespace) + + @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "update_space") + def update(self, **kwargs): + """Update the HyperPod Space configuration. + + Args: + **kwargs: Configuration fields to update (e.g., desired_status="Stopped") + + Raises: + Exception: If the update fails or Kubernetes API call fails + """ + self.verify_kube_config() + logger = self.get_logger() + + if not self.space_exists(): + logger.info(f"HyperPod Space '{self.config.name}' does not exist in namespace '{self.config.namespace}'") + return + + custom_api = client.CustomObjectsApi() + + # Update the local config + for key, value in kwargs.items(): + if hasattr(self.config, key): + setattr(self.config, key, value) + + # Convert to domain model and extract spec + domain_config = self.config.to_domain() + spec_updates = domain_config["space_spec"]["spec"] + + try: + custom_api.patch_namespaced_custom_object( + group=SPACE_GROUP, + version=SPACE_VERSION, + namespace=self.config.namespace, + plural=SPACE_PLURAL, + name=self.config.name, + body={"spec": spec_updates} + ) + logger.info(f"Successfully updated HyperPod Space '{self.config.name}'!") + except Exception as e: + logger.error(f"Failed to update HyperPod Space {self.config.name}!") + handle_exception(e, self.config.name, self.config.namespace) + + @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "start_space") + def start(self): + """Start the HyperPod Space by setting desired status to Running.""" + self.update(desired_status="Running") + + @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "stop_space") + def stop(self): + """Stop the HyperPod Space by setting desired status to Stopped.""" + self.update(desired_status="Stopped") + + def list_pods(self) -> List[str]: + """List all pods associated with this space. + + Returns: + List[str]: List of pod names associated with the space + """ + self.verify_kube_config() + logger = self.get_logger() + + if not self.space_exists(): + logger.info(f"HyperPod Space '{self.config.name}' does not exist in namespace '{self.config.namespace}'") + return [] + + v1 = client.CoreV1Api() + + try: + pods = v1.list_namespaced_pod( + namespace=self.config.namespace, + label_selector=f"sagemaker.aws.com/space-name={self.config.name}" + ) + return [pod.metadata.name for pod in pods.items] + except Exception as e: + handle_exception(e, self.config.name, self.config.namespace) + + def get_logs(self, pod_name: Optional[str] = None, container: Optional[str] = None) -> str: + """Get logs from a pod associated with this space. + + Args: + pod_name (str, optional): Name of the pod to get logs from. + If None, gets logs from the first available pod. + container (str, optional): Name of the container to get logs from. + + Returns: + str: The pod logs + """ + self.verify_kube_config() + logger = self.get_logger() + + if not self.space_exists(): + logger.info(f"HyperPod Space '{self.config.name}' does not exist in namespace '{self.config.namespace}'") + return "" + + if not pod_name: + pods = self.list_pods() + if not pods: + raise RuntimeError(f"No pods found for space '{self.config.name}'") + pod_name = pods[0] + + v1 = client.CoreV1Api() + + try: + if container: + logs = v1.read_namespaced_pod_log( + name=pod_name, + namespace=self.config.namespace, + container=container + ) + else: + logs = v1.read_namespaced_pod_log( + name=pod_name, + namespace=self.config.namespace + ) + return logs + except Exception as e: + handle_exception(e, pod_name, self.config.namespace) diff --git a/src/sagemaker/hyperpod/space/utils.py b/src/sagemaker/hyperpod/space/utils.py new file mode 100644 index 00000000..4b3023f6 --- /dev/null +++ b/src/sagemaker/hyperpod/space/utils.py @@ -0,0 +1,57 @@ +"""Utility functions for space operations.""" + +import re +from typing import Dict, Any, Set +from pydantic import BaseModel + + +def camel_to_snake(name: str) -> str: + """Convert camelCase to snake_case.""" + s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name) + return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower() + + +def get_model_fields(model_class: BaseModel) -> Set[str]: + """Get all field names from a Pydantic model.""" + return set(model_class.model_fields.keys()) + + +def map_kubernetes_response_to_model(k8s_data: Dict[str, Any], model_class: BaseModel) -> Dict[str, Any]: + """ + Map Kubernetes API response to model-compatible format. + + Args: + k8s_data: Raw Kubernetes API response data + model_class: Pydantic model class to map to + + Returns: + Dict with fields mapped and filtered for the model + """ + model_fields = get_model_fields(model_class) + mapped_data = {} + + # Extract metadata fields + if 'metadata' in k8s_data: + metadata = k8s_data['metadata'] + if 'name' in metadata and 'name' in model_fields: + mapped_data['name'] = metadata['name'] + if 'namespace' in metadata and 'namespace' in model_fields: + mapped_data['namespace'] = metadata['namespace'] + + # Extract and map spec fields + if 'spec' in k8s_data: + spec = k8s_data['spec'] + for k8s_field, value in spec.items(): + snake_field = camel_to_snake(k8s_field) + if snake_field in model_fields: + mapped_data[snake_field] = value + + # Extract and map status fields + if 'status' in k8s_data: + status = k8s_data['status'] + for k8s_field, value in status.items(): + snake_field = camel_to_snake(k8s_field) + if snake_field in model_fields: + mapped_data[snake_field] = value + + return mapped_data diff --git a/test/unit_tests/test_hyperpod_space.py b/test/unit_tests/test_hyperpod_space.py new file mode 100644 index 00000000..23ee7553 --- /dev/null +++ b/test/unit_tests/test_hyperpod_space.py @@ -0,0 +1,571 @@ +import unittest +from unittest.mock import Mock, patch, MagicMock +from kubernetes.client.rest import ApiException + +from sagemaker.hyperpod.space.hyperpod_space import HPSpace +from hyperpod_space_template.v1_0.model import SpaceConfig + + +class TestHPSpace(unittest.TestCase): + """Test cases for HPSpace PySDK""" + + def setUp(self): + """Setup test fixtures""" + self.mock_config = SpaceConfig( + name="test-space", + display_name="Test Space", + namespace="test-namespace", + image="test-image:latest", + desired_status="Running" + ) + self.hp_space = HPSpace(config=self.mock_config) + + @patch('sagemaker.hyperpod.space.hyperpod_space.config.load_kube_config') + @patch('sagemaker.hyperpod.space.hyperpod_space.verify_kubernetes_version_compatibility') + def test_verify_kube_config_success(self, mock_verify_k8s, mock_load_config): + """Test successful kubeconfig verification""" + HPSpace.is_kubeconfig_loaded = False + HPSpace.verify_kube_config() + + mock_load_config.assert_called_once() + mock_verify_k8s.assert_called_once() + self.assertTrue(HPSpace.is_kubeconfig_loaded) + + @patch('sagemaker.hyperpod.space.hyperpod_space.config.load_kube_config') + def test_verify_kube_config_failure(self, mock_load_config): + """Test kubeconfig verification failure""" + HPSpace.is_kubeconfig_loaded = False + mock_load_config.side_effect = Exception("Config load failed") + + with self.assertRaises(RuntimeError) as context: + HPSpace.verify_kube_config() + self.assertIn("Failed to load kubeconfig: Config load failed", str(context.exception)) + + def test_verify_kube_config_already_loaded(self): + """Test kubeconfig verification when already loaded""" + HPSpace.is_kubeconfig_loaded = True + + with patch('sagemaker.hyperpod.space.hyperpod_space.config.load_kube_config') as mock_load_config: + HPSpace.verify_kube_config() + mock_load_config.assert_not_called() + + @patch('sagemaker.hyperpod.space.hyperpod_space.client.CustomObjectsApi') + @patch.object(HPSpace, 'verify_kube_config') + @patch.object(HPSpace, 'space_exists') + def test_create_success(self, mock_space_exists, mock_verify_config, mock_custom_api_class): + """Test successful dev space creation""" + mock_custom_api = Mock() + mock_custom_api_class.return_value = mock_custom_api + mock_space_exists.return_value = False + + # Mock the config.to_domain() method + mock_domain_config = { + "space_spec": { + "apiVersion": "workspace.jupyter.org/v1alpha1", + "kind": "Workspace", + "metadata": {"name": "test-space", "namespace": "test-namespace"}, + "spec": {"image": "test-image:latest"} + } + } + + with patch('hyperpod_space_template.v1_0.model.SpaceConfig.to_domain', return_value=mock_domain_config): + self.hp_space.create() + + mock_verify_config.assert_called_once() + mock_space_exists.assert_called_once() + mock_custom_api.create_namespaced_custom_object.assert_called_once() + + @patch('sagemaker.hyperpod.space.hyperpod_space.client.CustomObjectsApi') + @patch.object(HPSpace, 'verify_kube_config') + @patch.object(HPSpace, 'space_exists') + def test_create_already_exists(self, mock_space_exists, mock_verify_config, mock_custom_api_class): + """Test dev space creation when resource already exists""" + mock_custom_api = Mock() + mock_custom_api_class.return_value = mock_custom_api + mock_space_exists.return_value = True + + self.hp_space.create() + + mock_verify_config.assert_called_once() + mock_space_exists.assert_called_once() + # Should not call create since resource exists + mock_custom_api.create_namespaced_custom_object.assert_not_called() + + @patch('sagemaker.hyperpod.space.hyperpod_space.client.CustomObjectsApi') + @patch.object(HPSpace, 'verify_kube_config') + @patch('sagemaker.hyperpod.space.hyperpod_space.handle_exception') + @patch.object(HPSpace, 'space_exists') + def test_create_failure(self, mock_space_exists, mock_handle_exception, mock_verify_config, mock_custom_api_class): + """Test dev space creation failure""" + mock_custom_api = Mock() + mock_custom_api_class.return_value = mock_custom_api + mock_space_exists.return_value = False + + # Mock creation failure + mock_custom_api.create_namespaced_custom_object.side_effect = Exception("Creation failed") + + mock_domain_config = { + "space_spec": { + "apiVersion": "workspace.jupyter.org/v1alpha1", + "kind": "Workspace", + "metadata": {"name": "test-space", "namespace": "test-namespace"}, + "spec": {"image": "test-image:latest"} + } + } + + with patch('hyperpod_space_template.v1_0.model.SpaceConfig.to_domain', return_value=mock_domain_config): + self.hp_space.create() + + mock_handle_exception.assert_called_once() + + def test_space_exists_success(self): + """Test space_exists method when space exists""" + with patch('sagemaker.hyperpod.space.hyperpod_space.client.CustomObjectsApi') as mock_custom_api_class: + mock_custom_api = Mock() + mock_custom_api_class.return_value = mock_custom_api + mock_custom_api.get_namespaced_custom_object.return_value = { + "metadata": {"name": "test-space", "namespace": "test-namespace"} + } + + result = self.hp_space.space_exists() + self.assertTrue(result) + + def test_space_exists_not_found(self): + """Test space_exists method when space doesn't exist""" + with patch('sagemaker.hyperpod.space.hyperpod_space.client.CustomObjectsApi') as mock_custom_api_class: + mock_custom_api = Mock() + mock_custom_api_class.return_value = mock_custom_api + mock_custom_api.get_namespaced_custom_object.side_effect = ApiException(status=404) + + result = self.hp_space.space_exists() + self.assertFalse(result) + + def test_space_exists_api_error(self): + """Test space_exists method with non-404 API error""" + with patch('sagemaker.hyperpod.space.hyperpod_space.client.CustomObjectsApi') as mock_custom_api_class: + mock_custom_api = Mock() + mock_custom_api_class.return_value = mock_custom_api + mock_custom_api.get_namespaced_custom_object.side_effect = ApiException(status=500) + + with self.assertRaises(ApiException): + self.hp_space.space_exists() + + @patch('sagemaker.hyperpod.space.hyperpod_space.client.CustomObjectsApi') + @patch.object(HPSpace, 'verify_kube_config') + @patch('sagemaker.hyperpod.space.hyperpod_space.get_default_namespace') + def test_list_success(self, mock_get_namespace, mock_verify_config, mock_custom_api_class): + """Test successful dev space listing""" + mock_custom_api = Mock() + mock_custom_api_class.return_value = mock_custom_api + mock_get_namespace.return_value = "default" + + mock_response = { + "items": [ + { + "metadata": {"name": "space1", "namespace": "default"}, + "spec": {"image": "image1:latest", "displayName": "Space 1"}, + }, + { + "metadata": {"name": "space2", "namespace": "default"}, + "spec": {"image": "image2:latest", "displayName": "Space 2"}, + } + ] + } + mock_custom_api.list_namespaced_custom_object.return_value = mock_response + + result = HPSpace.list() + + self.assertEqual(len(result), 2) + self.assertEqual(result[0].config.name, "space1") + self.assertEqual(result[1].config.name, "space2") + mock_custom_api.list_namespaced_custom_object.assert_called_once() + + @patch('sagemaker.hyperpod.space.hyperpod_space.client.CustomObjectsApi') + @patch.object(HPSpace, 'verify_kube_config') + def test_list_with_namespace(self, mock_verify_config, mock_custom_api_class): + """Test dev space listing with specific namespace""" + mock_custom_api = Mock() + mock_custom_api_class.return_value = mock_custom_api + + mock_response = {"items": []} + mock_custom_api.list_namespaced_custom_object.return_value = mock_response + + HPSpace.list(namespace="custom-namespace") + + mock_custom_api.list_namespaced_custom_object.assert_called_once_with( + group="workspace.jupyter.org", + version="v1alpha1", + namespace="custom-namespace", + plural="workspaces" + ) + + @patch('sagemaker.hyperpod.space.hyperpod_space.client.CustomObjectsApi') + @patch.object(HPSpace, 'verify_kube_config') + @patch('sagemaker.hyperpod.space.hyperpod_space.handle_exception') + def test_list_failure(self, mock_handle_exception, mock_verify_config, mock_custom_api_class): + """Test dev space listing failure""" + mock_custom_api = Mock() + mock_custom_api_class.return_value = mock_custom_api + mock_custom_api.list_namespaced_custom_object.side_effect = Exception("List failed") + + HPSpace.list(namespace="test-namespace") + + mock_handle_exception.assert_called_once() + + @patch('sagemaker.hyperpod.space.hyperpod_space.client.CustomObjectsApi') + @patch.object(HPSpace, 'verify_kube_config') + def test_get_success(self, mock_verify_config, mock_custom_api_class): + """Test successful dev space retrieval""" + mock_custom_api = Mock() + mock_custom_api_class.return_value = mock_custom_api + + mock_response = { + "metadata": {"name": "test-space", "namespace": "test-namespace"}, + "spec": {"image": "test-image:latest", "displayName": "Test Space"}, + } + mock_custom_api.get_namespaced_custom_object.return_value = mock_response + + result = HPSpace.get(name="test-space", namespace="test-namespace") + + self.assertEqual(result.config.name, "test-space") + mock_custom_api.get_namespaced_custom_object.assert_called_once() + + @patch('sagemaker.hyperpod.space.hyperpod_space.client.CustomObjectsApi') + @patch.object(HPSpace, 'verify_kube_config') + @patch('sagemaker.hyperpod.space.hyperpod_space.handle_exception') + def test_get_failure(self, mock_handle_exception, mock_verify_config, mock_custom_api_class): + """Test dev space retrieval failure""" + mock_custom_api = Mock() + mock_custom_api_class.return_value = mock_custom_api + mock_custom_api.get_namespaced_custom_object.side_effect = Exception("Get failed") + + HPSpace.get(name="test-space", namespace="test-namespace") + + mock_custom_api.get_namespaced_custom_object.assert_called_once_with( + group="workspace.jupyter.org", + version="v1alpha1", + namespace="test-namespace", + plural="workspaces", + name='test-space' + ) + mock_handle_exception.assert_called_once() + + @patch('sagemaker.hyperpod.space.hyperpod_space.client.CustomObjectsApi') + @patch.object(HPSpace, 'verify_kube_config') + @patch.object(HPSpace, 'space_exists') + def test_delete_success(self, mock_space_exists, mock_verify_config, mock_custom_api_class): + """Test successful dev space deletion""" + mock_custom_api = Mock() + mock_custom_api_class.return_value = mock_custom_api + mock_space_exists.return_value = True + + self.hp_space.delete() + + mock_verify_config.assert_called_once() + mock_space_exists.assert_called_once() + mock_custom_api.delete_namespaced_custom_object.assert_called_once_with( + group="workspace.jupyter.org", + version="v1alpha1", + namespace="test-namespace", + plural="workspaces", + name="test-space" + ) + + @patch('sagemaker.hyperpod.space.hyperpod_space.client.CustomObjectsApi') + @patch.object(HPSpace, 'verify_kube_config') + @patch.object(HPSpace, 'space_exists') + def test_delete_not_exists(self, mock_space_exists, mock_verify_config, mock_custom_api_class): + """Test dev space deletion when space doesn't exist""" + mock_custom_api = Mock() + mock_custom_api_class.return_value = mock_custom_api + mock_space_exists.return_value = False + + self.hp_space.delete() + + mock_verify_config.assert_called_once() + mock_space_exists.assert_called_once() + # Should not call delete since resource doesn't exist + mock_custom_api.delete_namespaced_custom_object.assert_not_called() + + @patch('sagemaker.hyperpod.space.hyperpod_space.client.CustomObjectsApi') + @patch.object(HPSpace, 'verify_kube_config') + @patch('sagemaker.hyperpod.space.hyperpod_space.handle_exception') + @patch.object(HPSpace, 'space_exists') + def test_delete_failure(self, mock_space_exists, mock_handle_exception, mock_verify_config, mock_custom_api_class): + """Test dev space deletion failure""" + mock_custom_api = Mock() + mock_custom_api_class.return_value = mock_custom_api + mock_space_exists.return_value = True + mock_custom_api.delete_namespaced_custom_object.side_effect = Exception("Delete failed") + + self.hp_space.delete() + + mock_handle_exception.assert_called_once() + + @patch('sagemaker.hyperpod.space.hyperpod_space.client.CustomObjectsApi') + @patch.object(HPSpace, 'verify_kube_config') + @patch.object(HPSpace, 'space_exists') + def test_update_success(self, mock_space_exists, mock_verify_config, mock_custom_api_class): + """Test successful dev space update""" + mock_custom_api = Mock() + mock_custom_api_class.return_value = mock_custom_api + mock_space_exists.return_value = True + + mock_domain_config = { + "space_spec": { + "spec": {"desiredStatus": "Stopped"} + } + } + + with patch('hyperpod_space_template.v1_0.model.SpaceConfig.to_domain', return_value=mock_domain_config): + self.hp_space.update(desired_status="Stopped") + + mock_custom_api.patch_namespaced_custom_object.assert_called_once_with( + group="workspace.jupyter.org", + version="v1alpha1", + namespace="test-namespace", + plural="workspaces", + name="test-space", + body={"spec": {"desiredStatus": "Stopped"}} + ) + + @patch('sagemaker.hyperpod.space.hyperpod_space.client.CustomObjectsApi') + @patch.object(HPSpace, 'verify_kube_config') + @patch.object(HPSpace, 'space_exists') + def test_update_not_exists(self, mock_space_exists, mock_verify_config, mock_custom_api_class): + """Test dev space update when space doesn't exist""" + mock_custom_api = Mock() + mock_custom_api_class.return_value = mock_custom_api + mock_space_exists.return_value = False + + self.hp_space.update(desired_status="Stopped") + + mock_verify_config.assert_called_once() + mock_space_exists.assert_called_once() + # Should not call update since resource doesn't exist + mock_custom_api.patch_namespaced_custom_object.assert_not_called() + + @patch('sagemaker.hyperpod.space.hyperpod_space.client.CustomObjectsApi') + @patch.object(HPSpace, 'verify_kube_config') + @patch('sagemaker.hyperpod.space.hyperpod_space.handle_exception') + @patch.object(HPSpace, 'space_exists') + def test_update_failure(self, mock_space_exists, mock_handle_exception, mock_verify_config, mock_custom_api_class): + """Test dev space update failure""" + mock_custom_api = Mock() + mock_custom_api_class.return_value = mock_custom_api + mock_space_exists.return_value = True + mock_custom_api.patch_namespaced_custom_object.side_effect = Exception("Update failed") + + mock_domain_config = {"space_spec": {"spec": {}}} + + with patch('hyperpod_space_template.v1_0.model.SpaceConfig.to_domain', return_value=mock_domain_config): + self.hp_space.update(desired_status="Stopped") + + mock_handle_exception.assert_called_once() + + @patch.object(HPSpace, 'update') + def test_start(self, mock_update): + """Test dev space start""" + self.hp_space.start() + mock_update.assert_called_once_with(desired_status="Running") + + @patch.object(HPSpace, 'update') + def test_stop(self, mock_update): + """Test dev space stop""" + self.hp_space.stop() + mock_update.assert_called_once_with(desired_status="Stopped") + + @patch('sagemaker.hyperpod.space.hyperpod_space.client.CoreV1Api') + @patch.object(HPSpace, 'verify_kube_config') + @patch.object(HPSpace, 'space_exists') + def test_list_pods_success(self, mock_space_exists, mock_verify_config, mock_core_api_class): + """Test successful pod listing""" + mock_core_api = Mock() + mock_core_api_class.return_value = mock_core_api + mock_space_exists.return_value = True + + mock_pod1 = Mock() + mock_pod1.metadata.name = "pod1" + mock_pod2 = Mock() + mock_pod2.metadata.name = "pod2" + + mock_pods = Mock() + mock_pods.items = [mock_pod1, mock_pod2] + mock_core_api.list_namespaced_pod.return_value = mock_pods + + result = self.hp_space.list_pods() + + self.assertEqual(result, ["pod1", "pod2"]) + mock_core_api.list_namespaced_pod.assert_called_once_with( + namespace="test-namespace", + label_selector="sagemaker.aws.com/space-name=test-space" + ) + + @patch('sagemaker.hyperpod.space.hyperpod_space.client.CoreV1Api') + @patch.object(HPSpace, 'verify_kube_config') + @patch.object(HPSpace, 'space_exists') + def test_list_pods_not_exists(self, mock_space_exists, mock_verify_config, mock_core_api_class): + """Test pod listing when space doesn't exist""" + mock_core_api = Mock() + mock_core_api_class.return_value = mock_core_api + mock_space_exists.return_value = False + + result = self.hp_space.list_pods() + + self.assertEqual(result, []) + mock_core_api.list_namespaced_pod.assert_not_called() + + @patch('sagemaker.hyperpod.space.hyperpod_space.client.CoreV1Api') + @patch.object(HPSpace, 'verify_kube_config') + @patch('sagemaker.hyperpod.space.hyperpod_space.handle_exception') + @patch.object(HPSpace, 'space_exists') + def test_list_pods_failure(self, mock_space_exists, mock_handle_exception, mock_verify_config, mock_core_api_class): + """Test pod listing failure""" + mock_core_api = Mock() + mock_core_api_class.return_value = mock_core_api + mock_space_exists.return_value = True + mock_core_api.list_namespaced_pod.side_effect = Exception("List pods failed") + + self.hp_space.list_pods() + + mock_handle_exception.assert_called_once() + + @patch('sagemaker.hyperpod.space.hyperpod_space.client.CoreV1Api') + @patch.object(HPSpace, 'verify_kube_config') + @patch.object(HPSpace, 'list_pods') + @patch.object(HPSpace, 'space_exists') + def test_get_logs_with_pod_name(self, mock_space_exists, mock_list_pods, mock_verify_config, mock_core_api_class): + """Test getting logs with specific pod name""" + mock_core_api = Mock() + mock_core_api_class.return_value = mock_core_api + mock_space_exists.return_value = True + mock_core_api.read_namespaced_pod_log.return_value = "test logs" + + result = self.hp_space.get_logs(pod_name="test-pod") + + self.assertEqual(result, "test logs") + mock_core_api.read_namespaced_pod_log.assert_called_once_with( + name="test-pod", + namespace="test-namespace" + ) + mock_list_pods.assert_not_called() + + @patch('sagemaker.hyperpod.space.hyperpod_space.client.CoreV1Api') + @patch.object(HPSpace, 'verify_kube_config') + @patch.object(HPSpace, 'list_pods') + @patch.object(HPSpace, 'space_exists') + def test_get_logs_without_pod_name(self, mock_space_exists, mock_list_pods, mock_verify_config, mock_core_api_class): + """Test getting logs without pod name (uses first available pod)""" + mock_core_api = Mock() + mock_core_api_class.return_value = mock_core_api + mock_space_exists.return_value = True + mock_core_api.read_namespaced_pod_log.return_value = "test logs" + mock_list_pods.return_value = ["pod1", "pod2"] + + result = self.hp_space.get_logs() + + self.assertEqual(result, "test logs") + mock_core_api.read_namespaced_pod_log.assert_called_once_with( + name="pod1", + namespace="test-namespace" + ) + + @patch.object(HPSpace, 'verify_kube_config') + @patch.object(HPSpace, 'list_pods') + @patch.object(HPSpace, 'space_exists') + def test_get_logs_no_pods(self, mock_space_exists, mock_list_pods, mock_verify_config): + """Test getting logs when no pods are available""" + mock_space_exists.return_value = True + mock_list_pods.return_value = [] + + with self.assertRaises(RuntimeError) as context: + self.hp_space.get_logs() + self.assertIn("No pods found for space 'test-space'", str(context.exception)) + + @patch('sagemaker.hyperpod.space.hyperpod_space.client.CoreV1Api') + @patch.object(HPSpace, 'verify_kube_config') + @patch.object(HPSpace, 'space_exists') + def test_get_logs_with_container(self, mock_space_exists, mock_verify_config, mock_core_api_class): + """Test getting logs with specific container""" + mock_core_api = Mock() + mock_core_api_class.return_value = mock_core_api + mock_space_exists.return_value = True + mock_core_api.read_namespaced_pod_log.return_value = "container logs" + + result = self.hp_space.get_logs(pod_name="test-pod", container="test-container") + + self.assertEqual(result, "container logs") + mock_core_api.read_namespaced_pod_log.assert_called_once_with( + name="test-pod", + namespace="test-namespace", + container="test-container" + ) + + @patch('sagemaker.hyperpod.space.hyperpod_space.client.CoreV1Api') + @patch.object(HPSpace, 'verify_kube_config') + @patch.object(HPSpace, 'space_exists') + def test_get_logs_not_exists(self, mock_space_exists, mock_verify_config, mock_core_api_class): + """Test getting logs when space doesn't exist""" + mock_core_api = Mock() + mock_core_api_class.return_value = mock_core_api + mock_space_exists.return_value = False + + result = self.hp_space.get_logs(pod_name="test-pod") + + self.assertEqual(result, "") + mock_core_api.read_namespaced_pod_log.assert_not_called() + + @patch('sagemaker.hyperpod.space.hyperpod_space.client.CoreV1Api') + @patch.object(HPSpace, 'verify_kube_config') + @patch('sagemaker.hyperpod.space.hyperpod_space.handle_exception') + @patch.object(HPSpace, 'space_exists') + def test_get_logs_failure(self, mock_space_exists, mock_handle_exception, mock_verify_config, mock_core_api_class): + """Test getting logs failure""" + mock_core_api = Mock() + mock_core_api_class.return_value = mock_core_api + mock_space_exists.return_value = True + mock_core_api.read_namespaced_pod_log.side_effect = Exception("Get logs failed") + + self.hp_space.get_logs(pod_name="test-pod") + + mock_handle_exception.assert_called_once() + + def test_model_validation(self): + """Test model validation with invalid config""" + with self.assertRaises(ValueError): + HPSpace(config="invalid_config") + + def test_model_extra_forbid(self): + """Test that extra fields are forbidden""" + with self.assertRaises(ValueError): + HPSpace(config=self.mock_config, extra_field="not_allowed") + + @patch('sagemaker.hyperpod.space.hyperpod_space.setup_logging') + @patch.object(HPSpace, 'verify_kube_config') + @patch.object(HPSpace, 'space_exists') + def test_create_debug_logging(self, mock_space_exists, mock_verify_config, mock_setup_logging): + """Test create method with debug logging enabled""" + mock_logger = Mock() + mock_setup_logging.return_value = mock_logger + mock_space_exists.return_value = False + + # Mock domain config for YAML serialization + mock_domain_config = { + "space_spec": { + "apiVersion": "workspace.jupyter.org/v1alpha1", + "kind": "Workspace", + "metadata": {"name": "test-space", "namespace": "test-namespace"}, + "spec": {"image": "test-image:latest"} + } + } + + with patch('sagemaker.hyperpod.space.hyperpod_space.client.CustomObjectsApi'): + with patch('hyperpod_space_template.v1_0.model.SpaceConfig.to_domain', return_value=mock_domain_config): + self.hp_space.create(debug=True) + + mock_setup_logging.assert_called_once() + + def test_get_logger(self): + """Test get_logger class method""" + logger = HPSpace.get_logger() + self.assertEqual(logger.name, "sagemaker.hyperpod.space.hyperpod_space") diff --git a/test/unit_tests/test_space_utils.py b/test/unit_tests/test_space_utils.py new file mode 100644 index 00000000..025c0ee1 --- /dev/null +++ b/test/unit_tests/test_space_utils.py @@ -0,0 +1,76 @@ +"""Unit tests for space utils module.""" + +import unittest +from sagemaker.hyperpod.space.utils import camel_to_snake, get_model_fields, map_kubernetes_response_to_model +from hyperpod_space_template.v1_0.model import SpaceConfig + + +class TestSpaceUtils(unittest.TestCase): + """Test cases for space utils functions.""" + + def test_camel_to_snake(self): + """Test camelCase to snake_case conversion.""" + self.assertEqual(camel_to_snake("displayName"), "display_name") + self.assertEqual(camel_to_snake("desiredStatus"), "desired_status") + self.assertEqual(camel_to_snake("ownershipType"), "ownership_type") + self.assertEqual(camel_to_snake("image"), "image") + self.assertEqual(camel_to_snake("name"), "name") + + def test_get_model_fields(self): + """Test model fields extraction.""" + fields = get_model_fields(SpaceConfig) + expected_fields = { + 'name', 'display_name', 'namespace', 'image', 'desired_status', + 'ownership_type', 'resources', 'storage', 'volumes', 'container_config', + 'node_selector', 'affinity', 'tolerations', 'lifecycle', 'template_ref' + } + self.assertTrue(expected_fields.issubset(fields)) + + def test_map_kubernetes_response_to_model(self): + """Test Kubernetes response mapping to model format.""" + k8s_data = { + 'metadata': {'name': 'test-space', 'namespace': 'default'}, + 'spec': { + 'image': 'test:latest', + 'displayName': 'Test Space', + 'desiredStatus': 'Running', + 'unknownField': 'should be ignored' + }, + 'status': { + 'currentStatus': 'Running', + 'anotherUnknownField': 'also ignored' + } + } + + mapped = map_kubernetes_response_to_model(k8s_data, SpaceConfig) + + # Check that expected fields are mapped correctly + self.assertEqual(mapped['name'], 'test-space') + self.assertEqual(mapped['namespace'], 'default') + self.assertEqual(mapped['image'], 'test:latest') + self.assertEqual(mapped['display_name'], 'Test Space') + self.assertEqual(mapped['desired_status'], 'Running') + + # Check that unknown fields are filtered out + self.assertNotIn('unknownField', mapped) + self.assertNotIn('anotherUnknownField', mapped) + self.assertNotIn('currentStatus', mapped) + + def test_map_kubernetes_response_creates_valid_config(self): + """Test that mapped data creates valid SpaceConfig.""" + k8s_data = { + 'metadata': {'name': 'valid-space', 'namespace': 'test'}, + 'spec': { + 'image': 'valid:latest', + 'displayName': 'Valid Space', + 'desiredStatus': 'Running' + } + } + + mapped = map_kubernetes_response_to_model(k8s_data, SpaceConfig) + config = SpaceConfig(**mapped) + + self.assertEqual(config.name, 'valid-space') + self.assertEqual(config.display_name, 'Valid Space') + self.assertEqual(config.namespace, 'test') + self.assertEqual(config.image, 'valid:latest') From 13f1c0c7c8bf135aa4f4682c8a64c794fa544321 Mon Sep 17 00:00:00 2001 From: aws-brianxia Date: Tue, 4 Nov 2025 15:02:59 -0800 Subject: [PATCH 14/31] Implement the pySDK for the Space Template (#282) --- src/sagemaker/hyperpod/cli/space_utils.py | 2 +- src/sagemaker/hyperpod/common/utils.py | 40 +- src/sagemaker/hyperpod/space/__init__.py | 2 + .../hyperpod/space/hyperpod_space_template.py | 235 +++++++++++ test/unit_tests/common/test_utils.py | 30 ++ .../test_hyperpod_space_template.py | 371 ++++++++++++++++++ 6 files changed, 667 insertions(+), 13 deletions(-) create mode 100644 src/sagemaker/hyperpod/space/hyperpod_space_template.py create mode 100644 test/unit_tests/test_hyperpod_space_template.py diff --git a/src/sagemaker/hyperpod/cli/space_utils.py b/src/sagemaker/hyperpod/cli/space_utils.py index b87c4864..5cd2948d 100644 --- a/src/sagemaker/hyperpod/cli/space_utils.py +++ b/src/sagemaker/hyperpod/cli/space_utils.py @@ -237,7 +237,7 @@ def wrapped_func(*args, **kwargs): wrapped_func = click.option( "--container-config", callback=_parse_container_config_param, - help="Container configuration. Format: --container-config command=,args=", + help="Container configuration. Format: --container-config command=,args=", )(wrapped_func) # Exclude the props that were handled out of the below for loop diff --git a/src/sagemaker/hyperpod/common/utils.py b/src/sagemaker/hyperpod/common/utils.py index 15e73ba8..60ce01d1 100644 --- a/src/sagemaker/hyperpod/common/utils.py +++ b/src/sagemaker/hyperpod/common/utils.py @@ -38,7 +38,7 @@ def get_default_namespace(): "No active context. Please use set_cluster_context() method to set current context." ) -def handle_exception(e: Exception, name: str, namespace: str, +def handle_exception(e: Exception, name: str, namespace: Optional[str], operation_type: str = 'unknown', resource_type: str = 'unknown'): """ Handle various Kubernetes API exceptions for SDK usage (non-CLI). @@ -53,23 +53,39 @@ def handle_exception(e: Exception, name: str, namespace: str, operation_type: Operation type (legacy parameter, kept for backward compatibility) resource_type: Resource type (legacy parameter, kept for backward compatibility) """ + if isinstance(e, ApiException): if e.status == 401: raise Exception(f"Credentials unauthorized.") from e elif e.status == 403: - raise Exception( - f"Access denied to resource '{name}' in namespace '{namespace}'." - ) from e + if namespace: + raise Exception( + f"Access denied to resource '{name}' in namespace '{namespace}'." + ) from e + else: + raise Exception( + f"Access denied to resource '{name}'." + ) from e elif e.status == 404: - # Basic 404 for SDK usage - CLI commands get enhanced 404 via decorator - raise Exception( - f"Resource '{name}' not found in namespace '{namespace}'. " - f"Please check the resource name and namespace." - ) from e + if namespace: + # Basic 404 for SDK usage - CLI commands get enhanced 404 via decorator + raise Exception( + f"Resource '{name}' not found in namespace '{namespace}'. " + f"Please check the resource name and namespace." + ) from e + else: + raise Exception( + f"Resource '{name}' not found. Please check the resource name." + ) from e elif e.status == 409: - raise Exception( - f"Resource '{name}' already exists in namespace '{namespace}'." - ) from e + if namespace: + raise Exception( + f"Resource '{name}' already exists in namespace '{namespace}'." + ) from e + else: + raise Exception( + f"Resource '{name}' already exists." + ) from e elif 500 <= e.status < 600: raise Exception("Kubernetes API internal server error.") from e else: diff --git a/src/sagemaker/hyperpod/space/__init__.py b/src/sagemaker/hyperpod/space/__init__.py index c7ed320d..b1c18285 100644 --- a/src/sagemaker/hyperpod/space/__init__.py +++ b/src/sagemaker/hyperpod/space/__init__.py @@ -12,9 +12,11 @@ # language governing permissions and limitations under the License. from sagemaker.hyperpod.space.hyperpod_space import HPSpace +from sagemaker.hyperpod.space.hyperpod_space_template import HPSpaceTemplate from hyperpod_space_template.v1_0.model import SpaceConfig __all__ = [ "HPSpace", + "HPSpaceTemplate", "SpaceConfig", ] diff --git a/src/sagemaker/hyperpod/space/hyperpod_space_template.py b/src/sagemaker/hyperpod/space/hyperpod_space_template.py new file mode 100644 index 00000000..0499483e --- /dev/null +++ b/src/sagemaker/hyperpod/space/hyperpod_space_template.py @@ -0,0 +1,235 @@ +import logging +import yaml +from typing import List, Optional, ClassVar, Dict, Any, Union +from kubernetes import client, config +from kubernetes.client.rest import ApiException + +from sagemaker.hyperpod.common.utils import ( + handle_exception, + verify_kubernetes_version_compatibility +) +from sagemaker.hyperpod.common.telemetry.telemetry_logging import ( + _hyperpod_telemetry_emitter, +) +from sagemaker.hyperpod.common.telemetry.constants import Feature +from sagemaker.hyperpod.cli.constants.space_template_constants import ( + SPACE_TEMPLATE_GROUP, + SPACE_TEMPLATE_VERSION, + SPACE_TEMPLATE_PLURAL, +) + + +class HPSpaceTemplate: + """HyperPod Space Template on Amazon SageMaker HyperPod clusters. + + This class provides methods to create, manage, and monitor space templates + on SageMaker HyperPod clusters orchestrated by Amazon EKS. + """ + + is_kubeconfig_loaded: ClassVar[bool] = False + + def __init__(self, file_path: str): + """Initialize space template with config YAML file path. + + Args: + file_path: Path to YAML file + """ + try: + with open(file_path, 'r') as f: + self.config_data = yaml.safe_load(f) + except FileNotFoundError: + raise FileNotFoundError(f"File '{file_path}' not found") + except yaml.YAMLError as e: + raise ValueError(f"Error parsing YAML file: {e}") + + self.name = self.config_data.get('metadata', {}).get('name') + + @classmethod + def get_logger(cls): + """Get logger for the class.""" + return logging.getLogger(__name__) + + @classmethod + def verify_kube_config(cls): + """Verify and load Kubernetes configuration.""" + if not cls.is_kubeconfig_loaded: + config.load_kube_config() + cls.is_kubeconfig_loaded = True + verify_kubernetes_version_compatibility(cls.get_logger()) + + @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "create_space_template") + def create(self) -> "HPSpaceTemplate": + """Create the space template in the cluster. + + Returns: + Updated HPSpaceTemplate instance with server response + """ + self.verify_kube_config() + + try: + api_instance = client.CustomObjectsApi() + response = api_instance.create_cluster_custom_object( + group=SPACE_TEMPLATE_GROUP, + version=SPACE_TEMPLATE_VERSION, + plural=SPACE_TEMPLATE_PLURAL, + body=self.config_data + ) + + self.config_data = response + self.get_logger().info(f"Space template '{self.name}' created successfully") + + except ApiException as e: + handle_exception(e, self.name, None) + except Exception as e: + self.get_logger().error(f"Error creating space template: {e}") + raise + + @classmethod + @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "list_space_templates") + def list(cls) -> List["HPSpaceTemplate"]: + """List all space templates. + + Returns: + List of HPSpaceTemplate instances + """ + cls.verify_kube_config() + + try: + api_instance = client.CustomObjectsApi() + response = api_instance.list_cluster_custom_object( + group=SPACE_TEMPLATE_GROUP, + version=SPACE_TEMPLATE_VERSION, + plural=SPACE_TEMPLATE_PLURAL + ) + + templates = [] + for item in response.get("items", []): + templates.append(cls(item)) + + return templates + + except ApiException as e: + handle_exception(e, "list", None) + except Exception as e: + cls.get_logger().error(f"Error listing space templates: {e}") + raise + + @classmethod + @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "get_space_template") + def get(cls, name: str) -> "HPSpaceTemplate": + """Get a specific space template by name. + + Args: + name: Name of the space template + + Returns: + HPSpaceTemplate instance + """ + cls.verify_kube_config() + + try: + api_instance = client.CustomObjectsApi() + response = api_instance.get_cluster_custom_object( + group=SPACE_TEMPLATE_GROUP, + version=SPACE_TEMPLATE_VERSION, + plural=SPACE_TEMPLATE_PLURAL, + name=name + ) + + # Remove managedFields for cleaner output + if 'metadata' in response: + response['metadata'].pop('managedFields', None) + + return cls(response) + + except ApiException as e: + handle_exception(e, name, None) + except Exception as e: + cls.get_logger().error(f"Error getting space template '{name}': {e}") + raise + + @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "delete_space_template") + def delete(self) -> None: + """Delete the space template from the cluster.""" + self.verify_kube_config() + + try: + api_instance = client.CustomObjectsApi() + api_instance.delete_cluster_custom_object( + group=SPACE_TEMPLATE_GROUP, + version=SPACE_TEMPLATE_VERSION, + plural=SPACE_TEMPLATE_PLURAL, + name=self.name + ) + + self.get_logger().info(f"Space template '{self.name}' deleted successfully") + + except ApiException as e: + handle_exception(e, self.name, None) + except Exception as e: + self.get_logger().error(f"Error deleting space template '{self.name}': {e}") + raise + + @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "update_space_template") + def update(self, file_path: str) -> "HPSpaceTemplate": + """Update the space template from a YAML file. + + Args: + file_path: Path to the YAML configuration file + + Returns: + Updated HPSpaceTemplate instance + """ + self.verify_kube_config() + + try: + with open(file_path, 'r') as f: + config_data = yaml.safe_load(f) + + # Validate that the name matches + yaml_name = config_data.get('metadata', {}).get('name') + if yaml_name and yaml_name != self.name: + raise ValueError(f"Name mismatch. Template name '{self.name}' does not match YAML name '{yaml_name}'") + + # Remove immutable fields + if 'metadata' in config_data: + for field in ['resourceVersion', 'uid', 'creationTimestamp', 'managedFields']: + config_data['metadata'].pop(field, None) + + api_instance = client.CustomObjectsApi() + response = api_instance.patch_cluster_custom_object( + group=SPACE_TEMPLATE_GROUP, + version=SPACE_TEMPLATE_VERSION, + plural=SPACE_TEMPLATE_PLURAL, + name=self.name, + body=config_data + ) + + self.config_data = response + self.get_logger().info(f"Space template '{self.name}' updated successfully") + + except FileNotFoundError: + raise FileNotFoundError(f"File '{file_path}' not found") + except yaml.YAMLError as e: + raise ValueError(f"Error parsing YAML file: {e}") + except ApiException as e: + handle_exception(e, self.name, None) + except Exception as e: + self.get_logger().error(f"Error updating space template '{self.name}': {e}") + raise + + def to_yaml(self) -> str: + """Convert the space template to YAML format. + + Returns: + YAML string representation + """ + return yaml.dump(self.config_data, default_flow_style=False) + + def to_dict(self) -> Dict[str, Any]: + """Convert the space template to dictionary format. + + Returns: + Dictionary representation + """ + return self.config_data diff --git a/test/unit_tests/common/test_utils.py b/test/unit_tests/common/test_utils.py index 7ba025b3..f43e37ff 100644 --- a/test/unit_tests/common/test_utils.py +++ b/test/unit_tests/common/test_utils.py @@ -39,6 +39,16 @@ def test_handle_api_exception_403(self): str(context.exception), ) + def test_handle_api_exception_403_without_namespace(self): + """Test handling 403 API exception""" + exception = ApiException(status=403) + with self.assertRaises(Exception) as context: + handle_exception(exception, "test-job", None) + self.assertIn( + "Access denied to resource 'test-job'", + str(context.exception), + ) + def test_handle_api_exception_404(self): """Test handling 404 API exception""" exception = ApiException(status=404) @@ -49,6 +59,16 @@ def test_handle_api_exception_404(self): str(context.exception), ) + def test_handle_api_exception_404_without_namespace(self): + """Test handling 404 API exception""" + exception = ApiException(status=404) + with self.assertRaises(Exception) as context: + handle_exception(exception, "test-job", None) + self.assertIn( + "Resource 'test-job' not found", + str(context.exception), + ) + def test_handle_api_exception_409(self): """Test handling 409 API exception""" exception = ApiException(status=409) @@ -59,6 +79,16 @@ def test_handle_api_exception_409(self): str(context.exception), ) + def test_handle_api_exception_409_without_namespace(self): + """Test handling 409 API exception""" + exception = ApiException(status=409) + with self.assertRaises(Exception) as context: + handle_exception(exception, "test-job", None) + self.assertIn( + "Resource 'test-job' already exists", + str(context.exception), + ) + def test_handle_api_exception_500(self): """Test handling 500 API exception""" exception = ApiException(status=500) diff --git a/test/unit_tests/test_hyperpod_space_template.py b/test/unit_tests/test_hyperpod_space_template.py new file mode 100644 index 00000000..9d03534b --- /dev/null +++ b/test/unit_tests/test_hyperpod_space_template.py @@ -0,0 +1,371 @@ +import unittest +from unittest.mock import Mock, patch, mock_open +import yaml +from kubernetes.client.rest import ApiException + +from sagemaker.hyperpod.space.hyperpod_space_template import HPSpaceTemplate + + +class TestHPSpaceTemplate(unittest.TestCase): + """Test cases for HPSpaceTemplate PySDK""" + + def setUp(self): + """Setup test fixtures""" + self.mock_config_data = { + "apiVersion": "workspace.jupyter.org/v1alpha1", + "kind": "WorkspaceTemplate", + "metadata": { + "name": "test-template" + }, + "spec": { + "displayName": "Test Template", + "description": "Test space template" + } + } + self.yaml_content = yaml.dump(self.mock_config_data) + + @patch('builtins.open', new_callable=mock_open) + @patch('yaml.safe_load') + def test_init_success(self, mock_yaml_load, mock_file): + """Test successful initialization""" + mock_yaml_load.return_value = self.mock_config_data + mock_file.return_value.read.return_value = self.yaml_content + + template = HPSpaceTemplate("test.yaml") + + self.assertEqual(template.config_data, self.mock_config_data) + self.assertEqual(template.name, "test-template") + mock_file.assert_called_once_with("test.yaml", 'r') + + @patch('builtins.open', side_effect=FileNotFoundError) + def test_init_file_not_found(self, mock_file): + """Test initialization with non-existent file""" + with self.assertRaises(FileNotFoundError) as context: + HPSpaceTemplate("nonexistent.yaml") + self.assertIn("File 'nonexistent.yaml' not found", str(context.exception)) + + @patch('builtins.open', new_callable=mock_open) + @patch('yaml.safe_load', side_effect=yaml.YAMLError("Invalid YAML")) + def test_init_yaml_error(self, mock_yaml_load, mock_file): + """Test initialization with invalid YAML""" + with self.assertRaises(ValueError) as context: + HPSpaceTemplate("invalid.yaml") + self.assertIn("Error parsing YAML file", str(context.exception)) + + @patch('sagemaker.hyperpod.space.hyperpod_space_template.config.load_kube_config') + @patch('sagemaker.hyperpod.space.hyperpod_space_template.verify_kubernetes_version_compatibility') + def test_verify_kube_config_success(self, mock_verify_k8s, mock_load_config): + """Test successful kubeconfig verification""" + HPSpaceTemplate.is_kubeconfig_loaded = False + HPSpaceTemplate.verify_kube_config() + + mock_load_config.assert_called_once() + mock_verify_k8s.assert_called_once() + self.assertTrue(HPSpaceTemplate.is_kubeconfig_loaded) + + def test_verify_kube_config_already_loaded(self): + """Test kubeconfig verification when already loaded""" + HPSpaceTemplate.is_kubeconfig_loaded = True + + with patch('sagemaker.hyperpod.space.hyperpod_space_template.config.load_kube_config') as mock_load_config: + HPSpaceTemplate.verify_kube_config() + mock_load_config.assert_not_called() + + @patch('builtins.open', new_callable=mock_open) + @patch('yaml.safe_load') + @patch('sagemaker.hyperpod.space.hyperpod_space_template.client.CustomObjectsApi') + @patch.object(HPSpaceTemplate, 'verify_kube_config') + def test_create_success(self, mock_verify_config, mock_custom_api_class, mock_yaml_load, mock_file): + """Test successful space template creation""" + mock_yaml_load.return_value = self.mock_config_data + mock_custom_api = Mock() + mock_custom_api_class.return_value = mock_custom_api + mock_custom_api.create_cluster_custom_object.return_value = self.mock_config_data + + template = HPSpaceTemplate("test.yaml") + template.create() + + mock_verify_config.assert_called_once() + mock_custom_api.create_cluster_custom_object.assert_called_once_with( + group="workspace.jupyter.org", + version="v1alpha1", + plural="workspacetemplates", + body=self.mock_config_data + ) + + @patch('builtins.open', new_callable=mock_open) + @patch('yaml.safe_load') + @patch('sagemaker.hyperpod.space.hyperpod_space_template.client.CustomObjectsApi') + @patch.object(HPSpaceTemplate, 'verify_kube_config') + @patch('sagemaker.hyperpod.space.hyperpod_space_template.handle_exception') + def test_create_api_exception(self, mock_handle_exception, mock_verify_config, mock_custom_api_class, mock_yaml_load, mock_file): + """Test space template creation with API exception""" + mock_yaml_load.return_value = self.mock_config_data + mock_custom_api = Mock() + mock_custom_api_class.return_value = mock_custom_api + mock_custom_api.create_cluster_custom_object.side_effect = ApiException(status=409) + + template = HPSpaceTemplate("test.yaml") + template.create() + + mock_handle_exception.assert_called_once() + + @patch('builtins.open', new_callable=mock_open) + @patch('yaml.safe_load') + @patch('sagemaker.hyperpod.space.hyperpod_space_template.client.CustomObjectsApi') + @patch.object(HPSpaceTemplate, 'verify_kube_config') + def test_create_general_exception(self, mock_verify_config, mock_custom_api_class, mock_yaml_load, mock_file): + """Test space template creation with general exception""" + mock_yaml_load.return_value = self.mock_config_data + mock_custom_api = Mock() + mock_custom_api_class.return_value = mock_custom_api + mock_custom_api.create_cluster_custom_object.side_effect = Exception("Creation failed") + + template = HPSpaceTemplate("test.yaml") + + with self.assertRaises(Exception): + template.create() + + @patch('sagemaker.hyperpod.space.hyperpod_space_template.client.CustomObjectsApi') + @patch.object(HPSpaceTemplate, 'verify_kube_config') + def test_list_success(self, mock_verify_config, mock_custom_api_class): + """Test successful space template listing""" + mock_custom_api = Mock() + mock_custom_api_class.return_value = mock_custom_api + + mock_response = { + "items": [ + { + "metadata": {"name": "template1"}, + "spec": {"displayName": "Template 1"} + }, + { + "metadata": {"name": "template2"}, + "spec": {"displayName": "Template 2"} + } + ] + } + mock_custom_api.list_cluster_custom_object.return_value = mock_response + + with patch('builtins.open', new_callable=mock_open), \ + patch('yaml.safe_load', return_value=mock_response["items"][0]): + result = HPSpaceTemplate.list() + + self.assertEqual(len(result), 2) + mock_custom_api.list_cluster_custom_object.assert_called_once_with( + group="workspace.jupyter.org", + version="v1alpha1", + plural="workspacetemplates" + ) + + @patch('sagemaker.hyperpod.space.hyperpod_space_template.client.CustomObjectsApi') + @patch.object(HPSpaceTemplate, 'verify_kube_config') + @patch('sagemaker.hyperpod.space.hyperpod_space_template.handle_exception') + def test_list_api_exception(self, mock_handle_exception, mock_verify_config, mock_custom_api_class): + """Test space template listing with API exception""" + mock_custom_api = Mock() + mock_custom_api_class.return_value = mock_custom_api + mock_custom_api.list_cluster_custom_object.side_effect = ApiException(status=500) + + HPSpaceTemplate.list() + + mock_handle_exception.assert_called_once() + + @patch('sagemaker.hyperpod.space.hyperpod_space_template.client.CustomObjectsApi') + @patch.object(HPSpaceTemplate, 'verify_kube_config') + def test_list_general_exception(self, mock_verify_config, mock_custom_api_class): + """Test space template listing with general exception""" + mock_custom_api = Mock() + mock_custom_api_class.return_value = mock_custom_api + mock_custom_api.list_cluster_custom_object.side_effect = Exception("List failed") + + with self.assertRaises(Exception): + HPSpaceTemplate.list() + + @patch('sagemaker.hyperpod.space.hyperpod_space_template.client.CustomObjectsApi') + @patch.object(HPSpaceTemplate, 'verify_kube_config') + def test_get_success(self, mock_verify_config, mock_custom_api_class): + """Test successful space template retrieval""" + mock_custom_api = Mock() + mock_custom_api_class.return_value = mock_custom_api + + mock_response = { + "metadata": { + "name": "test-template", + "managedFields": [{"manager": "test"}] + }, + "spec": {"displayName": "Test Template"} + } + expected_response = { + "metadata": {"name": "test-template"}, + "spec": {"displayName": "Test Template"} + } + mock_custom_api.get_cluster_custom_object.return_value = mock_response + + with patch('builtins.open', new_callable=mock_open), \ + patch('yaml.safe_load', return_value=expected_response): + result = HPSpaceTemplate.get("test-template") + + mock_custom_api.get_cluster_custom_object.assert_called_once_with( + group="workspace.jupyter.org", + version="v1alpha1", + plural="workspacetemplates", + name="test-template" + ) + + @patch('sagemaker.hyperpod.space.hyperpod_space_template.client.CustomObjectsApi') + @patch.object(HPSpaceTemplate, 'verify_kube_config') + @patch('sagemaker.hyperpod.space.hyperpod_space_template.handle_exception') + def test_get_api_exception(self, mock_handle_exception, mock_verify_config, mock_custom_api_class): + """Test space template retrieval with API exception""" + mock_custom_api = Mock() + mock_custom_api_class.return_value = mock_custom_api + mock_custom_api.get_cluster_custom_object.side_effect = ApiException(status=404) + + HPSpaceTemplate.get("nonexistent-template") + + mock_handle_exception.assert_called_once() + + @patch('builtins.open', new_callable=mock_open) + @patch('yaml.safe_load') + @patch('sagemaker.hyperpod.space.hyperpod_space_template.client.CustomObjectsApi') + @patch.object(HPSpaceTemplate, 'verify_kube_config') + def test_delete_success(self, mock_verify_config, mock_custom_api_class, mock_yaml_load, mock_file): + """Test successful space template deletion""" + mock_yaml_load.return_value = self.mock_config_data + mock_custom_api = Mock() + mock_custom_api_class.return_value = mock_custom_api + + template = HPSpaceTemplate("test.yaml") + template.delete() + + mock_verify_config.assert_called_once() + mock_custom_api.delete_cluster_custom_object.assert_called_once_with( + group="workspace.jupyter.org", + version="v1alpha1", + plural="workspacetemplates", + name="test-template" + ) + + @patch('builtins.open', new_callable=mock_open) + @patch('yaml.safe_load') + @patch('sagemaker.hyperpod.space.hyperpod_space_template.client.CustomObjectsApi') + @patch.object(HPSpaceTemplate, 'verify_kube_config') + @patch('sagemaker.hyperpod.space.hyperpod_space_template.handle_exception') + def test_delete_api_exception(self, mock_handle_exception, mock_verify_config, mock_custom_api_class, mock_yaml_load, mock_file): + """Test space template deletion with API exception""" + mock_yaml_load.return_value = self.mock_config_data + mock_custom_api = Mock() + mock_custom_api_class.return_value = mock_custom_api + mock_custom_api.delete_cluster_custom_object.side_effect = ApiException(status=404) + + template = HPSpaceTemplate("test.yaml") + template.delete() + + mock_handle_exception.assert_called_once() + + @patch('builtins.open', new_callable=mock_open) + @patch('yaml.safe_load') + @patch('sagemaker.hyperpod.space.hyperpod_space_template.client.CustomObjectsApi') + @patch.object(HPSpaceTemplate, 'verify_kube_config') + def test_update_success(self, mock_verify_config, mock_custom_api_class, mock_yaml_load, mock_file): + """Test successful space template update""" + mock_yaml_load.side_effect = [self.mock_config_data, self.mock_config_data] + mock_custom_api = Mock() + mock_custom_api_class.return_value = mock_custom_api + mock_custom_api.patch_cluster_custom_object.return_value = self.mock_config_data + + template = HPSpaceTemplate("test.yaml") + template.update("updated.yaml") + + mock_verify_config.assert_called_once() + mock_custom_api.patch_cluster_custom_object.assert_called_once_with( + group="workspace.jupyter.org", + version="v1alpha1", + plural="workspacetemplates", + name="test-template", + body=self.mock_config_data + ) + + @patch('builtins.open', new_callable=mock_open) + @patch('yaml.safe_load') + @patch.object(HPSpaceTemplate, 'verify_kube_config') + def test_update_name_mismatch(self, mock_verify_config, mock_yaml_load, mock_file): + """Test space template update with name mismatch""" + mock_yaml_load.side_effect = [ + self.mock_config_data, + {"metadata": {"name": "different-name"}} + ] + + template = HPSpaceTemplate("test.yaml") + + with self.assertRaises(ValueError) as context: + template.update("different.yaml") + self.assertIn("Name mismatch", str(context.exception)) + + @patch('builtins.open') + @patch('yaml.safe_load') + @patch.object(HPSpaceTemplate, 'verify_kube_config') + def test_update_file_not_found(self, mock_verify_config, mock_yaml_load, mock_file): + """Test space template update with non-existent file""" + mock_yaml_load.return_value = self.mock_config_data + mock_file.side_effect = [mock_open().return_value, FileNotFoundError("File 'nonexistent.yaml' not found")] + + template = HPSpaceTemplate("test.yaml") + + with self.assertRaises(FileNotFoundError) as context: + template.update("nonexistent.yaml") + self.assertIn("File 'nonexistent.yaml' not found", str(context.exception)) + + @patch('builtins.open', new_callable=mock_open) + @patch('yaml.safe_load') + @patch.object(HPSpaceTemplate, 'verify_kube_config') + def test_update_yaml_error(self, mock_verify_config, mock_yaml_load, mock_file): + """Test space template update with YAML error""" + mock_yaml_load.side_effect = [self.mock_config_data, yaml.YAMLError("Invalid YAML")] + + template = HPSpaceTemplate("test.yaml") + + with self.assertRaises(ValueError) as context: + template.update("invalid.yaml") + self.assertIn("Error parsing YAML file", str(context.exception)) + + @patch('builtins.open', new_callable=mock_open) + @patch('yaml.safe_load') + @patch('sagemaker.hyperpod.space.hyperpod_space_template.client.CustomObjectsApi') + @patch.object(HPSpaceTemplate, 'verify_kube_config') + @patch('sagemaker.hyperpod.space.hyperpod_space_template.handle_exception') + def test_update_api_exception(self, mock_handle_exception, mock_verify_config, mock_custom_api_class, mock_yaml_load, mock_file): + """Test space template update with API exception""" + mock_yaml_load.side_effect = [self.mock_config_data, self.mock_config_data] + mock_custom_api = Mock() + mock_custom_api_class.return_value = mock_custom_api + mock_custom_api.patch_cluster_custom_object.side_effect = ApiException(status=404) + + template = HPSpaceTemplate("test.yaml") + template.update("updated.yaml") + + mock_handle_exception.assert_called_once() + + @patch('builtins.open', new_callable=mock_open) + @patch('yaml.safe_load') + def test_to_yaml(self, mock_yaml_load, mock_file): + """Test converting space template to YAML""" + mock_yaml_load.return_value = self.mock_config_data + + template = HPSpaceTemplate("test.yaml") + result = template.to_yaml() + + self.assertIsInstance(result, str) + self.assertIn("test-template", result) + + @patch('builtins.open', new_callable=mock_open) + @patch('yaml.safe_load') + def test_to_dict(self, mock_yaml_load, mock_file): + """Test converting space template to dictionary""" + mock_yaml_load.return_value = self.mock_config_data + + template = HPSpaceTemplate("test.yaml") + result = template.to_dict() + + self.assertEqual(result, self.mock_config_data) From 8fc72770c742a3a79550e475bbb7e0ea762ca262 Mon Sep 17 00:00:00 2001 From: aws-brianxia Date: Wed, 5 Nov 2025 16:22:51 -0800 Subject: [PATCH 15/31] Refactor Space CLI using the Space PySDK (#281) * Implement CRUD operations for Space PySDK * Update Space PySDK per new schema * Refactor CLI to use the PySDK --- src/sagemaker/hyperpod/cli/commands/space.py | 138 ++++---- src/sagemaker/hyperpod/cli/space_utils.py | 12 +- .../hyperpod/space/hyperpod_space.py | 77 ++--- test/unit_tests/cli/test_space.py | 302 ++++++++---------- test/unit_tests/cli/test_space_utils.py | 166 +++++----- test/unit_tests/test_hyperpod_space.py | 168 +--------- 6 files changed, 341 insertions(+), 522 deletions(-) diff --git a/src/sagemaker/hyperpod/cli/commands/space.py b/src/sagemaker/hyperpod/cli/commands/space.py index 450fcaac..9fd9ba31 100644 --- a/src/sagemaker/hyperpod/cli/commands/space.py +++ b/src/sagemaker/hyperpod/cli/commands/space.py @@ -1,9 +1,11 @@ import click import json +import yaml from tabulate import tabulate -from sagemaker.hyperpod.cli.clients.kubernetes_client import KubernetesClient +from sagemaker.hyperpod.space.hyperpod_space import HPSpace from sagemaker.hyperpod.cli.space_utils import generate_click_command from hyperpod_space_template.registry import SCHEMA_REGISTRY +from hyperpod_space_template.v1_0.model import SpaceConfig from sagemaker.hyperpod.common.telemetry.telemetry_logging import ( _hyperpod_telemetry_emitter, ) @@ -17,16 +19,12 @@ ) def space_create(version, config): """Create a space resource.""" - try: - name = config.get("name") - namespace = config.get("namespace") - space_spec = config.get("space_spec") - - k8s_client = KubernetesClient() - k8s_client.create_space(namespace, space_spec) + space_config = SpaceConfig(**config) + space = HPSpace(config=space_config) + space.create() - click.echo(f"Space '{name}' created successfully in namespace '{namespace}'") + click.echo(f"Space '{space_config.name}' created successfully in namespace '{space_config.namespace}'") except Exception as e: click.echo(f"Error creating space: {e}", err=True) @@ -36,24 +34,38 @@ def space_create(version, config): @click.option("--output", "-o", type=click.Choice(["table", "json"]), default="table") def space_list(namespace, output): """List space resources.""" - k8s_client = KubernetesClient() - try: - resources = k8s_client.list_spaces(namespace) + spaces = HPSpace.list(namespace=namespace) if output == "json": - click.echo(json.dumps(resources, indent=2)) + spaces_data = [] + for space in spaces: + space_dict = space.config.model_dump() + spaces_data.append(space_dict) + click.echo(json.dumps(spaces_data, indent=2)) else: - items = resources.get("items", []) - if items: + if spaces: table_data = [] - for item in items: + for space in spaces: + # Extract status conditions from raw resource + available = "" + progressing = "" + degraded = "" + + if space.status and 'conditions' in space.status: + conditions = {c['type']: c['status'] for c in space.status['conditions']} + available = conditions.get('Available', '') + progressing = conditions.get('Progressing', '') + degraded = conditions.get('Degraded', '') + table_data.append([ - item["metadata"]["name"], - item["metadata"]["namespace"], - item.get("status", {}).get("phase", "Unknown") + space.config.name, + namespace, + available, + progressing, + degraded ]) - click.echo(tabulate(table_data, headers=["NAME", "NAMESPACE", "STATUS"])) + click.echo(tabulate(table_data, headers=["NAME", "NAMESPACE", "AVAILABLE", "PROGRESSING", "DEGRADED"])) else: click.echo("No spaces found") except Exception as e: @@ -66,17 +78,16 @@ def space_list(namespace, output): @click.option("--output", "-o", type=click.Choice(["yaml", "json"]), default="yaml") def space_describe(name, namespace, output): """Describe a space resource.""" - k8s_client = KubernetesClient() - try: - resource = k8s_client.get_space(namespace, name) - resource["metadata"].pop('managedFields', None) + current_space = HPSpace.get(name=name, namespace=namespace) + + # Combine config and raw resource data + current_space.raw_resource.get('metadata', {}).pop('managedFields', None) if output == "json": - click.echo(json.dumps(resource, indent=2)) + click.echo(json.dumps(current_space.raw_resource, indent=2)) else: - import yaml - click.echo(yaml.dump(resource, default_flow_style=False)) + click.echo(yaml.dump(current_space.raw_resource, default_flow_style=False)) except Exception as e: click.echo(f"Error describing space '{name}': {e}", err=True) @@ -86,10 +97,9 @@ def space_describe(name, namespace, output): @click.option("--namespace", "-n", required=False, default="default", help="Kubernetes namespace") def space_delete(name, namespace): """Delete a space resource.""" - k8s_client = KubernetesClient() - try: - k8s_client.delete_space(namespace, name) + current_space = HPSpace.get(name=name, namespace=namespace) + current_space.delete() click.echo(f"Space '{name}' deleted successfully") except Exception as e: @@ -104,22 +114,16 @@ def space_delete(name, namespace): ) def space_update(version, config): """Update a space resource.""" - k8s_client = KubernetesClient() - try: - name = config["name"] - namespace = config["namespace"] - space_spec = config.get("space_spec", {}) + current_space = HPSpace.get(name=config['name'], namespace=config['namespace']) + if not config.get("display_name"): + config["display_name"] = current_space.config.display_name - k8s_client.patch_space( - namespace=namespace, - name=name, - body=space_spec - ) + current_space.update(**config) - click.echo(f"Space '{name}' updated successfully") + click.echo(f"Space '{current_space.config.name}' updated successfully") except Exception as e: - click.echo(f"Error updating space '{name}': {e}", err=True) + click.echo(f"Error updating space: {e}", err=True) @click.command("hyp-space") @@ -127,16 +131,9 @@ def space_update(version, config): @click.option("--namespace", "-n", required=False, default="default", help="Kubernetes namespace") def space_start(name, namespace): """Start a space resource.""" - k8s_client = KubernetesClient() - try: - # Patch the resource to set desired status to "Running" - patch_body = {"spec": {"desiredStatus": "Running"}} - k8s_client.patch_space( - namespace=namespace, - name=name, - body=patch_body - ) + current_space = HPSpace.get(name=name, namespace=namespace) + current_space.start() click.echo(f"Space '{name}' start requested") except Exception as e: @@ -148,16 +145,9 @@ def space_start(name, namespace): @click.option("--namespace", "-n", required=False, default="default", help="Kubernetes namespace") def space_stop(name, namespace): """Stop a space resource.""" - k8s_client = KubernetesClient() - try: - # Patch the resource to set desired status to "Stopped" - patch_body = {"spec": {"desiredStatus": "Stopped"}} - k8s_client.patch_space( - namespace=namespace, - name=name, - body=patch_body - ) + current_space = HPSpace.get(name=name, namespace=namespace) + current_space.stop() click.echo(f"Space '{name}' stop requested") except Exception as e: @@ -167,31 +157,13 @@ def space_stop(name, namespace): @click.command("hyp-space") @click.option("--name", required=True, help="Name of the space") @click.option("--namespace", "-n", required=False, default="default", help="Kubernetes namespace") -def space_get_logs(name, namespace): +@click.option("--pod-name", required=False, help="Name of the pod to get logs from") +@click.option("--container", required=False, help="Name of the container to get logs from") +def space_get_logs(name, namespace, pod_name, container): """Get logs for a space resource.""" - k8s_client = KubernetesClient() - try: - # Get pods associated with the space - pods = k8s_client.list_pods_with_labels( - namespace=namespace, - label_selector=f"sagemaker.aws.com/space-name={name}" - ) - - if not pods.items: - click.echo(f"No pods found for space '{name}'") - return - - # Get logs from the first pod - pod_name = pods.items[0].metadata.name - logs = k8s_client.get_logs_for_pod( - pod_name=pod_name, - namespace=namespace, - ) - + current_space = HPSpace.get(name=name, namespace=namespace) + logs = current_space.get_logs(pod_name=pod_name, container=container) click.echo(logs) except Exception as e: click.echo(f"Error getting logs for space '{name}': {e}", err=True) - - - diff --git a/src/sagemaker/hyperpod/cli/space_utils.py b/src/sagemaker/hyperpod/cli/space_utils.py index 5cd2948d..f9a27a6e 100644 --- a/src/sagemaker/hyperpod/cli/space_utils.py +++ b/src/sagemaker/hyperpod/cli/space_utils.py @@ -181,9 +181,17 @@ def wrapped_func(*args, **kwargs): filtered_kwargs[key] = value + # For update operations, add temporary display_name if not provided to pass validation + is_update_and_display_name_not_exist = False + if is_update and 'display_name' not in filtered_kwargs: + filtered_kwargs['display_name'] = 'dummy' + is_update_and_display_name_not_exist = True + try: flat = Model(**filtered_kwargs) - domain_config = flat.to_domain() + config_dict = flat.model_dump(exclude_none=True, by_alias=True) + if is_update_and_display_name_not_exist: + config_dict['display_name'] = None except ValidationError as e: error_messages = [] for err in e.errors(): @@ -195,7 +203,7 @@ def wrapped_func(*args, **kwargs): f"Configuration validation errors:\n" + "\n".join(error_messages) ) - return func(version, domain_config) + return func(version, config_dict) # 2) inject click options from JSON Schema wrapped_func = click.option( diff --git a/src/sagemaker/hyperpod/space/hyperpod_space.py b/src/sagemaker/hyperpod/space/hyperpod_space.py index 5982e1b7..fe3f3c5f 100644 --- a/src/sagemaker/hyperpod/space/hyperpod_space.py +++ b/src/sagemaker/hyperpod/space/hyperpod_space.py @@ -1,6 +1,6 @@ import logging import yaml -from typing import List, Optional, ClassVar, Dict +from typing import List, Optional, ClassVar, Dict, Any from pydantic import BaseModel, Field, ConfigDict from kubernetes import client, config from kubernetes.client.rest import ApiException @@ -38,12 +38,37 @@ class HPSpace(BaseModel): config: SpaceConfig = Field( description="The space configuration using the template model" ) + + raw_resource: Optional[Dict[str, Any]] = Field( + default=None, + description="The complete Kubernetes resource data including apiVersion, kind, metadata, and status" + ) @classmethod def get_logger(cls): """Get logger for the class.""" return logging.getLogger(__name__) + @property + def api_version(self) -> Optional[str]: + """Get the apiVersion from the Kubernetes resource.""" + return self.raw_resource.get("apiVersion") if self.raw_resource else None + + @property + def kind(self) -> Optional[str]: + """Get the kind from the Kubernetes resource.""" + return self.raw_resource.get("kind") if self.raw_resource else None + + @property + def metadata(self) -> Optional[Dict[str, Any]]: + """Get the metadata from the Kubernetes resource.""" + return self.raw_resource.get("metadata") if self.raw_resource else None + + @property + def status(self) -> Optional[Dict[str, Any]]: + """Get the status from the Kubernetes resource.""" + return self.raw_resource.get("status") if self.raw_resource else None + @classmethod def verify_kube_config(cls): """Verify and load Kubernetes configuration.""" @@ -55,24 +80,6 @@ def verify_kube_config(cls): except Exception as e: raise RuntimeError(f"Failed to load kubeconfig: {e}") - def space_exists(self): - """Check if the space already exists""" - custom_api = client.CustomObjectsApi() - try: - custom_api.get_namespaced_custom_object( - group=SPACE_GROUP, - version=SPACE_VERSION, - namespace=self.config.namespace, - plural=SPACE_PLURAL, - name=self.config.name - ) - return True - except ApiException as e: - if e.status == 404: - return False - # re-raise if exception is not 404 (Not found) - raise - @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "create_space") def create(self, debug: bool = False): """Create and submit the HyperPod Space to the Kubernetes cluster. @@ -88,10 +95,6 @@ def create(self, debug: bool = False): logger = self.get_logger() logger = setup_logging(logger, debug) - if self.space_exists(): - logger.info(f"HyperPod Space '{self.config.name}' already exists in namespace '{self.config.namespace}'") - return - # Convert config to domain model domain_config = self.config.to_domain() config_body = domain_config["space_spec"] @@ -160,6 +163,7 @@ def list(cls, namespace: Optional[str] = None) -> List["HPSpace"]: space = cls( config=space_config, + raw_resource=item ) spaces.append(space) @@ -202,6 +206,7 @@ def get(cls, name: str, namespace: str = "default") -> "HPSpace": return cls( config=space_config, + raw_resource=response ) except Exception as e: handle_exception(e, name, namespace) @@ -216,10 +221,6 @@ def delete(self): self.verify_kube_config() logger = self.get_logger() - if not self.space_exists(): - logger.info(f"HyperPod Space '{self.config.name}' does not exist in namespace '{self.config.namespace}'") - return - custom_api = client.CustomObjectsApi() try: @@ -248,16 +249,12 @@ def update(self, **kwargs): self.verify_kube_config() logger = self.get_logger() - if not self.space_exists(): - logger.info(f"HyperPod Space '{self.config.name}' does not exist in namespace '{self.config.namespace}'") - return - custom_api = client.CustomObjectsApi() - # Update the local config - for key, value in kwargs.items(): - if hasattr(self.config, key): - setattr(self.config, key, value) + # Update space config with the input config + current_config = self.config.model_dump(by_alias=True) + current_config.update(kwargs) + self.config = SpaceConfig(**current_config) # Convert to domain model and extract spec domain_config = self.config.to_domain() @@ -295,17 +292,13 @@ def list_pods(self) -> List[str]: """ self.verify_kube_config() logger = self.get_logger() - - if not self.space_exists(): - logger.info(f"HyperPod Space '{self.config.name}' does not exist in namespace '{self.config.namespace}'") - return [] v1 = client.CoreV1Api() try: pods = v1.list_namespaced_pod( namespace=self.config.namespace, - label_selector=f"sagemaker.aws.com/space-name={self.config.name}" + label_selector=f"{SPACE_GROUP}/workspaceName={self.config.name}" ) return [pod.metadata.name for pod in pods.items] except Exception as e: @@ -324,10 +317,6 @@ def get_logs(self, pod_name: Optional[str] = None, container: Optional[str] = No """ self.verify_kube_config() logger = self.get_logger() - - if not self.space_exists(): - logger.info(f"HyperPod Space '{self.config.name}' does not exist in namespace '{self.config.namespace}'") - return "" if not pod_name: pods = self.list_pods() diff --git a/test/unit_tests/cli/test_space.py b/test/unit_tests/cli/test_space.py index 07851f9b..2341c0a4 100644 --- a/test/unit_tests/cli/test_space.py +++ b/test/unit_tests/cli/test_space.py @@ -20,11 +20,11 @@ class TestSpaceCommands: def setup_method(self): self.runner = CliRunner() - self.mock_k8s_client = Mock() + self.mock_hp_space = Mock() + @patch('sagemaker.hyperpod.cli.commands.space.HPSpace') @patch('sagemaker.hyperpod.cli.space_utils.load_schema_for_version') - @patch('sagemaker.hyperpod.cli.commands.space.KubernetesClient') - def test_space_create_success(self, mock_k8s_client_class, mock_load_schema): + def test_space_create_success(self, mock_load_schema, mock_hp_space_class): """Test successful space creation""" # Mock schema loading mock_load_schema.return_value = { @@ -36,6 +36,10 @@ def test_space_create_success(self, mock_k8s_client_class, mock_load_schema): "required": ["name", "display_name"] } + # Mock HPSpace instance + mock_hp_space_instance = Mock() + mock_hp_space_class.return_value = mock_hp_space_instance + # Mock model registry mock_model = Mock() mock_model.return_value = Mock() @@ -46,21 +50,21 @@ def test_space_create_success(self, mock_k8s_client_class, mock_load_schema): "space_spec": {"spec": {"image": "test-image"}} } - # Mock KubernetesClient - mock_k8s_instance = Mock() - mock_k8s_client_class.return_value = mock_k8s_instance - with patch('hyperpod_space_template.registry.SCHEMA_REGISTRY', {'1.0': mock_model}): - result = self.runner.invoke(space_create, [ - '--version', '1.0', - '--name', 'test-space', - '--display-name', 'Test Space', - '--namespace', 'test-ns' - ]) + with patch('sagemaker.hyperpod.cli.commands.space.SpaceConfig') as mock_space_config: + mock_space_config.return_value.name = "test-space" + mock_space_config.return_value.namespace = "test-ns" + + result = self.runner.invoke(space_create, [ + '--version', '1.0', + '--name', 'test-space', + '--display-name', 'Test Space', + '--namespace', 'test-ns' + ]) assert result.exit_code == 0 assert "Space 'test-space' created successfully" in result.output - mock_k8s_instance.create_space.assert_called_once() + mock_hp_space_instance.create.assert_called_once() @patch('sagemaker.hyperpod.cli.space_utils.load_schema_for_version') def test_space_create_missing_required_args(self, mock_load_schema): @@ -74,12 +78,12 @@ def test_space_create_missing_required_args(self, mock_load_schema): assert result.exit_code != 0 assert 'Missing option' in result.output - @patch('sagemaker.hyperpod.cli.commands.space.KubernetesClient') - def test_space_create_k8s_error(self, mock_k8s_client_class): + @patch('sagemaker.hyperpod.cli.commands.space.HPSpace') + def test_space_create_hp_space_error(self, mock_hp_space_class): """Test space creation error handling""" - mock_k8s_instance = Mock() - mock_k8s_instance.create_space.side_effect = Exception("Creation failed") - mock_k8s_client_class.return_value = mock_k8s_instance + mock_hp_space_instance = Mock() + mock_hp_space_instance.create.side_effect = Exception("Creation failed") + mock_hp_space_class.return_value = mock_hp_space_instance mock_model = Mock() mock_model.return_value = Mock() @@ -110,23 +114,27 @@ def test_space_create_k8s_error(self, mock_k8s_client_class): assert result.exit_code == 0 assert "Error creating space: Creation failed" in result.output - @patch('sagemaker.hyperpod.cli.commands.space.KubernetesClient') - def test_space_list_table_output(self, mock_k8s_client_class): + @patch('sagemaker.hyperpod.cli.commands.space.HPSpace') + def test_space_list_table_output(self, mock_hp_space_class): """Test space list with table output""" - mock_k8s_instance = Mock() - mock_k8s_instance.list_spaces.return_value = { - "items": [ - { - "metadata": {"name": "space1", "namespace": "ns1"}, - "status": {"phase": "Running"} - }, - { - "metadata": {"name": "space2", "namespace": "ns2"}, - "status": {"phase": "Stopped"} - } - ] - } - mock_k8s_client_class.return_value = mock_k8s_instance + # Mock HPSpace instances with config and status + mock_space1 = Mock() + mock_space1.config.name = "space1" + mock_space1.status = {"conditions": [ + {"type": "Available", "status": "True"}, + {"type": "Progressing", "status": "False"}, + {"type": "Degraded", "status": "False"} + ]} + + mock_space2 = Mock() + mock_space2.config.name = "space2" + mock_space2.status = {"conditions": [ + {"type": "Available", "status": "False"}, + {"type": "Progressing", "status": "True"}, + {"type": "Degraded", "status": "False"} + ]} + + mock_hp_space_class.list.return_value = [mock_space1, mock_space2] result = self.runner.invoke(space_list, [ '--namespace', 'test-ns', @@ -136,19 +144,16 @@ def test_space_list_table_output(self, mock_k8s_client_class): assert result.exit_code == 0 assert "space1" in result.output assert "space2" in result.output - mock_k8s_instance.list_spaces.assert_called_once_with('test-ns') + mock_hp_space_class.list.assert_called_once_with(namespace='test-ns') - @patch('sagemaker.hyperpod.cli.commands.space.KubernetesClient') - def test_space_list_json_output(self, mock_k8s_client_class): + @patch('sagemaker.hyperpod.cli.commands.space.HPSpace') + def test_space_list_json_output(self, mock_hp_space_class): """Test space list with JSON output""" - mock_resources = { - "items": [ - {"metadata": {"name": "space1", "namespace": "ns1"}} - ] - } - mock_k8s_instance = Mock() - mock_k8s_instance.list_spaces.return_value = mock_resources - mock_k8s_client_class.return_value = mock_k8s_instance + # Mock HPSpace instances + mock_space1 = Mock() + mock_space1.config.model_dump.return_value = {"name": "space1", "namespace": "ns1"} + + mock_hp_space_class.list.return_value = [mock_space1] result = self.runner.invoke(space_list, [ '--namespace', 'test-ns', @@ -157,14 +162,12 @@ def test_space_list_json_output(self, mock_k8s_client_class): assert result.exit_code == 0 output_json = json.loads(result.output) - assert output_json == mock_resources + assert output_json == [{"name": "space1", "namespace": "ns1"}] - @patch('sagemaker.hyperpod.cli.commands.space.KubernetesClient') - def test_space_list_empty(self, mock_k8s_client_class): + @patch('sagemaker.hyperpod.cli.commands.space.HPSpace') + def test_space_list_empty(self, mock_hp_space_class): """Test space list with no items""" - mock_k8s_instance = Mock() - mock_k8s_instance.list_spaces.return_value = {"items": []} - mock_k8s_client_class.return_value = mock_k8s_instance + mock_hp_space_class.list.return_value = [] result = self.runner.invoke(space_list, [ '--namespace', 'test-ns' @@ -173,12 +176,10 @@ def test_space_list_empty(self, mock_k8s_client_class): assert result.exit_code == 0 assert "No spaces found" in result.output - @patch('sagemaker.hyperpod.cli.commands.space.KubernetesClient') - def test_space_list_error(self, mock_k8s_client_class): + @patch('sagemaker.hyperpod.cli.commands.space.HPSpace') + def test_space_list_error(self, mock_hp_space_class): """Test space list error handling""" - mock_k8s_instance = Mock() - mock_k8s_instance.list_spaces.side_effect = Exception("List failed") - mock_k8s_client_class.return_value = mock_k8s_instance + mock_hp_space_class.list.side_effect = Exception("List failed") result = self.runner.invoke(space_list, [ '--namespace', 'test-ns' @@ -187,13 +188,13 @@ def test_space_list_error(self, mock_k8s_client_class): assert result.exit_code == 0 assert "Error listing spaces: List failed" in result.output - @patch('sagemaker.hyperpod.cli.commands.space.KubernetesClient') - def test_space_describe_yaml_output(self, mock_k8s_client_class): + @patch('sagemaker.hyperpod.cli.commands.space.HPSpace') + def test_space_describe_yaml_output(self, mock_hp_space_class): """Test space describe with YAML output""" mock_resource = {"metadata": {"name": "test-space"}} - mock_k8s_instance = Mock() - mock_k8s_instance.get_space.return_value = mock_resource - mock_k8s_client_class.return_value = mock_k8s_instance + # mock_hp_space_instance = Mock() + # mock_hp_space_instance.raw_resource = mock_resource + # mock_hp_space_class.get.return_value = mock_hp_space_instance with patch('yaml.dump') as mock_yaml_dump: mock_yaml_dump.return_value = "yaml_output" @@ -204,15 +205,15 @@ def test_space_describe_yaml_output(self, mock_k8s_client_class): assert result.exit_code == 0 assert "yaml_output" in result.output - mock_k8s_instance.get_space.assert_called_once_with('test-ns', 'test-space') + mock_hp_space_class.get.assert_called_once_with(name='test-space', namespace='test-ns') - @patch('sagemaker.hyperpod.cli.commands.space.KubernetesClient') - def test_space_describe_json_output(self, mock_k8s_client_class): + @patch('sagemaker.hyperpod.cli.commands.space.HPSpace') + def test_space_describe_json_output(self, mock_hp_space_class): """Test space describe with JSON output""" mock_resource = {"metadata": {"name": "test-space"}} - mock_k8s_instance = Mock() - mock_k8s_instance.get_space.return_value = mock_resource - mock_k8s_client_class.return_value = mock_k8s_instance + mock_hp_space_instance = Mock() + mock_hp_space_instance.raw_resource = mock_resource + mock_hp_space_class.get.return_value = mock_hp_space_instance result = self.runner.invoke(space_describe, [ '--name', 'test-space', @@ -224,12 +225,10 @@ def test_space_describe_json_output(self, mock_k8s_client_class): output_json = json.loads(result.output) assert output_json == mock_resource - @patch('sagemaker.hyperpod.cli.commands.space.KubernetesClient') - def test_space_describe_k8s_error(self, mock_k8s_client_class): + @patch('sagemaker.hyperpod.cli.commands.space.HPSpace') + def test_space_describe_hp_space_error(self, mock_hp_space_class): """Test space describe error handling""" - mock_k8s_instance = Mock() - mock_k8s_instance.get_space.side_effect = Exception("Describe failed") - mock_k8s_client_class.return_value = mock_k8s_instance + mock_hp_space_class.get.side_effect = Exception("Describe failed") result = self.runner.invoke(space_describe, [ '--name', 'test-space', @@ -239,11 +238,11 @@ def test_space_describe_k8s_error(self, mock_k8s_client_class): assert result.exit_code == 0 assert "Error describing space 'test-space': Describe failed" in result.output - @patch('sagemaker.hyperpod.cli.commands.space.KubernetesClient') - def test_space_delete_success(self, mock_k8s_client_class): + @patch('sagemaker.hyperpod.cli.commands.space.HPSpace') + def test_space_delete_success(self, mock_hp_space_class): """Test successful space deletion""" - mock_k8s_instance = Mock() - mock_k8s_client_class.return_value = mock_k8s_instance + mock_hp_space_instance = Mock() + mock_hp_space_class.get.return_value = mock_hp_space_instance result = self.runner.invoke(space_delete, [ '--name', 'test-space', @@ -252,14 +251,13 @@ def test_space_delete_success(self, mock_k8s_client_class): assert result.exit_code == 0 assert "Space 'test-space' deleted successfully" in result.output - mock_k8s_instance.delete_space.assert_called_once_with('test-ns', 'test-space') + mock_hp_space_class.get.assert_called_once_with(name='test-space', namespace='test-ns') + mock_hp_space_instance.delete.assert_called_once() - @patch('sagemaker.hyperpod.cli.commands.space.KubernetesClient') - def test_space_delete_k8s_error(self, mock_k8s_client_class): + @patch('sagemaker.hyperpod.cli.commands.space.HPSpace') + def test_space_delete_hp_space_error(self, mock_hp_space_class): """Test space delete error handling""" - mock_k8s_instance = Mock() - mock_k8s_instance.delete_space.side_effect = Exception("Delete failed") - mock_k8s_client_class.return_value = mock_k8s_instance + mock_hp_space_class.get.side_effect = Exception("Delete failed") result = self.runner.invoke(space_delete, [ '--name', 'test-space', @@ -269,9 +267,9 @@ def test_space_delete_k8s_error(self, mock_k8s_client_class): assert result.exit_code == 0 assert "Error deleting space 'test-space': Delete failed" in result.output + @patch('sagemaker.hyperpod.cli.commands.space.HPSpace') @patch('sagemaker.hyperpod.cli.space_utils.load_schema_for_version') - @patch('sagemaker.hyperpod.cli.commands.space.KubernetesClient') - def test_space_update_success(self, mock_k8s_client_class, mock_load_schema): + def test_space_update_success(self, mock_load_schema, mock_hp_space_class): """Test successful space update""" # Mock schema loading mock_load_schema.return_value = { @@ -283,6 +281,12 @@ def test_space_update_success(self, mock_k8s_client_class, mock_load_schema): "required": ["name"] } + # Mock HPSpace instance + mock_hp_space_instance = Mock() + mock_hp_space_instance.config.name = "test-space" + mock_hp_space_instance.config.display_name = "Test Space" + mock_hp_space_class.get.return_value = mock_hp_space_instance + # Mock model registry mock_model = Mock() mock_model.return_value = Mock() @@ -292,10 +296,6 @@ def test_space_update_success(self, mock_k8s_client_class, mock_load_schema): "space_spec": {"spec": {"image": "updated-image"}} } - # Mock KubernetesClient - mock_k8s_instance = Mock() - mock_k8s_client_class.return_value = mock_k8s_instance - with patch('hyperpod_space_template.registry.SCHEMA_REGISTRY', {'1.0': mock_model}): result = self.runner.invoke(space_update, [ '--version', '1.0', @@ -306,14 +306,15 @@ def test_space_update_success(self, mock_k8s_client_class, mock_load_schema): assert result.exit_code == 0 assert "Space 'test-space' updated successfully" in result.output - mock_k8s_instance.patch_space.assert_called_once() + mock_hp_space_class.get.assert_called_once_with(name='test-space', namespace='test-ns') + mock_hp_space_instance.update.assert_called_once() - @patch('sagemaker.hyperpod.cli.commands.space.KubernetesClient') - def test_space_update_k8_error(self, mock_k8s_client_class): + @patch('sagemaker.hyperpod.cli.commands.space.HPSpace') + def test_space_update_hp_space_error(self, mock_hp_space_class): """Test space update error handling""" - mock_k8s_instance = Mock() - mock_k8s_instance.patch_space.side_effect = Exception("Update failed") - mock_k8s_client_class.return_value = mock_k8s_instance + mock_hp_space_instance = Mock() + mock_hp_space_instance.update.side_effect = Exception("Update failed") + mock_hp_space_class.get.return_value = mock_hp_space_instance mock_model = Mock() mock_model.return_value = Mock() @@ -341,13 +342,13 @@ def test_space_update_k8_error(self, mock_k8s_client_class): ]) assert result.exit_code == 0 - assert "Error updating space 'test-space': Update failed" in result.output + assert "Error updating space: Update failed" in result.output - @patch('sagemaker.hyperpod.cli.commands.space.KubernetesClient') - def test_space_start_success(self, mock_k8s_client_class): + @patch('sagemaker.hyperpod.cli.commands.space.HPSpace') + def test_space_start_success(self, mock_hp_space_class): """Test successful space start""" - mock_k8s_instance = Mock() - mock_k8s_client_class.return_value = mock_k8s_instance + mock_hp_space_instance = Mock() + mock_hp_space_class.get.return_value = mock_hp_space_instance result = self.runner.invoke(space_start, [ '--name', 'test-space', @@ -356,18 +357,15 @@ def test_space_start_success(self, mock_k8s_client_class): assert result.exit_code == 0 assert "Space 'test-space' start requested" in result.output - mock_k8s_instance.patch_space.assert_called_once_with( - namespace='test-ns', - name='test-space', - body={"spec": {"desiredStatus": "Running"}} - ) - - @patch('sagemaker.hyperpod.cli.commands.space.KubernetesClient') - def test_space_start_k8s_error(self, mock_k8s_client_class): + mock_hp_space_class.get.assert_called_once_with(name='test-space', namespace='test-ns') + mock_hp_space_instance.start.assert_called_once() + + @patch('sagemaker.hyperpod.cli.commands.space.HPSpace') + def test_space_start_hp_space_error(self, mock_hp_space_class): """Test space start error handling""" - mock_k8s_instance = Mock() - mock_k8s_instance.patch_space.side_effect = Exception("Start failed") - mock_k8s_client_class.return_value = mock_k8s_instance + mock_hp_space_instance = Mock() + mock_hp_space_instance.start.side_effect = Exception("Start failed") + mock_hp_space_class.get.return_value = mock_hp_space_instance result = self.runner.invoke(space_start, [ '--name', 'test-space', @@ -377,11 +375,11 @@ def test_space_start_k8s_error(self, mock_k8s_client_class): assert result.exit_code == 0 assert "Error starting space 'test-space': Start failed" in result.output - @patch('sagemaker.hyperpod.cli.commands.space.KubernetesClient') - def test_space_stop_success(self, mock_k8s_client_class): + @patch('sagemaker.hyperpod.cli.commands.space.HPSpace') + def test_space_stop_success(self, mock_hp_space_class): """Test successful space stop""" - mock_k8s_instance = Mock() - mock_k8s_client_class.return_value = mock_k8s_instance + mock_hp_space_instance = Mock() + mock_hp_space_class.get.return_value = mock_hp_space_instance result = self.runner.invoke(space_stop, [ '--name', 'test-space', @@ -390,18 +388,15 @@ def test_space_stop_success(self, mock_k8s_client_class): assert result.exit_code == 0 assert "Space 'test-space' stop requested" in result.output - mock_k8s_instance.patch_space.assert_called_once_with( - namespace='test-ns', - name='test-space', - body={"spec": {"desiredStatus": "Stopped"}} - ) - - @patch('sagemaker.hyperpod.cli.commands.space.KubernetesClient') - def test_space_stop_k8s_error(self, mock_k8s_client_class): + mock_hp_space_class.get.assert_called_once_with(name='test-space', namespace='test-ns') + mock_hp_space_instance.stop.assert_called_once() + + @patch('sagemaker.hyperpod.cli.commands.space.HPSpace') + def test_space_stop_hp_space_error(self, mock_hp_space_class): """Test space stop error handling""" - mock_k8s_instance = Mock() - mock_k8s_instance.patch_space.side_effect = Exception("Stop failed") - mock_k8s_client_class.return_value = mock_k8s_instance + mock_hp_space_instance = Mock() + mock_hp_space_instance.stop.side_effect = Exception("Stop failed") + mock_hp_space_class.get.return_value = mock_hp_space_instance result = self.runner.invoke(space_stop, [ '--name', 'test-space', @@ -411,18 +406,12 @@ def test_space_stop_k8s_error(self, mock_k8s_client_class): assert result.exit_code == 0 assert "Error stopping space 'test-space': Stop failed" in result.output - @patch('sagemaker.hyperpod.cli.commands.space.KubernetesClient') - def test_space_get_logs_success(self, mock_k8s_client_class): + @patch('sagemaker.hyperpod.cli.commands.space.HPSpace') + def test_space_get_logs_success(self, mock_hp_space_class): """Test successful space get logs""" - mock_pod = Mock() - mock_pod.metadata.name = "test-pod" - mock_pods = Mock() - mock_pods.items = [mock_pod] - - mock_k8s_instance = Mock() - mock_k8s_instance.list_pods_with_labels.return_value = mock_pods - mock_k8s_instance.get_logs_for_pod.return_value = "test logs" - mock_k8s_client_class.return_value = mock_k8s_instance + mock_hp_space_instance = Mock() + mock_hp_space_instance.get_logs.return_value = "test logs" + mock_hp_space_class.get.return_value = mock_hp_space_instance result = self.runner.invoke(space_get_logs, [ '--name', 'test-space', @@ -431,24 +420,15 @@ def test_space_get_logs_success(self, mock_k8s_client_class): assert result.exit_code == 0 assert "test logs" in result.output - mock_k8s_instance.list_pods_with_labels.assert_called_once_with( - namespace='test-ns', - label_selector='sagemaker.aws.com/space-name=test-space' - ) - mock_k8s_instance.get_logs_for_pod.assert_called_once_with( - pod_name='test-pod', - namespace='test-ns' - ) - - @patch('sagemaker.hyperpod.cli.commands.space.KubernetesClient') - def test_space_get_logs_no_pods(self, mock_k8s_client_class): - """Test space get logs with no pods""" - mock_pods = Mock() - mock_pods.items = [] + mock_hp_space_class.get.assert_called_once_with(name='test-space', namespace='test-ns') + mock_hp_space_instance.get_logs.assert_called_once_with(pod_name=None, container=None) - mock_k8s_instance = Mock() - mock_k8s_instance.list_pods_with_labels.return_value = mock_pods - mock_k8s_client_class.return_value = mock_k8s_instance + @patch('sagemaker.hyperpod.cli.commands.space.HPSpace') + def test_space_get_logs_no_pods(self, mock_hp_space_class): + """Test space get logs with no pods""" + mock_hp_space_instance = Mock() + mock_hp_space_instance.get_logs.return_value = "" + mock_hp_space_class.get.return_value = mock_hp_space_instance result = self.runner.invoke(space_get_logs, [ '--name', 'test-space', @@ -456,14 +436,14 @@ def test_space_get_logs_no_pods(self, mock_k8s_client_class): ]) assert result.exit_code == 0 - assert "No pods found for space 'test-space'" in result.output + # HPSpace.get_logs() handles the "no pods" case internally - @patch('sagemaker.hyperpod.cli.commands.space.KubernetesClient') - def test_space_get_logs_k8s_error(self, mock_k8s_client_class): + @patch('sagemaker.hyperpod.cli.commands.space.HPSpace') + def test_space_get_logs_hp_space_error(self, mock_hp_space_class): """Test space get logs error handling""" - mock_k8s_instance = Mock() - mock_k8s_instance.list_pods_with_labels.side_effect = Exception("List pod failed") - mock_k8s_client_class.return_value = mock_k8s_instance + mock_hp_space_instance = Mock() + mock_hp_space_instance.get_logs.side_effect = Exception("Get logs failed") + mock_hp_space_class.get.return_value = mock_hp_space_instance result = self.runner.invoke(space_get_logs, [ '--name', 'test-space', @@ -471,7 +451,7 @@ def test_space_get_logs_k8s_error(self, mock_k8s_client_class): ]) assert result.exit_code == 0 - assert "Error getting logs for space 'test-space': List pod failed" in result.output + assert "Error getting logs for space 'test-space': Get logs failed" in result.output def test_missing_required_arguments(self): """Test commands with missing required arguments""" diff --git a/test/unit_tests/cli/test_space_utils.py b/test/unit_tests/cli/test_space_utils.py index b8529f14..389949f5 100644 --- a/test/unit_tests/cli/test_space_utils.py +++ b/test/unit_tests/cli/test_space_utils.py @@ -3,7 +3,7 @@ import click from click.testing import CliRunner from unittest.mock import Mock, patch -from pydantic import ValidationError +from pydantic import ValidationError, BaseModel from sagemaker.hyperpod.cli.space_utils import load_schema_for_version, generate_click_command @@ -70,11 +70,9 @@ def test_version_handling(self, mock_load_schema): schema = {'properties': {}, 'required': []} mock_load_schema.return_value = schema - class DummyModel: - def __init__(self, **kwargs): - pass - def to_domain(self): - return self + class DummyModel(BaseModel): + class Config: + extra = 'allow' registry = {'2.0': DummyModel} @@ -108,18 +106,16 @@ def test_resources_building(self, mock_load_schema): } mock_load_schema.return_value = schema - class DummyModel: - def __init__(self, **kwargs): - self.resources = kwargs.get('resources') - def to_domain(self): - return self + class DummyModel(BaseModel): + class Config: + extra = 'allow' registry = {'1.0': DummyModel} @click.command() @generate_click_command(registry=registry, schema_pkg="hyperpod_space_template") def cmd(version, domain_config): - click.echo(json.dumps(domain_config.resources)) + click.echo(json.dumps(domain_config.get('resources'))) # Test with custom CPU and memory result = self.runner.invoke(cmd, ['--cpu', '1000m', '--memory', '1Gi']) @@ -155,11 +151,9 @@ def test_type_conversion(self, mock_load_schema): } mock_load_schema.return_value = schema - class DummyModel: - def __init__(self, **kwargs): - self.__dict__.update(kwargs) - def to_domain(self): - return self + class DummyModel(BaseModel): + class Config: + extra = 'allow' registry = {'1.0': DummyModel} @@ -167,10 +161,10 @@ def to_domain(self): @generate_click_command(registry=registry, schema_pkg="hyperpod_space_template") def cmd(version, domain_config): click.echo(json.dumps({ - 'name': domain_config.name, - 'desired_status': getattr(domain_config, 'desired_status', None), - 'storage_size': getattr(domain_config, 'storage_size', None), - 'port': getattr(domain_config, 'port', None) + 'name': domain_config.get('name'), + 'desired_status': domain_config.get('desired_status'), + 'storage_size': domain_config.get('storage_size'), + 'port': domain_config.get('port') })) # Test string and enum types @@ -205,18 +199,16 @@ def test_successful_command_execution(self, mock_load_schema): } mock_load_schema.return_value = schema - class DummyModel: - def __init__(self, **kwargs): - self.__dict__.update(kwargs) - def to_domain(self): - return self + class DummyModel(BaseModel): + class Config: + extra = 'allow' registry = {'1.0': DummyModel} @click.command() @generate_click_command(registry=registry, schema_pkg="hyperpod_space_template") def cmd(version, domain_config): - click.echo(f'success: {domain_config.name}') + click.echo(f'success: {domain_config.get("name")}') # Test successful execution result = self.runner.invoke(cmd, ['--name', 'test-space']) @@ -237,11 +229,9 @@ def test_immutable_fields_excluded_in_update(self, mock_load_schema): } mock_load_schema.return_value = schema - class DummyModel: - def __init__(self, **kwargs): - self.__dict__.update(kwargs) - def to_domain(self): - return self + class DummyModel(BaseModel): + class Config: + extra = 'allow' registry = {'1.0': DummyModel} @@ -277,12 +267,9 @@ def test_filtered_kwargs(self, mock_load_schema): } mock_load_schema.return_value = schema - class DummyModel: - def __init__(self, **kwargs): - self.received_kwargs = kwargs - self.__dict__.update(kwargs) - def to_domain(self): - return self + class DummyModel(BaseModel): + class Config: + extra = 'allow' registry = {'1.0': DummyModel} @@ -290,7 +277,7 @@ def to_domain(self): @generate_click_command(registry=registry, schema_pkg="hyperpod_space_template") def cmd(version, domain_config): # Check that None values were filtered out - click.echo(json.dumps(domain_config.received_kwargs)) + click.echo(json.dumps(domain_config)) result = self.runner.invoke(cmd, ['--name', 'test-space']) assert result.exit_code == 0 @@ -305,9 +292,9 @@ def test_default_version_injection(self, mock_load_schema): schema = {'properties': {}, 'required': []} mock_load_schema.return_value = schema - class DummyModel: - def __init__(self, **kwargs): pass - def to_domain(self): return self + class DummyModel(BaseModel): + class Config: + extra = 'allow' registry = {'1.0': DummyModel, '2.0': DummyModel} @@ -323,7 +310,6 @@ def cmd(version, domain_config): # Test custom version result = self.runner.invoke(cmd, ['--version', '2.0']) - print(result.output) assert result.exit_code == 0 assert result.output.strip() == '2.0' @@ -340,11 +326,9 @@ def test_schema_defaults_and_required_fields(self, mock_load_schema): } mock_load_schema.return_value = schema - class DummyModel: - def __init__(self, **kwargs): - self.__dict__.update(kwargs) - def to_domain(self): - return self + class DummyModel(BaseModel): + class Config: + extra = 'allow' registry = {'1.0': DummyModel} @@ -360,7 +344,6 @@ def cmd(version, domain_config): # Test with required field provided result = self.runner.invoke(cmd, ['--name', 'test-space', '--namespace', 'test-ns']) - print(result.output) assert result.exit_code == 0 assert result.output.strip() == 'success' @@ -376,18 +359,16 @@ def test_volume_parsing(self, mock_load_schema): } mock_load_schema.return_value = schema - class DummyModel: - def __init__(self, **kwargs): - self.__dict__.update(kwargs) - def to_domain(self): - return self + class DummyModel(BaseModel): + class Config: + extra = 'allow' registry = {'1.0': DummyModel} @click.command() @generate_click_command(registry=registry, schema_pkg="hyperpod_space_template") def cmd(version, domain_config): - click.echo(json.dumps(getattr(domain_config, 'volumes', None))) + click.echo(json.dumps(domain_config.get('volumes'))) # Test valid volume parsing result = self.runner.invoke(cmd, [ @@ -431,18 +412,16 @@ def test_storage_parsing(self, mock_load_schema): } mock_load_schema.return_value = schema - class DummyModel: - def __init__(self, **kwargs): - self.__dict__.update(kwargs) - def to_domain(self): - return self + class DummyModel(BaseModel): + class Config: + extra = 'allow' registry = {'1.0': DummyModel} @click.command() @generate_click_command(registry=registry, schema_pkg="hyperpod_space_template") def cmd(version, domain_config): - click.echo(json.dumps(getattr(domain_config, 'storage', None))) + click.echo(json.dumps(domain_config.get('storage'))) # Test valid storage parsing result = self.runner.invoke(cmd, [ @@ -475,18 +454,16 @@ def test_container_config_parsing_simple(self, mock_load_schema): } mock_load_schema.return_value = schema - class DummyModel: - def __init__(self, **kwargs): - self.__dict__.update(kwargs) - def to_domain(self): - return self + class DummyModel(BaseModel): + class Config: + extra = 'allow' registry = {'1.0': DummyModel} @click.command() @generate_click_command(registry=registry, schema_pkg="hyperpod_space_template") def cmd(version, domain_config): - click.echo(json.dumps(getattr(domain_config, 'container_config', None))) + click.echo(json.dumps(domain_config.get('container_config'))) # Test valid container config with semicolon format result = self.runner.invoke(cmd, [ @@ -519,11 +496,9 @@ def test_json_object_parsing(self, mock_load_schema): } mock_load_schema.return_value = schema - class DummyModel: - def __init__(self, **kwargs): - self.__dict__.update(kwargs) - def to_domain(self): - return self + class DummyModel(BaseModel): + class Config: + extra = 'allow' registry = {'1.0': DummyModel} @@ -531,8 +506,8 @@ def to_domain(self): @generate_click_command(registry=registry, schema_pkg="hyperpod_space_template") def cmd(version, domain_config): result = { - 'metadata': getattr(domain_config, 'metadata', None), - 'tags': getattr(domain_config, 'tags', None) + 'metadata': domain_config.get('metadata'), + 'tags': domain_config.get('tags') } click.echo(json.dumps(result)) @@ -573,18 +548,16 @@ def test_anyof_type_handling(self, mock_load_schema): } mock_load_schema.return_value = schema - class DummyModel: - def __init__(self, **kwargs): - self.__dict__.update(kwargs) - def to_domain(self): - return self + class DummyModel(BaseModel): + class Config: + extra = 'allow' registry = {'1.0': DummyModel} @click.command() @generate_click_command(registry=registry, schema_pkg="hyperpod_space_template") def cmd(version, domain_config): - click.echo(json.dumps(getattr(domain_config, 'config', None))) + click.echo(json.dumps(domain_config.get('config'))) # Test with JSON object for anyOf type result = self.runner.invoke(cmd, [ @@ -594,3 +567,36 @@ def cmd(version, domain_config): assert result.exit_code == 0 config = json.loads(result.output) assert config['setting'] == 'value' + + @patch('sagemaker.hyperpod.cli.space_utils.load_schema_for_version') + def test_display_name_optional_in_update_mode(self, mock_load_schema): + """Test that display_name is optional in update mode""" + schema = { + 'properties': { + 'name': {'type': 'string'}, + 'display_name': {'type': 'string'}, + 'image': {'type': 'string'} + }, + 'required': ['name', 'display_name'] + } + mock_load_schema.return_value = schema + + class DummyModel(BaseModel): + class Config: + extra = 'allow' + + registry = {'1.0': DummyModel} + + @click.command() + @generate_click_command( + registry=registry, + schema_pkg="hyperpod_space_template", + is_update=True + ) + def cmd(version, domain_config): + click.echo('success') + + # In update mode, display_name should not be required + result = self.runner.invoke(cmd, ['--name', 'test-space']) + assert result.exit_code == 0 + assert 'success' in result.output diff --git a/test/unit_tests/test_hyperpod_space.py b/test/unit_tests/test_hyperpod_space.py index 23ee7553..60689b01 100644 --- a/test/unit_tests/test_hyperpod_space.py +++ b/test/unit_tests/test_hyperpod_space.py @@ -51,12 +51,10 @@ def test_verify_kube_config_already_loaded(self): @patch('sagemaker.hyperpod.space.hyperpod_space.client.CustomObjectsApi') @patch.object(HPSpace, 'verify_kube_config') - @patch.object(HPSpace, 'space_exists') - def test_create_success(self, mock_space_exists, mock_verify_config, mock_custom_api_class): + def test_create_success(self, mock_verify_config, mock_custom_api_class): """Test successful dev space creation""" mock_custom_api = Mock() mock_custom_api_class.return_value = mock_custom_api - mock_space_exists.return_value = False # Mock the config.to_domain() method mock_domain_config = { @@ -72,34 +70,15 @@ def test_create_success(self, mock_space_exists, mock_verify_config, mock_custom self.hp_space.create() mock_verify_config.assert_called_once() - mock_space_exists.assert_called_once() mock_custom_api.create_namespaced_custom_object.assert_called_once() - @patch('sagemaker.hyperpod.space.hyperpod_space.client.CustomObjectsApi') - @patch.object(HPSpace, 'verify_kube_config') - @patch.object(HPSpace, 'space_exists') - def test_create_already_exists(self, mock_space_exists, mock_verify_config, mock_custom_api_class): - """Test dev space creation when resource already exists""" - mock_custom_api = Mock() - mock_custom_api_class.return_value = mock_custom_api - mock_space_exists.return_value = True - - self.hp_space.create() - - mock_verify_config.assert_called_once() - mock_space_exists.assert_called_once() - # Should not call create since resource exists - mock_custom_api.create_namespaced_custom_object.assert_not_called() - @patch('sagemaker.hyperpod.space.hyperpod_space.client.CustomObjectsApi') @patch.object(HPSpace, 'verify_kube_config') @patch('sagemaker.hyperpod.space.hyperpod_space.handle_exception') - @patch.object(HPSpace, 'space_exists') - def test_create_failure(self, mock_space_exists, mock_handle_exception, mock_verify_config, mock_custom_api_class): + def test_create_failure(self, mock_handle_exception, mock_verify_config, mock_custom_api_class): """Test dev space creation failure""" mock_custom_api = Mock() mock_custom_api_class.return_value = mock_custom_api - mock_space_exists.return_value = False # Mock creation failure mock_custom_api.create_namespaced_custom_object.side_effect = Exception("Creation failed") @@ -118,37 +97,7 @@ def test_create_failure(self, mock_space_exists, mock_handle_exception, mock_ver mock_handle_exception.assert_called_once() - def test_space_exists_success(self): - """Test space_exists method when space exists""" - with patch('sagemaker.hyperpod.space.hyperpod_space.client.CustomObjectsApi') as mock_custom_api_class: - mock_custom_api = Mock() - mock_custom_api_class.return_value = mock_custom_api - mock_custom_api.get_namespaced_custom_object.return_value = { - "metadata": {"name": "test-space", "namespace": "test-namespace"} - } - - result = self.hp_space.space_exists() - self.assertTrue(result) - - def test_space_exists_not_found(self): - """Test space_exists method when space doesn't exist""" - with patch('sagemaker.hyperpod.space.hyperpod_space.client.CustomObjectsApi') as mock_custom_api_class: - mock_custom_api = Mock() - mock_custom_api_class.return_value = mock_custom_api - mock_custom_api.get_namespaced_custom_object.side_effect = ApiException(status=404) - - result = self.hp_space.space_exists() - self.assertFalse(result) - - def test_space_exists_api_error(self): - """Test space_exists method with non-404 API error""" - with patch('sagemaker.hyperpod.space.hyperpod_space.client.CustomObjectsApi') as mock_custom_api_class: - mock_custom_api = Mock() - mock_custom_api_class.return_value = mock_custom_api - mock_custom_api.get_namespaced_custom_object.side_effect = ApiException(status=500) - - with self.assertRaises(ApiException): - self.hp_space.space_exists() + @patch('sagemaker.hyperpod.space.hyperpod_space.client.CustomObjectsApi') @patch.object(HPSpace, 'verify_kube_config') @@ -252,17 +201,14 @@ def test_get_failure(self, mock_handle_exception, mock_verify_config, mock_custo @patch('sagemaker.hyperpod.space.hyperpod_space.client.CustomObjectsApi') @patch.object(HPSpace, 'verify_kube_config') - @patch.object(HPSpace, 'space_exists') - def test_delete_success(self, mock_space_exists, mock_verify_config, mock_custom_api_class): + def test_delete_success(self, mock_verify_config, mock_custom_api_class): """Test successful dev space deletion""" mock_custom_api = Mock() mock_custom_api_class.return_value = mock_custom_api - mock_space_exists.return_value = True self.hp_space.delete() mock_verify_config.assert_called_once() - mock_space_exists.assert_called_once() mock_custom_api.delete_namespaced_custom_object.assert_called_once_with( group="workspace.jupyter.org", version="v1alpha1", @@ -271,31 +217,13 @@ def test_delete_success(self, mock_space_exists, mock_verify_config, mock_custom name="test-space" ) - @patch('sagemaker.hyperpod.space.hyperpod_space.client.CustomObjectsApi') - @patch.object(HPSpace, 'verify_kube_config') - @patch.object(HPSpace, 'space_exists') - def test_delete_not_exists(self, mock_space_exists, mock_verify_config, mock_custom_api_class): - """Test dev space deletion when space doesn't exist""" - mock_custom_api = Mock() - mock_custom_api_class.return_value = mock_custom_api - mock_space_exists.return_value = False - - self.hp_space.delete() - - mock_verify_config.assert_called_once() - mock_space_exists.assert_called_once() - # Should not call delete since resource doesn't exist - mock_custom_api.delete_namespaced_custom_object.assert_not_called() - @patch('sagemaker.hyperpod.space.hyperpod_space.client.CustomObjectsApi') @patch.object(HPSpace, 'verify_kube_config') @patch('sagemaker.hyperpod.space.hyperpod_space.handle_exception') - @patch.object(HPSpace, 'space_exists') - def test_delete_failure(self, mock_space_exists, mock_handle_exception, mock_verify_config, mock_custom_api_class): + def test_delete_failure(self, mock_handle_exception, mock_verify_config, mock_custom_api_class): """Test dev space deletion failure""" mock_custom_api = Mock() mock_custom_api_class.return_value = mock_custom_api - mock_space_exists.return_value = True mock_custom_api.delete_namespaced_custom_object.side_effect = Exception("Delete failed") self.hp_space.delete() @@ -304,12 +232,10 @@ def test_delete_failure(self, mock_space_exists, mock_handle_exception, mock_ver @patch('sagemaker.hyperpod.space.hyperpod_space.client.CustomObjectsApi') @patch.object(HPSpace, 'verify_kube_config') - @patch.object(HPSpace, 'space_exists') - def test_update_success(self, mock_space_exists, mock_verify_config, mock_custom_api_class): + def test_update_success(self, mock_verify_config, mock_custom_api_class): """Test successful dev space update""" mock_custom_api = Mock() mock_custom_api_class.return_value = mock_custom_api - mock_space_exists.return_value = True mock_domain_config = { "space_spec": { @@ -329,31 +255,13 @@ def test_update_success(self, mock_space_exists, mock_verify_config, mock_custom body={"spec": {"desiredStatus": "Stopped"}} ) - @patch('sagemaker.hyperpod.space.hyperpod_space.client.CustomObjectsApi') - @patch.object(HPSpace, 'verify_kube_config') - @patch.object(HPSpace, 'space_exists') - def test_update_not_exists(self, mock_space_exists, mock_verify_config, mock_custom_api_class): - """Test dev space update when space doesn't exist""" - mock_custom_api = Mock() - mock_custom_api_class.return_value = mock_custom_api - mock_space_exists.return_value = False - - self.hp_space.update(desired_status="Stopped") - - mock_verify_config.assert_called_once() - mock_space_exists.assert_called_once() - # Should not call update since resource doesn't exist - mock_custom_api.patch_namespaced_custom_object.assert_not_called() - @patch('sagemaker.hyperpod.space.hyperpod_space.client.CustomObjectsApi') @patch.object(HPSpace, 'verify_kube_config') @patch('sagemaker.hyperpod.space.hyperpod_space.handle_exception') - @patch.object(HPSpace, 'space_exists') - def test_update_failure(self, mock_space_exists, mock_handle_exception, mock_verify_config, mock_custom_api_class): + def test_update_failure(self, mock_handle_exception, mock_verify_config, mock_custom_api_class): """Test dev space update failure""" mock_custom_api = Mock() mock_custom_api_class.return_value = mock_custom_api - mock_space_exists.return_value = True mock_custom_api.patch_namespaced_custom_object.side_effect = Exception("Update failed") mock_domain_config = {"space_spec": {"spec": {}}} @@ -377,12 +285,10 @@ def test_stop(self, mock_update): @patch('sagemaker.hyperpod.space.hyperpod_space.client.CoreV1Api') @patch.object(HPSpace, 'verify_kube_config') - @patch.object(HPSpace, 'space_exists') - def test_list_pods_success(self, mock_space_exists, mock_verify_config, mock_core_api_class): + def test_list_pods_success(self, mock_verify_config, mock_core_api_class): """Test successful pod listing""" mock_core_api = Mock() mock_core_api_class.return_value = mock_core_api - mock_space_exists.return_value = True mock_pod1 = Mock() mock_pod1.metadata.name = "pod1" @@ -398,32 +304,16 @@ def test_list_pods_success(self, mock_space_exists, mock_verify_config, mock_cor self.assertEqual(result, ["pod1", "pod2"]) mock_core_api.list_namespaced_pod.assert_called_once_with( namespace="test-namespace", - label_selector="sagemaker.aws.com/space-name=test-space" + label_selector="workspace.jupyter.org/workspaceName=test-space" ) - @patch('sagemaker.hyperpod.space.hyperpod_space.client.CoreV1Api') - @patch.object(HPSpace, 'verify_kube_config') - @patch.object(HPSpace, 'space_exists') - def test_list_pods_not_exists(self, mock_space_exists, mock_verify_config, mock_core_api_class): - """Test pod listing when space doesn't exist""" - mock_core_api = Mock() - mock_core_api_class.return_value = mock_core_api - mock_space_exists.return_value = False - - result = self.hp_space.list_pods() - - self.assertEqual(result, []) - mock_core_api.list_namespaced_pod.assert_not_called() - @patch('sagemaker.hyperpod.space.hyperpod_space.client.CoreV1Api') @patch.object(HPSpace, 'verify_kube_config') @patch('sagemaker.hyperpod.space.hyperpod_space.handle_exception') - @patch.object(HPSpace, 'space_exists') - def test_list_pods_failure(self, mock_space_exists, mock_handle_exception, mock_verify_config, mock_core_api_class): + def test_list_pods_failure(self, mock_handle_exception, mock_verify_config, mock_core_api_class): """Test pod listing failure""" mock_core_api = Mock() mock_core_api_class.return_value = mock_core_api - mock_space_exists.return_value = True mock_core_api.list_namespaced_pod.side_effect = Exception("List pods failed") self.hp_space.list_pods() @@ -433,12 +323,10 @@ def test_list_pods_failure(self, mock_space_exists, mock_handle_exception, mock_ @patch('sagemaker.hyperpod.space.hyperpod_space.client.CoreV1Api') @patch.object(HPSpace, 'verify_kube_config') @patch.object(HPSpace, 'list_pods') - @patch.object(HPSpace, 'space_exists') - def test_get_logs_with_pod_name(self, mock_space_exists, mock_list_pods, mock_verify_config, mock_core_api_class): + def test_get_logs_with_pod_name(self, mock_list_pods, mock_verify_config, mock_core_api_class): """Test getting logs with specific pod name""" mock_core_api = Mock() mock_core_api_class.return_value = mock_core_api - mock_space_exists.return_value = True mock_core_api.read_namespaced_pod_log.return_value = "test logs" result = self.hp_space.get_logs(pod_name="test-pod") @@ -453,12 +341,10 @@ def test_get_logs_with_pod_name(self, mock_space_exists, mock_list_pods, mock_ve @patch('sagemaker.hyperpod.space.hyperpod_space.client.CoreV1Api') @patch.object(HPSpace, 'verify_kube_config') @patch.object(HPSpace, 'list_pods') - @patch.object(HPSpace, 'space_exists') - def test_get_logs_without_pod_name(self, mock_space_exists, mock_list_pods, mock_verify_config, mock_core_api_class): + def test_get_logs_without_pod_name(self, mock_list_pods, mock_verify_config, mock_core_api_class): """Test getting logs without pod name (uses first available pod)""" mock_core_api = Mock() mock_core_api_class.return_value = mock_core_api - mock_space_exists.return_value = True mock_core_api.read_namespaced_pod_log.return_value = "test logs" mock_list_pods.return_value = ["pod1", "pod2"] @@ -472,10 +358,8 @@ def test_get_logs_without_pod_name(self, mock_space_exists, mock_list_pods, mock @patch.object(HPSpace, 'verify_kube_config') @patch.object(HPSpace, 'list_pods') - @patch.object(HPSpace, 'space_exists') - def test_get_logs_no_pods(self, mock_space_exists, mock_list_pods, mock_verify_config): + def test_get_logs_no_pods(self, mock_list_pods, mock_verify_config): """Test getting logs when no pods are available""" - mock_space_exists.return_value = True mock_list_pods.return_value = [] with self.assertRaises(RuntimeError) as context: @@ -484,12 +368,10 @@ def test_get_logs_no_pods(self, mock_space_exists, mock_list_pods, mock_verify_c @patch('sagemaker.hyperpod.space.hyperpod_space.client.CoreV1Api') @patch.object(HPSpace, 'verify_kube_config') - @patch.object(HPSpace, 'space_exists') - def test_get_logs_with_container(self, mock_space_exists, mock_verify_config, mock_core_api_class): + def test_get_logs_with_container(self, mock_verify_config, mock_core_api_class): """Test getting logs with specific container""" mock_core_api = Mock() mock_core_api_class.return_value = mock_core_api - mock_space_exists.return_value = True mock_core_api.read_namespaced_pod_log.return_value = "container logs" result = self.hp_space.get_logs(pod_name="test-pod", container="test-container") @@ -501,29 +383,13 @@ def test_get_logs_with_container(self, mock_space_exists, mock_verify_config, mo container="test-container" ) - @patch('sagemaker.hyperpod.space.hyperpod_space.client.CoreV1Api') - @patch.object(HPSpace, 'verify_kube_config') - @patch.object(HPSpace, 'space_exists') - def test_get_logs_not_exists(self, mock_space_exists, mock_verify_config, mock_core_api_class): - """Test getting logs when space doesn't exist""" - mock_core_api = Mock() - mock_core_api_class.return_value = mock_core_api - mock_space_exists.return_value = False - - result = self.hp_space.get_logs(pod_name="test-pod") - - self.assertEqual(result, "") - mock_core_api.read_namespaced_pod_log.assert_not_called() - @patch('sagemaker.hyperpod.space.hyperpod_space.client.CoreV1Api') @patch.object(HPSpace, 'verify_kube_config') @patch('sagemaker.hyperpod.space.hyperpod_space.handle_exception') - @patch.object(HPSpace, 'space_exists') - def test_get_logs_failure(self, mock_space_exists, mock_handle_exception, mock_verify_config, mock_core_api_class): + def test_get_logs_failure(self, mock_handle_exception, mock_verify_config, mock_core_api_class): """Test getting logs failure""" mock_core_api = Mock() mock_core_api_class.return_value = mock_core_api - mock_space_exists.return_value = True mock_core_api.read_namespaced_pod_log.side_effect = Exception("Get logs failed") self.hp_space.get_logs(pod_name="test-pod") @@ -542,12 +408,10 @@ def test_model_extra_forbid(self): @patch('sagemaker.hyperpod.space.hyperpod_space.setup_logging') @patch.object(HPSpace, 'verify_kube_config') - @patch.object(HPSpace, 'space_exists') - def test_create_debug_logging(self, mock_space_exists, mock_verify_config, mock_setup_logging): + def test_create_debug_logging(self, mock_verify_config, mock_setup_logging): """Test create method with debug logging enabled""" mock_logger = Mock() mock_setup_logging.return_value = mock_logger - mock_space_exists.return_value = False # Mock domain config for YAML serialization mock_domain_config = { From 51a641504dd11639656d7485947834bf656c1f45 Mon Sep 17 00:00:00 2001 From: Brian Xia Date: Fri, 7 Nov 2025 07:58:46 -0800 Subject: [PATCH 16/31] Add dev_space_access.py CLI command (#259) * Add dev_space_access.py CLI command * Add space access creation to pySDK and refactor space access CLI --------- Co-authored-by: Brian Xia --- .../hyperpod/cli/commands/space_access.py | 21 +++++ .../cli/constants/space_access_constants.py | 4 +- src/sagemaker/hyperpod/cli/hyp_cli.py | 2 + .../hyperpod/space/hyperpod_space.py | 50 +++++++++++ test/unit_tests/cli/test_space_access.py | 67 ++++++++++++++ test/unit_tests/test_hyperpod_space.py | 88 +++++++++++++++++++ 6 files changed, 230 insertions(+), 2 deletions(-) create mode 100644 src/sagemaker/hyperpod/cli/commands/space_access.py create mode 100644 test/unit_tests/cli/test_space_access.py diff --git a/src/sagemaker/hyperpod/cli/commands/space_access.py b/src/sagemaker/hyperpod/cli/commands/space_access.py new file mode 100644 index 00000000..fbe36e63 --- /dev/null +++ b/src/sagemaker/hyperpod/cli/commands/space_access.py @@ -0,0 +1,21 @@ +import click +from sagemaker.hyperpod.space.hyperpod_space import HPSpace +from sagemaker.hyperpod.common.telemetry.telemetry_logging import ( + _hyperpod_telemetry_emitter, +) +from sagemaker.hyperpod.common.telemetry.constants import Feature + + +@click.command("hyp-space-access") +@click.option("--name", required=True, help="Name of the space") +@click.option("--namespace", "-n", required=False, default="default", help="Kubernetes namespace") +@click.option("--connection-type", "-t", required=False, default="vscode-remote", help="Remote access type") +def space_access_create(name, namespace, connection_type): + """Create a space access resource.""" + + try: + space = HPSpace.get(name=name, namespace=namespace) + response = space.create_space_access(connection_type=connection_type) + click.echo(response) + except Exception as e: + click.echo(f"Error creating space access: {e}", err=True) diff --git a/src/sagemaker/hyperpod/cli/constants/space_access_constants.py b/src/sagemaker/hyperpod/cli/constants/space_access_constants.py index 55fc9522..ea27f5be 100644 --- a/src/sagemaker/hyperpod/cli/constants/space_access_constants.py +++ b/src/sagemaker/hyperpod/cli/constants/space_access_constants.py @@ -11,6 +11,6 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -SPACE_ACCESS_GROUP = "access.devspaces.sagemaker.aws.com" +SPACE_ACCESS_GROUP = "connection.workspace.jupyter.org" SPACE_ACCESS_VERSION = "v1alpha1" -SPACE_ACCESS_PLURAL = "devspaceaccess" +SPACE_ACCESS_PLURAL = "workspaceconnections" diff --git a/src/sagemaker/hyperpod/cli/hyp_cli.py b/src/sagemaker/hyperpod/cli/hyp_cli.py index ec60f303..bf4701e2 100644 --- a/src/sagemaker/hyperpod/cli/hyp_cli.py +++ b/src/sagemaker/hyperpod/cli/hyp_cli.py @@ -55,6 +55,7 @@ space_template_delete, space_template_update, ) +from sagemaker.hyperpod.cli.commands.space_access import space_access_create from sagemaker.hyperpod.cli.commands.init import ( init, @@ -208,6 +209,7 @@ def exec(): create.add_command(_default_create) create.add_command(space_create) create.add_command(space_template_create) +create.add_command(space_access_create) list.add_command(list_jobs) list.add_command(js_list) diff --git a/src/sagemaker/hyperpod/space/hyperpod_space.py b/src/sagemaker/hyperpod/space/hyperpod_space.py index fe3f3c5f..e49e8b49 100644 --- a/src/sagemaker/hyperpod/space/hyperpod_space.py +++ b/src/sagemaker/hyperpod/space/hyperpod_space.py @@ -22,6 +22,11 @@ SPACE_VERSION, SPACE_PLURAL, ) +from sagemaker.hyperpod.cli.constants.space_access_constants import ( + SPACE_ACCESS_GROUP, + SPACE_ACCESS_VERSION, + SPACE_ACCESS_PLURAL, +) from hyperpod_space_template.v1_0.model import SpaceConfig @@ -341,3 +346,48 @@ def get_logs(self, pod_name: Optional[str] = None, container: Optional[str] = No return logs except Exception as e: handle_exception(e, pod_name, self.config.namespace) + + @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "create_space_access") + def create_space_access(self, connection_type: str = "vscode-remote") -> Dict[str, str]: + """Create a space access for this space. + + Args: + connection_type (str, optional): The IDE type for remote access. Defaults to "vscode". + + Returns: + Dict[str, str]: Dictionary with 'SpaceConnectionType' and 'SpaceConnectionUrl' keys + + Raises: + Exception: If the space access creation fails + """ + self.verify_kube_config() + logger = self.get_logger() + + config = { + "metadata": { + "namespace": self.config.namespace, + }, + "spec": { + "workspaceName": self.config.name, + "workspaceConnectionType": connection_type, + } + } + + custom_api = client.CustomObjectsApi() + + try: + response = custom_api.create_namespaced_custom_object( + group=SPACE_ACCESS_GROUP, + version=SPACE_ACCESS_VERSION, + namespace=self.config.namespace, + plural=SPACE_ACCESS_PLURAL, + body=config + ) + logger.debug(f"Successfully created space access for '{self.config.name}'!") + return { + "SpaceConnectionType": connection_type, + "SpaceConnectionUrl": response["status"]["workspaceConnectionUrl"] + } + except Exception as e: + logger.error(f"Failed to create space access for {self.config.name}!") + handle_exception(e, self.config.name, self.config.namespace) diff --git a/test/unit_tests/cli/test_space_access.py b/test/unit_tests/cli/test_space_access.py new file mode 100644 index 00000000..9602edc7 --- /dev/null +++ b/test/unit_tests/cli/test_space_access.py @@ -0,0 +1,67 @@ +import pytest +from click.testing import CliRunner +from unittest.mock import Mock, patch + +from sagemaker.hyperpod.cli.commands.space_access import space_access_create + + +class TestSpaceAccessCommands: + """Test cases for space access commands""" + + def setup_method(self): + self.runner = CliRunner() + + @patch('sagemaker.hyperpod.cli.commands.space_access.HPSpace') + def test_space_access_create_success(self, mock_hp_space_class): + """Test successful space access creation""" + # Mock HPSpace.get() and create_space_access() + mock_space_instance = Mock() + mock_space_instance.create_space_access.return_value = { + "SpaceConnectionType": "vscode-remote", + "SpaceConnectionUrl": "https://test-url.com" + } + mock_hp_space_class.get.return_value = mock_space_instance + + result = self.runner.invoke(space_access_create, [ + '--name', 'test-space', + '--namespace', 'test-namespace', + '--connection-type', 'vscode-remote' + ]) + + assert result.exit_code == 0 + assert "https://test-url.com" in result.output + mock_hp_space_class.get.assert_called_once_with(name='test-space', namespace='test-namespace') + mock_space_instance.create_space_access.assert_called_once_with(connection_type='vscode-remote') + + @patch('sagemaker.hyperpod.cli.commands.space_access.HPSpace') + def test_space_access_create_default_values(self, mock_hp_space_class): + """Test space access creation with default values""" + mock_space_instance = Mock() + mock_space_instance.create_space_access.return_value = { + "SpaceConnectionType": "vscode-remote", + "SpaceConnectionUrl": "https://default-url.com" + } + mock_hp_space_class.get.return_value = mock_space_instance + + result = self.runner.invoke(space_access_create, [ + '--name', 'test-space' + ]) + + assert result.exit_code == 0 + assert "https://default-url.com" in result.output + mock_hp_space_class.get.assert_called_once_with(name='test-space', namespace='default') + mock_space_instance.create_space_access.assert_called_once_with(connection_type='vscode-remote') + + @patch('sagemaker.hyperpod.cli.commands.space_access.HPSpace') + def test_space_access_create_api_error(self, mock_hp_space_class): + """Test space access creation when API call fails""" + mock_space_instance = Mock() + mock_space_instance.create_space_access.side_effect = Exception("API error") + mock_hp_space_class.get.return_value = mock_space_instance + + result = self.runner.invoke(space_access_create, [ + '--name', 'test-space' + ]) + + assert result.exit_code == 0 + assert "Error creating space access: API error" in result.output diff --git a/test/unit_tests/test_hyperpod_space.py b/test/unit_tests/test_hyperpod_space.py index 60689b01..ba481ab9 100644 --- a/test/unit_tests/test_hyperpod_space.py +++ b/test/unit_tests/test_hyperpod_space.py @@ -433,3 +433,91 @@ def test_get_logger(self): """Test get_logger class method""" logger = HPSpace.get_logger() self.assertEqual(logger.name, "sagemaker.hyperpod.space.hyperpod_space") + + @patch('sagemaker.hyperpod.space.hyperpod_space.client.CustomObjectsApi') + @patch.object(HPSpace, 'verify_kube_config') + def test_create_space_access_success(self, mock_verify_config, mock_custom_api_class): + """Test successful space access creation""" + mock_custom_api = Mock() + mock_custom_api_class.return_value = mock_custom_api + + mock_response = { + "status": { + "workspaceConnectionUrl": "https://example.com/vscode-access" + } + } + mock_custom_api.create_namespaced_custom_object.return_value = mock_response + + result = self.hp_space.create_space_access() + + expected_config = { + "metadata": { + "namespace": "test-namespace", + }, + "spec": { + "workspaceName": "test-space", + "workspaceConnectionType": "vscode-remote", + } + } + + mock_verify_config.assert_called_once() + mock_custom_api.create_namespaced_custom_object.assert_called_once_with( + group="connection.workspace.jupyter.org", + version="v1alpha1", + namespace="test-namespace", + plural="workspaceconnections", + body=expected_config + ) + self.assertEqual(result, {"SpaceConnectionType": "vscode-remote", "SpaceConnectionUrl": "https://example.com/vscode-access"}) + + @patch('sagemaker.hyperpod.space.hyperpod_space.client.CustomObjectsApi') + @patch.object(HPSpace, 'verify_kube_config') + def test_create_space_access_custom_ide(self, mock_verify_config, mock_custom_api_class): + """Test space access creation with custom IDE type""" + mock_custom_api = Mock() + mock_custom_api_class.return_value = mock_custom_api + + mock_response = { + "status": { + "workspaceConnectionUrl": "https://example.com/webui-access" + } + } + mock_custom_api.create_namespaced_custom_object.return_value = mock_response + + result = self.hp_space.create_space_access(connection_type="web-ui") + + expected_config = { + "metadata": { + "namespace": "test-namespace", + }, + "spec": { + "workspaceName": "test-space", + "workspaceConnectionType": "web-ui", + } + } + + mock_custom_api.create_namespaced_custom_object.assert_called_once_with( + group="connection.workspace.jupyter.org", + version="v1alpha1", + namespace="test-namespace", + plural="workspaceconnections", + body=expected_config + ) + self.assertEqual(result, {"SpaceConnectionType": "web-ui", "SpaceConnectionUrl": "https://example.com/webui-access"}) + + @patch('sagemaker.hyperpod.space.hyperpod_space.client.CustomObjectsApi') + @patch.object(HPSpace, 'verify_kube_config') + @patch('sagemaker.hyperpod.space.hyperpod_space.handle_exception') + def test_create_space_access_failure(self, mock_handle_exception, mock_verify_config, mock_custom_api_class): + """Test space access creation failure""" + mock_custom_api = Mock() + mock_custom_api_class.return_value = mock_custom_api + mock_custom_api.create_namespaced_custom_object.side_effect = Exception("Access creation failed") + + self.hp_space.create_space_access() + + mock_handle_exception.assert_called_once_with( + mock_custom_api.create_namespaced_custom_object.side_effect, + "test-space", + "test-namespace" + ) From 24171e6367ebd381333edc8b47b5819c28d6fed2 Mon Sep 17 00:00:00 2001 From: aws-brianxia Date: Fri, 7 Nov 2025 15:46:14 -0800 Subject: [PATCH 17/31] Listing space will filter out the spaces not created by the current user (#285) * Implement CRUD operations for Space PySDK * Update Space PySDK per new schema * Implement CRUD operations for Space PySDK * Update Space PySDK per new schema * Update Space PySDK per new schema * Implement space list pagination and creator filtering --- src/sagemaker/hyperpod/cli/commands/space.py | 4 +- .../hyperpod/space/hyperpod_space.py | 65 +++--- test/unit_tests/test_hyperpod_space.py | 216 +++++++++++++++++- 3 files changed, 248 insertions(+), 37 deletions(-) diff --git a/src/sagemaker/hyperpod/cli/commands/space.py b/src/sagemaker/hyperpod/cli/commands/space.py index 9fd9ba31..28c46b71 100644 --- a/src/sagemaker/hyperpod/cli/commands/space.py +++ b/src/sagemaker/hyperpod/cli/commands/space.py @@ -101,7 +101,7 @@ def space_delete(name, namespace): current_space = HPSpace.get(name=name, namespace=namespace) current_space.delete() - click.echo(f"Space '{name}' deleted successfully") + click.echo(f"Space '{name}' deleted successfully in namespace '{namespace}'") except Exception as e: click.echo(f"Error deleting space '{name}': {e}", err=True) @@ -121,7 +121,7 @@ def space_update(version, config): current_space.update(**config) - click.echo(f"Space '{current_space.config.name}' updated successfully") + click.echo(f"Space '{current_space.config.name}' updated successfully in namespace '{config['namespace']}'") except Exception as e: click.echo(f"Error updating space: {e}", err=True) diff --git a/src/sagemaker/hyperpod/space/hyperpod_space.py b/src/sagemaker/hyperpod/space/hyperpod_space.py index e49e8b49..b0fd9f11 100644 --- a/src/sagemaker/hyperpod/space/hyperpod_space.py +++ b/src/sagemaker/hyperpod/space/hyperpod_space.py @@ -1,5 +1,6 @@ import logging import yaml +import boto3 from typing import List, Optional, ClassVar, Dict, Any from pydantic import BaseModel, Field, ConfigDict from kubernetes import client, config @@ -119,7 +120,7 @@ def create(self, debug: bool = False): plural=SPACE_PLURAL, body=config_body, ) - logger.info(f"Successfully created HyperPod Space '{self.config.name}'!") + logger.debug(f"Successfully created HyperPod Space '{self.config.name}'!") except Exception as e: logger.error(f"Failed to create HyperPod Space {self.config.name}!") handle_exception(e, self.config.name, self.config.namespace) @@ -127,14 +128,14 @@ def create(self, debug: bool = False): @classmethod @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "list_spaces") def list(cls, namespace: Optional[str] = None) -> List["HPSpace"]: - """List all HyperPod Spaces in the specified namespace. + """List all HyperPod Spaces in the specified namespace created by the caller. Args: namespace (str, optional): The Kubernetes namespace to list spaces from. If None, uses the default namespace from current context. Returns: - List[HPSpace]: List of HPSpace instances found in the namespace + List[HPSpace]: List of HPSpace instances created by the caller Raises: Exception: If the Kubernetes API call fails or spaces cannot be retrieved @@ -144,33 +145,43 @@ def list(cls, namespace: Optional[str] = None) -> List["HPSpace"]: if not namespace: namespace = get_default_namespace() + # Get caller identity + sts_client = boto3.client('sts') + caller_identity = sts_client.get_caller_identity() + caller_arn = caller_identity['Arn'] + custom_api = client.CustomObjectsApi() + spaces = [] + continue_token = None try: - response = custom_api.list_namespaced_custom_object( - group=SPACE_GROUP, - version=SPACE_VERSION, - namespace=namespace, - plural=SPACE_PLURAL + while True: + response = custom_api.list_namespaced_custom_object( + group=SPACE_GROUP, + version=SPACE_VERSION, + namespace=namespace, + plural=SPACE_PLURAL, + _continue=continue_token ) - spaces = [] - for item in response.get("items", []): - # Create SpaceConfig from the Kubernetes resource - spec = item.get("spec", {}) - config_data = { - "name": item["metadata"]["name"], - "namespace": item["metadata"]["namespace"], - } - - config_data = map_kubernetes_response_to_model(item, SpaceConfig) - space_config = SpaceConfig(**config_data) - - space = cls( - config=space_config, - raw_resource=item - ) - spaces.append(space) + for item in response.get("items", []): + # Check if space was created by the caller + # TODO: need to also check OwnershipType when it's implemented in the operator + created_by = item.get('metadata', {}).get('annotations', {}).get('workspace.jupyter.org/created-by') + if created_by == caller_arn: + config_data = map_kubernetes_response_to_model(item, SpaceConfig) + space_config = SpaceConfig(**config_data) + + space = cls( + config=space_config, + raw_resource=item + ) + spaces.append(space) + + # Check if there are more pages + continue_token = response.get('metadata', {}).get('continue') + if not continue_token: + break return spaces except Exception as e: @@ -236,7 +247,7 @@ def delete(self): plural=SPACE_PLURAL, name=self.config.name ) - logger.info(f"Successfully deleted HyperPod Space '{self.config.name}'!") + logger.debug(f"Successfully deleted HyperPod Space '{self.config.name}'!") except Exception as e: logger.error(f"Failed to delete HyperPod Space {self.config.name}!") handle_exception(e, self.config.name, self.config.namespace) @@ -274,7 +285,7 @@ def update(self, **kwargs): name=self.config.name, body={"spec": spec_updates} ) - logger.info(f"Successfully updated HyperPod Space '{self.config.name}'!") + logger.debug(f"Successfully updated HyperPod Space '{self.config.name}'!") except Exception as e: logger.error(f"Failed to update HyperPod Space {self.config.name}!") handle_exception(e, self.config.name, self.config.namespace) diff --git a/test/unit_tests/test_hyperpod_space.py b/test/unit_tests/test_hyperpod_space.py index ba481ab9..99776583 100644 --- a/test/unit_tests/test_hyperpod_space.py +++ b/test/unit_tests/test_hyperpod_space.py @@ -97,25 +97,41 @@ def test_create_failure(self, mock_handle_exception, mock_verify_config, mock_cu mock_handle_exception.assert_called_once() - - + @patch('sagemaker.hyperpod.space.hyperpod_space.boto3.client') @patch('sagemaker.hyperpod.space.hyperpod_space.client.CustomObjectsApi') @patch.object(HPSpace, 'verify_kube_config') @patch('sagemaker.hyperpod.space.hyperpod_space.get_default_namespace') - def test_list_success(self, mock_get_namespace, mock_verify_config, mock_custom_api_class): + def test_list_success(self, mock_get_namespace, mock_verify_config, mock_custom_api_class, mock_boto3_client): """Test successful dev space listing""" mock_custom_api = Mock() mock_custom_api_class.return_value = mock_custom_api mock_get_namespace.return_value = "default" + # Mock STS client for caller identity + mock_sts_client = Mock() + mock_sts_client.get_caller_identity.return_value = {'Arn': 'arn:aws:iam::123456789012:user/test-user'} + mock_boto3_client.return_value = mock_sts_client + mock_response = { "items": [ { - "metadata": {"name": "space1", "namespace": "default"}, + "metadata": { + "name": "space1", + "namespace": "default", + "annotations": { + "workspace.jupyter.org/created-by": "arn:aws:iam::123456789012:user/test-user" + } + }, "spec": {"image": "image1:latest", "displayName": "Space 1"}, }, { - "metadata": {"name": "space2", "namespace": "default"}, + "metadata": { + "name": "space2", + "namespace": "default", + "annotations": { + "workspace.jupyter.org/created-by": "arn:aws:iam::123456789012:user/test-user" + } + }, "spec": {"image": "image2:latest", "displayName": "Space 2"}, } ] @@ -129,13 +145,19 @@ def test_list_success(self, mock_get_namespace, mock_verify_config, mock_custom_ self.assertEqual(result[1].config.name, "space2") mock_custom_api.list_namespaced_custom_object.assert_called_once() + @patch('sagemaker.hyperpod.space.hyperpod_space.boto3.client') @patch('sagemaker.hyperpod.space.hyperpod_space.client.CustomObjectsApi') @patch.object(HPSpace, 'verify_kube_config') - def test_list_with_namespace(self, mock_verify_config, mock_custom_api_class): + def test_list_with_namespace(self, mock_verify_config, mock_custom_api_class, mock_boto3_client): """Test dev space listing with specific namespace""" mock_custom_api = Mock() mock_custom_api_class.return_value = mock_custom_api + # Mock STS client for caller identity + mock_sts_client = Mock() + mock_sts_client.get_caller_identity.return_value = {'Arn': 'arn:aws:iam::123456789012:user/test-user'} + mock_boto3_client.return_value = mock_sts_client + mock_response = {"items": []} mock_custom_api.list_namespaced_custom_object.return_value = mock_response @@ -145,13 +167,191 @@ def test_list_with_namespace(self, mock_verify_config, mock_custom_api_class): group="workspace.jupyter.org", version="v1alpha1", namespace="custom-namespace", - plural="workspaces" + plural="workspaces", + _continue=None ) + + @patch('sagemaker.hyperpod.space.hyperpod_space.boto3.client') + @patch('sagemaker.hyperpod.space.hyperpod_space.client.CustomObjectsApi') + @patch.object(HPSpace, 'verify_kube_config') + @patch('sagemaker.hyperpod.space.hyperpod_space.get_default_namespace') + def test_list_filters_by_creator(self, mock_get_namespace, mock_verify_config, mock_custom_api_class, mock_boto3_client): + """Test that list only returns spaces created by the caller""" + mock_custom_api = Mock() + mock_custom_api_class.return_value = mock_custom_api + mock_get_namespace.return_value = "default" + + # Mock STS client for caller identity + mock_sts_client = Mock() + mock_sts_client.get_caller_identity.return_value = {'Arn': 'arn:aws:iam::123456789012:user/test-user'} + mock_boto3_client.return_value = mock_sts_client + + # Mock response with spaces from different creators + mock_response = { + "items": [ + { + "metadata": { + "name": "my-space", + "namespace": "default", + "annotations": { + "workspace.jupyter.org/created-by": "arn:aws:iam::123456789012:user/test-user" + } + }, + "spec": {"image": "image1:latest", "displayName": "My Space"}, + }, + { + "metadata": { + "name": "other-space", + "namespace": "default", + "annotations": { + "workspace.jupyter.org/created-by": "arn:aws:iam::123456789012:user/other-user" + } + }, + "spec": {"image": "image2:latest", "displayName": "Other Space"}, + }, + { + "metadata": { + "name": "no-annotation-space", + "namespace": "default" + }, + "spec": {"image": "image3:latest", "displayName": "No Annotation Space"}, + } + ] + } + mock_custom_api.list_namespaced_custom_object.return_value = mock_response + + result = HPSpace.list() + + # Should only return the space created by the current user + self.assertEqual(len(result), 1) + self.assertEqual(result[0].config.name, "my-space") + + @patch('sagemaker.hyperpod.space.hyperpod_space.boto3.client') + @patch('sagemaker.hyperpod.space.hyperpod_space.client.CustomObjectsApi') + @patch.object(HPSpace, 'verify_kube_config') + @patch('sagemaker.hyperpod.space.hyperpod_space.get_default_namespace') + def test_list_pagination_multiple_pages(self, mock_get_namespace, mock_verify_config, mock_custom_api_class, mock_boto3_client): + """Test pagination with multiple pages""" + mock_custom_api = Mock() + mock_custom_api_class.return_value = mock_custom_api + mock_get_namespace.return_value = "default" + + # Mock STS client + mock_sts_client = Mock() + mock_sts_client.get_caller_identity.return_value = {'Arn': 'arn:aws:iam::123456789012:user/test-user'} + mock_boto3_client.return_value = mock_sts_client + + # Mock responses for multiple pages + first_page_response = { + "items": [ + { + "metadata": { + "name": "space1", + "namespace": "default", + "annotations": { + "workspace.jupyter.org/created-by": "arn:aws:iam::123456789012:user/test-user" + } + }, + "spec": {"image": "image1:latest", "displayName": "Space 1"}, + } + ], + "metadata": {"continue": "page2-token"} + } + + second_page_response = { + "items": [ + { + "metadata": { + "name": "space2", + "namespace": "default", + "annotations": { + "workspace.jupyter.org/created-by": "arn:aws:iam::123456789012:user/test-user" + } + }, + "spec": {"image": "image2:latest", "displayName": "Space 2"}, + } + ], + "metadata": {} # No continue token (last page) + } + + mock_custom_api.list_namespaced_custom_object.side_effect = [first_page_response, second_page_response] + + result = HPSpace.list() + + self.assertEqual(len(result), 2) + self.assertEqual(result[0].config.name, "space1") + self.assertEqual(result[1].config.name, "space2") + + # Should be called twice (two pages) + self.assertEqual(mock_custom_api.list_namespaced_custom_object.call_count, 2) + + # Verify the calls + calls = mock_custom_api.list_namespaced_custom_object.call_args_list + self.assertEqual(calls[0][1]['_continue'], None) # First call + self.assertEqual(calls[1][1]['_continue'], "page2-token") # Second call with token + + @patch('sagemaker.hyperpod.space.hyperpod_space.boto3.client') + @patch('sagemaker.hyperpod.space.hyperpod_space.client.CustomObjectsApi') + @patch.object(HPSpace, 'verify_kube_config') + def test_list_no_matching_spaces_across_pages(self, mock_verify_config, mock_custom_api_class, mock_boto3_client): + """Test pagination when no spaces match the creator filter""" + mock_custom_api = Mock() + mock_custom_api_class.return_value = mock_custom_api + + # Mock STS client + mock_sts_client = Mock() + mock_sts_client.get_caller_identity.return_value = {'Arn': 'arn:aws:iam::123456789012:user/test-user'} + mock_boto3_client.return_value = mock_sts_client + + # Mock responses with no matching creators + first_page_response = { + "items": [ + { + "metadata": { + "name": "other-space1", + "namespace": "test-namespace", + "annotations": { + "workspace.jupyter.org/created-by": "arn:aws:iam::123456789012:user/other-user" + } + }, + "spec": {"image": "image1:latest", "displayName": "Other Space 1"}, + } + ], + "metadata": {"continue": "page2-token"} + } + + second_page_response = { + "items": [ + { + "metadata": { + "name": "another-space", + "namespace": "test-namespace", + "annotations": { + "workspace.jupyter.org/created-by": "arn:aws:iam::123456789012:user/another-user" + } + }, + "spec": {"image": "image2:latest", "displayName": "Another Space"}, + } + ], + "metadata": {} # No continue token (last page) + } + + mock_custom_api.list_namespaced_custom_object.side_effect = [first_page_response, second_page_response] + + result = HPSpace.list(namespace="test-namespace") + + # Should return empty list (no matching creators) + self.assertEqual(len(result), 0) + + # Should still paginate through all pages + self.assertEqual(mock_custom_api.list_namespaced_custom_object.call_count, 2) + + @patch('sagemaker.hyperpod.space.hyperpod_space.boto3.client') @patch('sagemaker.hyperpod.space.hyperpod_space.client.CustomObjectsApi') @patch.object(HPSpace, 'verify_kube_config') @patch('sagemaker.hyperpod.space.hyperpod_space.handle_exception') - def test_list_failure(self, mock_handle_exception, mock_verify_config, mock_custom_api_class): + def test_list_failure(self, mock_handle_exception, mock_verify_config, mock_custom_api_class, mock_boto3_client): """Test dev space listing failure""" mock_custom_api = Mock() mock_custom_api_class.return_value = mock_custom_api From c008e928422cb577f445494cd1fa29dfda9addbe Mon Sep 17 00:00:00 2001 From: aws-brianxia Date: Tue, 11 Nov 2025 18:13:27 -0800 Subject: [PATCH 18/31] Refactor space template with PySDK (#286) --- .../hyperpod/cli/commands/space_template.py | 77 ++--- .../hyperpod/space/hyperpod_space_template.py | 38 ++- test/unit_tests/cli/test_space_template.py | 270 +++++------------- .../test_hyperpod_space_template.py | 30 +- 4 files changed, 134 insertions(+), 281 deletions(-) diff --git a/src/sagemaker/hyperpod/cli/commands/space_template.py b/src/sagemaker/hyperpod/cli/commands/space_template.py index 92941d79..540125ae 100644 --- a/src/sagemaker/hyperpod/cli/commands/space_template.py +++ b/src/sagemaker/hyperpod/cli/commands/space_template.py @@ -2,25 +2,17 @@ import json import yaml from tabulate import tabulate -from sagemaker.hyperpod.cli.clients.kubernetes_client import KubernetesClient +from sagemaker.hyperpod.space.hyperpod_space_template import HPSpaceTemplate @click.command("hyp-space-template") @click.option("--file", "-f", required=True, help="YAML file containing the configuration") def space_template_create(file): """Create a space-template resource.""" - k8s_client = KubernetesClient() - try: - with open(file, 'r') as f: - config_data = yaml.safe_load(f) - - k8s_client.create_space_template(config_data) - click.echo(f"Space template '{config_data['metadata']['name']}' created successfully") - except FileNotFoundError: - click.echo(f"Error: File '{file}' not found", err=True) - except yaml.YAMLError as e: - click.echo(f"Error parsing YAML file: {e}", err=True) + template = HPSpaceTemplate(file_path=file) + template.create() + click.echo(f"Space template '{template.name}' created successfully") except Exception as e: click.echo(f"Error creating space template: {e}", err=True) @@ -29,22 +21,22 @@ def space_template_create(file): @click.option("--output", "-o", type=click.Choice(["table", "json"]), default="table") def space_template_list(output): """List space-template resources.""" - k8s_client = KubernetesClient() - try: - resources = k8s_client.list_space_templates() + templates = HPSpaceTemplate.list() if output == "json": - click.echo(json.dumps(resources, indent=2)) + templates_data = [template.to_dict() for template in templates] + click.echo(json.dumps(templates_data, indent=2)) else: - items = resources.get("items", []) - if items: + if templates: table_data = [] - for item in items: + for template in templates: table_data.append([ - item["metadata"]["name"], + template.name, + template.config_data.get("spec", {}).get("displayName", ""), + template.config_data.get("spec", {}).get("defaultImage", ""), ]) - click.echo(tabulate(table_data, headers=["NAME"])) + click.echo(tabulate(table_data, headers=["NAME", "DISPLAY_NAME", "DEFAULT_IMAGE"])) else: click.echo("No space templates found") except Exception as e: @@ -52,32 +44,28 @@ def space_template_list(output): @click.command("hyp-space-template") -@click.option("--name", required=False, help="Name of the space template") +@click.option("--name", required=True, help="Name of the space template") @click.option("--output", "-o", type=click.Choice(["yaml", "json"]), default="yaml") def space_template_describe(name, output): """Describe a space-template resource.""" - k8s_client = KubernetesClient() - try: - resource = k8s_client.get_space_template(name) - resource["metadata"].pop('managedFields', None) + template = HPSpaceTemplate.get(name) if output == "json": - click.echo(json.dumps(resource, indent=2)) + click.echo(json.dumps(template.to_dict(), indent=2)) else: - click.echo(yaml.dump(resource, default_flow_style=False)) + click.echo(template.to_yaml()) except Exception as e: click.echo(f"Error describing space template '{name}': {e}", err=True) @click.command("hyp-space-template") -@click.option("--name", required=False, help="Name of the space template") +@click.option("--name", required=True, help="Name of the space template") def space_template_delete(name): """Delete a space-template resource.""" - k8s_client = KubernetesClient() - try: - k8s_client.delete_space_template(name) + template = HPSpaceTemplate.get(name) + template.delete() click.echo(f"Space template '{name}' deleted successfully") except Exception as e: click.echo(f"Error deleting space template '{name}': {e}", err=True) @@ -88,30 +76,9 @@ def space_template_delete(name): @click.option("--file", "-f", required=True, help="YAML file containing the updated template") def space_template_update(name, file): """Update a space-template resource.""" - k8s_client = KubernetesClient() - try: - with open(file, 'r') as f: - config_data = yaml.safe_load(f) - - # Validate that the name matches - yaml_name = config_data.get('metadata', {}).get('name') - if yaml_name and yaml_name != name: - click.echo(f"Error: Name mismatch. CLI parameter '{name}' does not match YAML name '{yaml_name}'", err=True) - return - - # Remove immutable fields from the update - if 'metadata' in config_data: - config_data['metadata'].pop('resourceVersion', None) - config_data['metadata'].pop('uid', None) - config_data['metadata'].pop('creationTimestamp', None) - config_data['metadata'].pop('managedFields', None) - - k8s_client.patch_space_template(name, config_data) + template = HPSpaceTemplate.get(name) + template.update(file) click.echo(f"Space template '{name}' updated successfully") - except FileNotFoundError: - click.echo(f"Error: File '{file}' not found", err=True) - except yaml.YAMLError as e: - click.echo(f"Error parsing YAML file: {e}", err=True) except Exception as e: click.echo(f"Error updating space template '{name}': {e}", err=True) diff --git a/src/sagemaker/hyperpod/space/hyperpod_space_template.py b/src/sagemaker/hyperpod/space/hyperpod_space_template.py index 0499483e..80ae8800 100644 --- a/src/sagemaker/hyperpod/space/hyperpod_space_template.py +++ b/src/sagemaker/hyperpod/space/hyperpod_space_template.py @@ -1,6 +1,6 @@ import logging import yaml -from typing import List, Optional, ClassVar, Dict, Any, Union +from typing import List, Optional, ClassVar, Dict, Any from kubernetes import client, config from kubernetes.client.rest import ApiException @@ -28,19 +28,31 @@ class HPSpaceTemplate: is_kubeconfig_loaded: ClassVar[bool] = False - def __init__(self, file_path: str): - """Initialize space template with config YAML file path. + def __init__(self, *, file_path: Optional[str] = None, config_data: Optional[Dict[str, Any]] = None): + """Initialize space template with config YAML file path or dictionary data. Args: - file_path: Path to YAML file + file_path: Path to YAML configuration file + config_data: Dictionary containing configuration data + + Raises: + ValueError: If both or neither parameters are provided """ - try: - with open(file_path, 'r') as f: - self.config_data = yaml.safe_load(f) - except FileNotFoundError: - raise FileNotFoundError(f"File '{file_path}' not found") - except yaml.YAMLError as e: - raise ValueError(f"Error parsing YAML file: {e}") + if (file_path is None) == (config_data is None): + raise ValueError("Exactly one of 'file_path' or 'config_data' must be provided") + + if file_path is not None: + # Initialize from file path + try: + with open(file_path, 'r') as f: + self.config_data = yaml.safe_load(f) + except FileNotFoundError: + raise FileNotFoundError(f"File '{file_path}' not found") + except yaml.YAMLError as e: + raise ValueError(f"Error parsing YAML file: {e}") + else: + # Initialize from dictionary data (e.g., from Kubernetes API response) + self.config_data = config_data self.name = self.config_data.get('metadata', {}).get('name') @@ -104,7 +116,7 @@ def list(cls) -> List["HPSpaceTemplate"]: templates = [] for item in response.get("items", []): - templates.append(cls(item)) + templates.append(cls(config_data=item)) return templates @@ -140,7 +152,7 @@ def get(cls, name: str) -> "HPSpaceTemplate": if 'metadata' in response: response['metadata'].pop('managedFields', None) - return cls(response) + return cls(config_data=response) except ApiException as e: handle_exception(e, name, None) diff --git a/test/unit_tests/cli/test_space_template.py b/test/unit_tests/cli/test_space_template.py index e468c85d..5f8dfd41 100644 --- a/test/unit_tests/cli/test_space_template.py +++ b/test/unit_tests/cli/test_space_template.py @@ -36,67 +36,38 @@ def setUp(self): "spec": {"displayName": "Test Template"} } - @patch("sagemaker.hyperpod.cli.commands.space_template.KubernetesClient") - @patch("builtins.open", new_callable=mock_open, read_data="test: data") - @patch("yaml.safe_load") - def test_space_template_create_success(self, mock_yaml_load, mock_file, mock_k8s_client): + @patch("sagemaker.hyperpod.cli.commands.space_template.HPSpaceTemplate") + def test_space_template_create_success(self, mock_hp_space_template): """Test successful space template creation""" - mock_yaml_load.return_value = self.mock_config_data - mock_client_instance = Mock() - mock_k8s_client.return_value = mock_client_instance + mock_template_instance = Mock() + mock_template_instance.name = "test-template" + mock_hp_space_template.return_value = mock_template_instance result = self.runner.invoke(space_template_create, ["--file", "test.yaml"]) self.assertEqual(result.exit_code, 0) self.assertIn("Space template 'test-template' created successfully", result.output) - mock_client_instance.create_space_template.assert_called_once_with(self.mock_config_data) + mock_hp_space_template.assert_called_once_with(file_path="test.yaml") + mock_template_instance.create.assert_called_once() - @patch("sagemaker.hyperpod.cli.commands.space_template.KubernetesClient") - def test_space_template_create_file_not_found(self, mock_k8s_client): + @patch("sagemaker.hyperpod.cli.commands.space_template.HPSpaceTemplate") + def test_space_template_create_file_not_found(self, mock_hp_space_template): """Test space template creation with missing file""" - result = self.runner.invoke(space_template_create, ["--file", "nonexistent.yaml"]) - - self.assertEqual(result.exit_code, 0) - self.assertIn("Error: File 'nonexistent.yaml' not found", result.output) - - @patch("sagemaker.hyperpod.cli.commands.space_template.KubernetesClient") - @patch("builtins.open", new_callable=mock_open, read_data="invalid: yaml: content:") - @patch("yaml.safe_load") - def test_space_template_create_yaml_error(self, mock_yaml_load, mock_file, mock_k8s_client): - """Test space template creation with YAML parsing error""" - mock_yaml_load.side_effect = yaml.YAMLError("Invalid YAML") - - result = self.runner.invoke(space_template_create, ["--file", "test.yaml"]) + mock_hp_space_template.side_effect = FileNotFoundError("File 'nonexistent.yaml' not found") - self.assertEqual(result.exit_code, 0) - self.assertIn("Error parsing YAML file: Invalid YAML", result.output) - - @patch("sagemaker.hyperpod.cli.commands.space_template.KubernetesClient") - @patch("builtins.open", new_callable=mock_open, read_data="test: data") - @patch("yaml.safe_load") - def test_space_template_create_k8s_error(self, mock_yaml_load, mock_file, mock_k8s_client): - """Test space template creation with Kubernetes error""" - mock_yaml_load.return_value = self.mock_config_data - mock_client_instance = Mock() - mock_k8s_client.return_value = mock_client_instance - mock_client_instance.create_space_template.side_effect = Exception("K8s error") - - result = self.runner.invoke(space_template_create, ["--file", "test.yaml"]) + result = self.runner.invoke(space_template_create, ["--file", "nonexistent.yaml"]) self.assertEqual(result.exit_code, 0) - self.assertIn("Error creating space template: K8s error", result.output) + self.assertIn("Error creating space template: File 'nonexistent.yaml' not found", result.output) - @patch("sagemaker.hyperpod.cli.commands.space_template.KubernetesClient") - def test_space_template_list_table_output(self, mock_k8s_client): + @patch("sagemaker.hyperpod.cli.commands.space_template.HPSpaceTemplate") + def test_space_template_list_table_output(self, mock_hp_space_template): """Test space template list with table output""" - mock_client_instance = Mock() - mock_k8s_client.return_value = mock_client_instance - mock_client_instance.list_space_templates.return_value = { - "items": [ - {"metadata": {"name": "template1"}}, - {"metadata": {"name": "template2"}} - ] - } + mock_template1 = Mock() + mock_template1.name = "template1" + mock_template2 = Mock() + mock_template2.name = "template2" + mock_hp_space_template.list.return_value = [mock_template1, mock_template2] result = self.runner.invoke(space_template_list, ["--output", "table"]) @@ -105,225 +76,128 @@ def test_space_template_list_table_output(self, mock_k8s_client): self.assertIn("template2", result.output) self.assertIn("NAME", result.output) - @patch("sagemaker.hyperpod.cli.commands.space_template.KubernetesClient") - def test_space_template_list_json_output(self, mock_k8s_client): + @patch("sagemaker.hyperpod.cli.commands.space_template.HPSpaceTemplate") + def test_space_template_list_json_output(self, mock_hp_space_template): """Test space template list with JSON output""" - mock_client_instance = Mock() - mock_k8s_client.return_value = mock_client_instance - mock_resources = { - "items": [ - {"metadata": {"name": "template1"}}, - {"metadata": {"name": "template2"}} - ] - } - mock_client_instance.list_space_templates.return_value = mock_resources + mock_template1 = Mock() + mock_template1.to_dict.return_value = {"metadata": {"name": "template1"}} + mock_template2 = Mock() + mock_template2.to_dict.return_value = {"metadata": {"name": "template2"}} + mock_hp_space_template.list.return_value = [mock_template1, mock_template2] result = self.runner.invoke(space_template_list, ["--output", "json"]) self.assertEqual(result.exit_code, 0) output_json = json.loads(result.output) - self.assertEqual(output_json, mock_resources) + self.assertEqual(len(output_json), 2) + self.assertEqual(output_json[0]["metadata"]["name"], "template1") + self.assertEqual(output_json[1]["metadata"]["name"], "template2") - @patch("sagemaker.hyperpod.cli.commands.space_template.KubernetesClient") - def test_space_template_list_empty(self, mock_k8s_client): + @patch("sagemaker.hyperpod.cli.commands.space_template.HPSpaceTemplate") + def test_space_template_list_empty(self, mock_hp_space_template): """Test space template list with no templates""" - mock_client_instance = Mock() - mock_k8s_client.return_value = mock_client_instance - mock_client_instance.list_space_templates.return_value = {"items": []} + mock_hp_space_template.list.return_value = [] result = self.runner.invoke(space_template_list, ["--output", "table"]) self.assertEqual(result.exit_code, 0) self.assertIn("No space templates found", result.output) - @patch("sagemaker.hyperpod.cli.commands.space_template.KubernetesClient") - def test_space_template_list_error(self, mock_k8s_client): + @patch("sagemaker.hyperpod.cli.commands.space_template.HPSpaceTemplate") + def test_space_template_list_error(self, mock_hp_space_template): """Test space template list with error""" - mock_client_instance = Mock() - mock_k8s_client.return_value = mock_client_instance - mock_client_instance.list_space_templates.side_effect = Exception("List error") + mock_hp_space_template.list.side_effect = Exception("List error") result = self.runner.invoke(space_template_list) self.assertEqual(result.exit_code, 0) self.assertIn("Error listing space templates: List error", result.output) - @patch("sagemaker.hyperpod.cli.commands.space_template.KubernetesClient") - def test_space_template_describe_yaml_output(self, mock_k8s_client): + @patch("sagemaker.hyperpod.cli.commands.space_template.HPSpaceTemplate") + def test_space_template_describe_yaml_output(self, mock_hp_space_template): """Test space template describe with YAML output""" - mock_client_instance = Mock() - mock_k8s_client.return_value = mock_client_instance - mock_resource = { - "metadata": { - "name": "test-template", - "managedFields": [{"manager": "kubectl"}] - }, - "spec": {"displayName": "Test Template"} - } - mock_client_instance.get_space_template.return_value = mock_resource + mock_template_instance = Mock() + mock_template_instance.to_yaml.return_value = "name: test-template\nspec:\n displayName: Test Template" + mock_hp_space_template.get.return_value = mock_template_instance result = self.runner.invoke(space_template_describe, ["--name", "test-template", "--output", "yaml"]) self.assertEqual(result.exit_code, 0) self.assertIn("name: test-template", result.output) self.assertIn("displayName: Test Template", result.output) - # managedFields should be removed - self.assertNotIn("managedFields", result.output) + mock_hp_space_template.get.assert_called_once_with("test-template") - @patch("sagemaker.hyperpod.cli.commands.space_template.KubernetesClient") - def test_space_template_describe_json_output(self, mock_k8s_client): + @patch("sagemaker.hyperpod.cli.commands.space_template.HPSpaceTemplate") + def test_space_template_describe_json_output(self, mock_hp_space_template): """Test space template describe with JSON output""" - mock_client_instance = Mock() - mock_k8s_client.return_value = mock_client_instance - mock_resource = { - "metadata": { - "name": "test-template", - "managedFields": [{"manager": "kubectl"}] - }, + mock_template_instance = Mock() + mock_template_instance.to_dict.return_value = { + "metadata": {"name": "test-template"}, "spec": {"displayName": "Test Template"} } - mock_client_instance.get_space_template.return_value = mock_resource + mock_hp_space_template.get.return_value = mock_template_instance result = self.runner.invoke(space_template_describe, ["--name", "test-template", "--output", "json"]) self.assertEqual(result.exit_code, 0) output_json = json.loads(result.output) self.assertEqual(output_json["metadata"]["name"], "test-template") - self.assertNotIn("managedFields", output_json["metadata"]) + self.assertEqual(output_json["spec"]["displayName"], "Test Template") - @patch("sagemaker.hyperpod.cli.commands.space_template.KubernetesClient") - def test_space_template_describe_error(self, mock_k8s_client): + @patch("sagemaker.hyperpod.cli.commands.space_template.HPSpaceTemplate") + def test_space_template_describe_error(self, mock_hp_space_template): """Test space template describe with error""" - mock_client_instance = Mock() - mock_k8s_client.return_value = mock_client_instance - mock_client_instance.get_space_template.side_effect = Exception("Not found") + mock_hp_space_template.get.side_effect = Exception("Not found") result = self.runner.invoke(space_template_describe, ["--name", "nonexistent"]) self.assertEqual(result.exit_code, 0) self.assertIn("Error describing space template 'nonexistent': Not found", result.output) - @patch("sagemaker.hyperpod.cli.commands.space_template.KubernetesClient") - def test_space_template_delete_success(self, mock_k8s_client): + @patch("sagemaker.hyperpod.cli.commands.space_template.HPSpaceTemplate") + def test_space_template_delete_success(self, mock_hp_space_template): """Test successful space template deletion""" - mock_client_instance = Mock() - mock_k8s_client.return_value = mock_client_instance + mock_template_instance = Mock() + mock_hp_space_template.get.return_value = mock_template_instance result = self.runner.invoke(space_template_delete, ["--name", "test-template"]) self.assertEqual(result.exit_code, 0) self.assertIn("Space template 'test-template' deleted successfully", result.output) - mock_client_instance.delete_space_template.assert_called_once_with("test-template") + mock_hp_space_template.get.assert_called_once_with("test-template") + mock_template_instance.delete.assert_called_once() - @patch("sagemaker.hyperpod.cli.commands.space_template.KubernetesClient") - def test_space_template_delete_error(self, mock_k8s_client): + @patch("sagemaker.hyperpod.cli.commands.space_template.HPSpaceTemplate") + def test_space_template_delete_error(self, mock_hp_space_template): """Test space template deletion with error""" - mock_client_instance = Mock() - mock_k8s_client.return_value = mock_client_instance - mock_client_instance.delete_space_template.side_effect = Exception("Delete error") + mock_hp_space_template.get.side_effect = Exception("Delete error") result = self.runner.invoke(space_template_delete, ["--name", "test-template"]) self.assertEqual(result.exit_code, 0) self.assertIn("Error deleting space template 'test-template': Delete error", result.output) - @patch("sagemaker.hyperpod.cli.commands.space_template.KubernetesClient") - @patch("builtins.open", new_callable=mock_open, read_data="test: data") - @patch("yaml.safe_load") - def test_space_template_update_success(self, mock_yaml_load, mock_file, mock_k8s_client): + @patch("sagemaker.hyperpod.cli.commands.space_template.HPSpaceTemplate") + def test_space_template_update_success(self, mock_hp_space_template): """Test successful space template update""" - mock_yaml_load.return_value = self.mock_config_data - mock_client_instance = Mock() - mock_k8s_client.return_value = mock_client_instance + mock_template_instance = Mock() + mock_hp_space_template.get.return_value = mock_template_instance result = self.runner.invoke(space_template_update, ["--name", "test-template", "--file", "test.yaml"]) self.assertEqual(result.exit_code, 0) self.assertIn("Space template 'test-template' updated successfully", result.output) - mock_client_instance.patch_space_template.assert_called_once() - - @patch("sagemaker.hyperpod.cli.commands.space_template.KubernetesClient") - @patch("builtins.open", new_callable=mock_open, read_data="test: data") - @patch("yaml.safe_load") - def test_space_template_update_name_mismatch(self, mock_yaml_load, mock_file, mock_k8s_client): - """Test space template update with name mismatch""" - config_with_different_name = self.mock_config_data.copy() - config_with_different_name["metadata"]["name"] = "different-name" - mock_yaml_load.return_value = config_with_different_name - mock_client_instance = Mock() - mock_k8s_client.return_value = mock_client_instance - - result = self.runner.invoke(space_template_update, ["--name", "test-template", "--file", "test.yaml"]) - - self.assertEqual(result.exit_code, 0) - self.assertIn("Error: Name mismatch. CLI parameter 'test-template' does not match YAML name 'different-name'", result.output) - mock_client_instance.patch_space_template.assert_not_called() - - @patch("sagemaker.hyperpod.cli.commands.space_template.KubernetesClient") - def test_space_template_update_file_not_found(self, mock_k8s_client): - """Test space template update with missing file""" - result = self.runner.invoke(space_template_update, ["--name", "test-template", "--file", "nonexistent.yaml"]) - - self.assertEqual(result.exit_code, 0) - self.assertIn("Error: File 'nonexistent.yaml' not found", result.output) + mock_hp_space_template.get.assert_called_once_with("test-template") + mock_template_instance.update.assert_called_once_with("test.yaml") - @patch("sagemaker.hyperpod.cli.commands.space_template.KubernetesClient") - @patch("builtins.open", new_callable=mock_open, read_data="invalid: yaml: content:") - @patch("yaml.safe_load") - def test_space_template_update_yaml_error(self, mock_yaml_load, mock_file, mock_k8s_client): + @patch("sagemaker.hyperpod.cli.commands.space_template.HPSpaceTemplate") + def test_space_template_update_yaml_error(self, mock_hp_space_template): """Test space template update with YAML parsing error""" - mock_yaml_load.side_effect = yaml.YAMLError("Invalid YAML") - - result = self.runner.invoke(space_template_update, ["--name", "test-template", "--file", "test.yaml"]) - - self.assertEqual(result.exit_code, 0) - self.assertIn("Error parsing YAML file: Invalid YAML", result.output) - - @patch("sagemaker.hyperpod.cli.commands.space_template.KubernetesClient") - @patch("builtins.open", new_callable=mock_open, read_data="test: data") - @patch("yaml.safe_load") - def test_space_template_update_k8s_error(self, mock_yaml_load, mock_file, mock_k8s_client): - """Test space template update with Kubernetes error""" - mock_yaml_load.return_value = self.mock_config_data - mock_client_instance = Mock() - mock_k8s_client.return_value = mock_client_instance - mock_client_instance.patch_space_template.side_effect = Exception("K8s error") + mock_template_instance = Mock() + mock_template_instance.update.side_effect = yaml.YAMLError("Invalid YAML") + mock_hp_space_template.get.return_value = mock_template_instance result = self.runner.invoke(space_template_update, ["--name", "test-template", "--file", "test.yaml"]) self.assertEqual(result.exit_code, 0) - self.assertIn("Error updating space template 'test-template': K8s error", result.output) - - @patch("sagemaker.hyperpod.cli.commands.space_template.KubernetesClient") - @patch("builtins.open", new_callable=mock_open, read_data="test: data") - @patch("yaml.safe_load") - def test_space_template_update_removes_immutable_fields(self, mock_yaml_load, mock_file, mock_k8s_client): - """Test space template update removes immutable fields""" - config_with_immutable_fields = { - "metadata": { - "name": "test-template", - "resourceVersion": "12345", - "uid": "abc-123", - "creationTimestamp": "2023-01-01T00:00:00Z", - "managedFields": [{"manager": "kubectl"}] - }, - "spec": {"displayName": "Test Template"} - } - mock_yaml_load.return_value = config_with_immutable_fields - mock_client_instance = Mock() - mock_k8s_client.return_value = mock_client_instance - - result = self.runner.invoke(space_template_update, ["--name", "test-template", "--file", "test.yaml"]) - - self.assertEqual(result.exit_code, 0) - # Verify patch was called with cleaned config - call_args = mock_client_instance.patch_space_template.call_args[0][1] - self.assertNotIn("resourceVersion", call_args["metadata"]) - self.assertNotIn("uid", call_args["metadata"]) - self.assertNotIn("creationTimestamp", call_args["metadata"]) - self.assertNotIn("managedFields", call_args["metadata"]) - self.assertEqual(call_args["metadata"]["name"], "test-template") - - -if __name__ == "__main__": - unittest.main() + self.assertIn("Error updating space template 'test-template': Invalid YAML", result.output) diff --git a/test/unit_tests/test_hyperpod_space_template.py b/test/unit_tests/test_hyperpod_space_template.py index 9d03534b..8e918bb9 100644 --- a/test/unit_tests/test_hyperpod_space_template.py +++ b/test/unit_tests/test_hyperpod_space_template.py @@ -31,7 +31,7 @@ def test_init_success(self, mock_yaml_load, mock_file): mock_yaml_load.return_value = self.mock_config_data mock_file.return_value.read.return_value = self.yaml_content - template = HPSpaceTemplate("test.yaml") + template = HPSpaceTemplate(file_path="test.yaml") self.assertEqual(template.config_data, self.mock_config_data) self.assertEqual(template.name, "test-template") @@ -41,7 +41,7 @@ def test_init_success(self, mock_yaml_load, mock_file): def test_init_file_not_found(self, mock_file): """Test initialization with non-existent file""" with self.assertRaises(FileNotFoundError) as context: - HPSpaceTemplate("nonexistent.yaml") + HPSpaceTemplate(file_path="nonexistent.yaml") self.assertIn("File 'nonexistent.yaml' not found", str(context.exception)) @patch('builtins.open', new_callable=mock_open) @@ -49,7 +49,7 @@ def test_init_file_not_found(self, mock_file): def test_init_yaml_error(self, mock_yaml_load, mock_file): """Test initialization with invalid YAML""" with self.assertRaises(ValueError) as context: - HPSpaceTemplate("invalid.yaml") + HPSpaceTemplate(file_path="invalid.yaml") self.assertIn("Error parsing YAML file", str(context.exception)) @patch('sagemaker.hyperpod.space.hyperpod_space_template.config.load_kube_config') @@ -82,7 +82,7 @@ def test_create_success(self, mock_verify_config, mock_custom_api_class, mock_ya mock_custom_api_class.return_value = mock_custom_api mock_custom_api.create_cluster_custom_object.return_value = self.mock_config_data - template = HPSpaceTemplate("test.yaml") + template = HPSpaceTemplate(file_path="test.yaml") template.create() mock_verify_config.assert_called_once() @@ -105,7 +105,7 @@ def test_create_api_exception(self, mock_handle_exception, mock_verify_config, m mock_custom_api_class.return_value = mock_custom_api mock_custom_api.create_cluster_custom_object.side_effect = ApiException(status=409) - template = HPSpaceTemplate("test.yaml") + template = HPSpaceTemplate(file_path="test.yaml") template.create() mock_handle_exception.assert_called_once() @@ -121,7 +121,7 @@ def test_create_general_exception(self, mock_verify_config, mock_custom_api_clas mock_custom_api_class.return_value = mock_custom_api mock_custom_api.create_cluster_custom_object.side_effect = Exception("Creation failed") - template = HPSpaceTemplate("test.yaml") + template = HPSpaceTemplate(file_path="test.yaml") with self.assertRaises(Exception): template.create() @@ -236,7 +236,7 @@ def test_delete_success(self, mock_verify_config, mock_custom_api_class, mock_ya mock_custom_api = Mock() mock_custom_api_class.return_value = mock_custom_api - template = HPSpaceTemplate("test.yaml") + template = HPSpaceTemplate(file_path="test.yaml") template.delete() mock_verify_config.assert_called_once() @@ -259,7 +259,7 @@ def test_delete_api_exception(self, mock_handle_exception, mock_verify_config, m mock_custom_api_class.return_value = mock_custom_api mock_custom_api.delete_cluster_custom_object.side_effect = ApiException(status=404) - template = HPSpaceTemplate("test.yaml") + template = HPSpaceTemplate(file_path="test.yaml") template.delete() mock_handle_exception.assert_called_once() @@ -275,7 +275,7 @@ def test_update_success(self, mock_verify_config, mock_custom_api_class, mock_ya mock_custom_api_class.return_value = mock_custom_api mock_custom_api.patch_cluster_custom_object.return_value = self.mock_config_data - template = HPSpaceTemplate("test.yaml") + template = HPSpaceTemplate(file_path="test.yaml") template.update("updated.yaml") mock_verify_config.assert_called_once() @@ -297,7 +297,7 @@ def test_update_name_mismatch(self, mock_verify_config, mock_yaml_load, mock_fil {"metadata": {"name": "different-name"}} ] - template = HPSpaceTemplate("test.yaml") + template = HPSpaceTemplate(file_path="test.yaml") with self.assertRaises(ValueError) as context: template.update("different.yaml") @@ -311,7 +311,7 @@ def test_update_file_not_found(self, mock_verify_config, mock_yaml_load, mock_fi mock_yaml_load.return_value = self.mock_config_data mock_file.side_effect = [mock_open().return_value, FileNotFoundError("File 'nonexistent.yaml' not found")] - template = HPSpaceTemplate("test.yaml") + template = HPSpaceTemplate(file_path="test.yaml") with self.assertRaises(FileNotFoundError) as context: template.update("nonexistent.yaml") @@ -324,7 +324,7 @@ def test_update_yaml_error(self, mock_verify_config, mock_yaml_load, mock_file): """Test space template update with YAML error""" mock_yaml_load.side_effect = [self.mock_config_data, yaml.YAMLError("Invalid YAML")] - template = HPSpaceTemplate("test.yaml") + template = HPSpaceTemplate(file_path="test.yaml") with self.assertRaises(ValueError) as context: template.update("invalid.yaml") @@ -342,7 +342,7 @@ def test_update_api_exception(self, mock_handle_exception, mock_verify_config, m mock_custom_api_class.return_value = mock_custom_api mock_custom_api.patch_cluster_custom_object.side_effect = ApiException(status=404) - template = HPSpaceTemplate("test.yaml") + template = HPSpaceTemplate(file_path="test.yaml") template.update("updated.yaml") mock_handle_exception.assert_called_once() @@ -353,7 +353,7 @@ def test_to_yaml(self, mock_yaml_load, mock_file): """Test converting space template to YAML""" mock_yaml_load.return_value = self.mock_config_data - template = HPSpaceTemplate("test.yaml") + template = HPSpaceTemplate(file_path="test.yaml") result = template.to_yaml() self.assertIsInstance(result, str) @@ -365,7 +365,7 @@ def test_to_dict(self, mock_yaml_load, mock_file): """Test converting space template to dictionary""" mock_yaml_load.return_value = self.mock_config_data - template = HPSpaceTemplate("test.yaml") + template = HPSpaceTemplate(file_path="test.yaml") result = template.to_dict() self.assertEqual(result, self.mock_config_data) From bba82d716dd181507a1aa0418e9e72d9816186dc Mon Sep 17 00:00:00 2001 From: aws-brianxia Date: Wed, 12 Nov 2025 11:20:58 -0800 Subject: [PATCH 19/31] Add additional Space parameters for resources including the fractional GPU (#287) --- src/sagemaker/hyperpod/cli/space_utils.py | 76 +++++++++++++++++-- .../hyperpod/space/hyperpod_space.py | 2 +- test/unit_tests/cli/test_space_utils.py | 33 ++++++-- test/unit_tests/test_hyperpod_space.py | 2 +- 4 files changed, 96 insertions(+), 17 deletions(-) diff --git a/src/sagemaker/hyperpod/cli/space_utils.py b/src/sagemaker/hyperpod/cli/space_utils.py index f9a27a6e..9068ed83 100644 --- a/src/sagemaker/hyperpod/cli/space_utils.py +++ b/src/sagemaker/hyperpod/cli/space_utils.py @@ -42,22 +42,40 @@ def generate_click_command( def decorator(func: Callable) -> Callable: # build resources from CPU/memory options - def _build_resources(cpu, memory, gpu): - if cpu is None and memory is None and gpu is None: + def _build_resources(cpu, cpu_limit, memory, memory_limit, gpu, gpu_limit, + accelerator_partition_type, accelerator_partition_count): + if not any([cpu, cpu_limit, memory, memory_limit, gpu, gpu_limit, + accelerator_partition_type, accelerator_partition_count]): return None + + if (accelerator_partition_type is None) ^ (accelerator_partition_count is None): + raise click.UsageError( + "Both accelerator-partition-type and accelerator-partition-count must be specified together" + ) # Build requests dictionary requests = {} + limits = {} if cpu is not None: requests["cpu"] = cpu + if cpu_limit is not None: + limits["cpu"] = cpu_limit if memory is not None: requests["memory"] = memory + if memory_limit is not None: + limits["memory"] = memory_limit if gpu is not None: requests["nvidia.com/gpu"] = gpu + if gpu_limit is not None: + limits["nvidia.com/gpu"] = gpu_limit + if accelerator_partition_type is not None and accelerator_partition_count is not None: + requests[f"nvidia.com/{accelerator_partition_type}"] = accelerator_partition_count + limits[f"nvidia.com/{accelerator_partition_type}"] = accelerator_partition_count # Return ResourceRequirements structure return { - "requests": requests + "requests": requests, + "limits": limits, } def _parse_volume_param(ctx, param, value): @@ -140,7 +158,16 @@ def wrapped_func(*args, **kwargs): if Model is None: raise click.ClickException(f"Unsupported schema version: {version}") - resources = _build_resources(kwargs.pop("cpu", None), kwargs.pop("memory", None), kwargs.pop("gpu", None)) + resources = _build_resources( + kwargs.pop("cpu", None), + kwargs.pop("cpu_limit", None), + kwargs.pop("memory", None), + kwargs.pop("memory_limit", None), + kwargs.pop("gpu", None), + kwargs.pop("gpu_limit", None), + kwargs.pop("accelerator_partition_type", None), + kwargs.pop("accelerator_partition_count", None), + ) if resources is not None: kwargs["resources"] = resources @@ -210,21 +237,56 @@ def wrapped_func(*args, **kwargs): "--cpu", type=str, default=None, - help="CPU resource, e.g. '250m'", + help="CPU resource request, e.g. '250m'", + )(wrapped_func) + + wrapped_func = click.option( + "--cpu-limit", + type=str, + default=None, + help="CPU resource limit, e.g. '250m'", )(wrapped_func) wrapped_func = click.option( "--memory", type=str, default=None, - help="Memory resource, e.g. '256Mi'", + help="Memory resource request, e.g. '256Mi'", + )(wrapped_func) + + wrapped_func = click.option( + "--memory-limit", + type=str, + default=None, + help="Memory resource limit, e.g. '256Mi'", )(wrapped_func) wrapped_func = click.option( "--gpu", type=str, default=None, - help="Gpu resource, e.g. '1'", + help="Gpu resource request, e.g. '1'", + )(wrapped_func) + + wrapped_func = click.option( + "--gpu-limit", + type=str, + default=None, + help="Gpu resource limit, e.g. '1'", + )(wrapped_func) + + wrapped_func = click.option( + "--accelerator-partition-type", + type=str, + default=None, + help="Fractional GPU parition type", + )(wrapped_func) + + wrapped_func = click.option( + "--accelerator-partition-count", + type=str, + default=None, + help="Fractional GPU parition count", )(wrapped_func) wrapped_func = click.option( diff --git a/src/sagemaker/hyperpod/space/hyperpod_space.py b/src/sagemaker/hyperpod/space/hyperpod_space.py index b0fd9f11..c29932c6 100644 --- a/src/sagemaker/hyperpod/space/hyperpod_space.py +++ b/src/sagemaker/hyperpod/space/hyperpod_space.py @@ -314,7 +314,7 @@ def list_pods(self) -> List[str]: try: pods = v1.list_namespaced_pod( namespace=self.config.namespace, - label_selector=f"{SPACE_GROUP}/workspaceName={self.config.name}" + label_selector=f"{SPACE_GROUP}/workspace-name={self.config.name}" ) return [pod.metadata.name for pod in pods.items] except Exception as e: diff --git a/test/unit_tests/cli/test_space_utils.py b/test/unit_tests/cli/test_space_utils.py index 389949f5..47f5911c 100644 --- a/test/unit_tests/cli/test_space_utils.py +++ b/test/unit_tests/cli/test_space_utils.py @@ -91,7 +91,7 @@ def cmd(version, domain_config): @patch('sagemaker.hyperpod.cli.space_utils.load_schema_for_version') def test_resources_building(self, mock_load_schema): - """Test CPU and memory resource building""" + """Test CPU, memory, GPU and fractional GPU resource building""" schema = { 'properties': { 'resources': { @@ -117,26 +117,43 @@ class Config: def cmd(version, domain_config): click.echo(json.dumps(domain_config.get('resources'))) - # Test with custom CPU and memory - result = self.runner.invoke(cmd, ['--cpu', '1000m', '--memory', '1Gi']) + # Test with CPU and memory requests and limits + result = self.runner.invoke(cmd, ['--cpu', '1000m', '--cpu-limit', '2000m', '--memory', '1Gi', '--memory-limit', '2Gi']) assert result.exit_code == 0 output = json.loads(result.output) assert output['requests']['cpu'] == '1000m' assert output['requests']['memory'] == '1Gi' - assert 'nvidia.com/gpu' not in output['requests'] + assert output['limits']['cpu'] == '2000m' + assert output['limits']['memory'] == '2Gi' - # Test with only CPU - result = self.runner.invoke(cmd, ['--cpu', '750m']) + # Test with GPU requests and limits + result = self.runner.invoke(cmd, ['--gpu', '1', '--gpu-limit', '2']) assert result.exit_code == 0 output = json.loads(result.output) - assert output['requests']['cpu'] == '750m' - assert 'memory' not in output['requests'] + assert output['requests']['nvidia.com/gpu'] == '1' + assert output['limits']['nvidia.com/gpu'] == '2' + + # Test with fractional GPU partitioning + result = self.runner.invoke(cmd, ['--accelerator-partition-type', 'mig-1g.5gb', '--accelerator-partition-count', '2']) + assert result.exit_code == 0 + output = json.loads(result.output) + assert output['requests']['nvidia.com/mig-1g.5gb'] == '2' + assert output['limits']['nvidia.com/mig-1g.5gb'] == '2' # Test with no resources specified result = self.runner.invoke(cmd, []) assert result.exit_code == 0 assert result.output.strip() == 'null' + # Test error when only one accelerator partition parameter is provided + result = self.runner.invoke(cmd, ['--accelerator-partition-type', 'mig-1g.5gb']) + assert result.exit_code == 2 + assert 'Both accelerator-partition-type and accelerator-partition-count must be specified together' in result.output + + result = self.runner.invoke(cmd, ['--accelerator-partition-count', '2']) + assert result.exit_code == 2 + assert 'Both accelerator-partition-type and accelerator-partition-count must be specified together' in result.output + @patch('sagemaker.hyperpod.cli.space_utils.load_schema_for_version') def test_type_conversion(self, mock_load_schema): """Test type conversion for different parameter types""" diff --git a/test/unit_tests/test_hyperpod_space.py b/test/unit_tests/test_hyperpod_space.py index 99776583..c06e89c2 100644 --- a/test/unit_tests/test_hyperpod_space.py +++ b/test/unit_tests/test_hyperpod_space.py @@ -504,7 +504,7 @@ def test_list_pods_success(self, mock_verify_config, mock_core_api_class): self.assertEqual(result, ["pod1", "pod2"]) mock_core_api.list_namespaced_pod.assert_called_once_with( namespace="test-namespace", - label_selector="workspace.jupyter.org/workspaceName=test-space" + label_selector="workspace.jupyter.org/workspace-name=test-space" ) @patch('sagemaker.hyperpod.space.hyperpod_space.client.CoreV1Api') From 5dffc800b45a0f21a4004ac1bde60021db1dc690 Mon Sep 17 00:00:00 2001 From: aws-brianxia Date: Mon, 17 Nov 2025 19:19:27 -0800 Subject: [PATCH 20/31] Implement validation for mig profiles for Spaces (#291) * Implement validation for mig profiles when creating/updating spaces * Update Space parameter model * Make Space Template namespaced resource --- .../hyperpod_space_template/v1_0/model.py | 74 +++++++-- .../hyperpod_space_template/v1_0/schema.json | 145 ++++++++++++++++-- .../hyperpod/cli/commands/space_access.py | 6 +- .../hyperpod/cli/constants/space_constants.py | 4 +- src/sagemaker/hyperpod/cli/space_utils.py | 25 +++ .../hyperpod/space/hyperpod_space.py | 86 +++++++++-- .../hyperpod/space/hyperpod_space_template.py | 35 ++++- src/sagemaker/hyperpod/space/utils.py | 36 ++++- test/unit_tests/cli/test_space_utils.py | 103 ++++++++++++- .../test_hyperpod_space_template.py | 65 +++++--- test/unit_tests/test_space_utils.py | 61 +++++++- 11 files changed, 572 insertions(+), 68 deletions(-) diff --git a/hyperpod-space-template/hyperpod_space_template/v1_0/model.py b/hyperpod-space-template/hyperpod_space_template/v1_0/model.py index 016c2978..16d51883 100644 --- a/hyperpod-space-template/hyperpod_space_template/v1_0/model.py +++ b/hyperpod-space-template/hyperpod_space_template/v1_0/model.py @@ -43,6 +43,41 @@ class ContainerConfig(BaseModel): ) +class TemplateRef(BaseModel): + """ContainerConfig defines container command and args configuration""" + name: str = Field( + description="Name of the WorkspaceTemplate" + ) + namespace: Optional[str] = Field( + default=None, + description="Namespace where the WorkspaceTemplate is located" + ) + + +class IdleDetectionSpec(BaseModel): + """IdleDetectionSpec defines idle detection methods""" + http_get: Optional[Dict[str, Any]] = Field( + default=None, + alias="httpGet", + description="HTTPGet specifies the HTTP request to perform for idle detection" + ) + + +class IdleShutdownSpec(BaseModel): + """IdleShutdownSpec defines idle shutdown configuration""" + enabled: bool = Field( + description="Enabled indicates if idle shutdown is enabled" + ) + timeout_minutes: int = Field( + alias="timeoutMinutes", + description="TimeoutMinutes specifies idle timeout in minutes", + ge=1 + ) + detection: IdleDetectionSpec = Field( + description="Detection specifies how to detect idle state" + ) + + class StorageSpec(BaseModel): """StorageSpec defines the storage configuration for Workspace""" storage_class_name: Optional[str] = Field( @@ -63,11 +98,11 @@ class StorageSpec(BaseModel): class ResourceRequirements(BaseModel): """ResourceRequirements describes the compute resource requirements""" - requests: Optional[Dict[str, str]] = Field( + requests: Optional[Dict[str, Optional[str]]] = Field( default=None, description="Requests describes the minimum amount of compute resources required. If Requests is omitted for a container, it defaults to Limits if that is explicitly specified, otherwise to an implementation-defined value. Requests cannot exceed Limits." ) - limits: Optional[Dict[str, str]] = Field( + limits: Optional[Dict[str, Optional[str]]] = Field( default=None, description="Limits describes the maximum amount of compute resources allowed." ) @@ -105,7 +140,7 @@ class SpaceConfig(BaseModel): ownership_type: Optional[OwnershipType] = Field( default=None, alias="ownership_type", - description="OwnershipType specifies who can modify the space. Public means anyone with RBAC permissions can update/delete the space. OwnerOnly means only the creator can update/delete the space." + description="OwnershipType specifies who can modify the space. 'Public' means anyone with RBAC permissions can update/delete the space. 'OwnerOnly' means only the creator can update/delete the space." ) resources: Optional[ResourceRequirements] = Field( default=None, @@ -127,24 +162,39 @@ class SpaceConfig(BaseModel): node_selector: Optional[Dict[str, str]] = Field( default=None, alias="node_selector", - description="NodeSelector specifies node selection constraints for the space pod (JSON)" + description="NodeSelector specifies node selection constraints for the space pod (JSON string)" ) affinity: Optional[Dict[str, Any]] = Field( default=None, - description="Affinity specifies node affinity and anti-affinity rules for the space pod (JSON)" + description="Affinity specifies node affinity and anti-affinity rules for the space pod (JSON string)" ) tolerations: Optional[List[Dict[str, Any]]] = Field( default=None, - description="Tolerations specifies tolerations for the space pod to schedule on nodes with matching taints (JSON)" + description="Tolerations specifies tolerations for the space pod to schedule on nodes with matching taints (JSON string)" ) lifecycle: Optional[Dict[str, Any]] = Field( default=None, - description="Lifecycle specifies actions that the management system should take in response to container lifecycle events (JSON)" + description="Lifecycle specifies actions that the management system should take in response to container lifecycle events (JSON string)" ) - template_ref: Optional[str] = Field( + template_ref: Optional[TemplateRef] = Field( default=None, alias="template_ref", - description="TemplateRef references a WorkspaceTemplate to use as base configuration. When set, template provides defaults and spec fields (Image, Resources, Storage.Size) act as overrides." + description="TemplateRef references a WorkspaceTemplate to use as base configuration. When set, template provides defaults and workspace spec fields act as overrides" + ) + idle_shutdown: Optional[IdleShutdownSpec] = Field( + default=None, + alias="idle_shutdown", + description="IdleShutdown specifies idle shutdown configuration" + ) + app_type: Optional[str] = Field( + default=None, + alias="app_type", + description="AppType specifies the application type for this workspace" + ) + service_account_name: Optional[str] = Field( + default=None, + alias="service_account_name", + description="ServiceAccountName specifies the name of the ServiceAccount to use for the workspace pod" ) @field_validator('volumes') @@ -199,6 +249,12 @@ def to_domain(self) -> Dict: spec["lifecycle"] = self.lifecycle if self.template_ref is not None: spec["templateRef"] = self.template_ref + if self.idle_shutdown is not None: + spec["idleShutdown"] = self.idle_shutdown.model_dump(exclude_none=True, by_alias=True) + if self.app_type is not None: + spec["appType"] = self.app_type + if self.service_account_name is not None: + spec["serviceAccountName"] = self.service_account_name # Create metadata metadata = {"name": self.name} diff --git a/hyperpod-space-template/hyperpod_space_template/v1_0/schema.json b/hyperpod-space-template/hyperpod_space_template/v1_0/schema.json index 30aa045d..71fa032a 100644 --- a/hyperpod-space-template/hyperpod_space_template/v1_0/schema.json +++ b/hyperpod-space-template/hyperpod_space_template/v1_0/schema.json @@ -47,6 +47,54 @@ "title": "DesiredStatus", "type": "string" }, + "IdleDetectionSpec": { + "description": "IdleDetectionSpec defines idle detection methods", + "properties": { + "httpGet": { + "anyOf": [ + { + "additionalProperties": true, + "type": "object" + }, + { + "type": "null" + } + ], + "default": null, + "description": "HTTPGet specifies the HTTP request to perform for idle detection", + "title": "Httpget" + } + }, + "title": "IdleDetectionSpec", + "type": "object" + }, + "IdleShutdownSpec": { + "description": "IdleShutdownSpec defines idle shutdown configuration", + "properties": { + "enabled": { + "description": "Enabled indicates if idle shutdown is enabled", + "title": "Enabled", + "type": "boolean" + }, + "timeoutMinutes": { + "description": "TimeoutMinutes specifies idle timeout in minutes", + "minimum": 1, + "title": "Timeoutminutes", + "type": "integer" + }, + "detection": { + "$ref": "#/$defs/IdleDetectionSpec", + "description": "Detection specifies how to detect idle state" + } + }, + "required": [ + "enabled", + "timeoutMinutes", + "detection" + ], + "title": "IdleShutdownSpec", + "type": "object" + }, "OwnershipType": { "enum": [ "Public", @@ -62,7 +110,14 @@ "anyOf": [ { "additionalProperties": { - "type": "string" + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ] }, "type": "object" }, @@ -78,7 +133,14 @@ "anyOf": [ { "additionalProperties": { - "type": "string" + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ] }, "type": "object" }, @@ -140,6 +202,34 @@ "title": "StorageSpec", "type": "object" }, + "TemplateRef": { + "description": "ContainerConfig defines container command and args configuration", + "properties": { + "name": { + "description": "Name of the WorkspaceTemplate", + "title": "Name", + "type": "string" + }, + "namespace": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Namespace where the WorkspaceTemplate is located", + "title": "Namespace" + } + }, + "required": [ + "name" + ], + "title": "TemplateRef", + "type": "object" + }, "VolumeSpec": { "description": "VolumeSpec defines a volume to mount from an existing PVC", "properties": { @@ -230,7 +320,7 @@ } ], "default": null, - "description": "OwnershipType specifies who can modify the space. Public means anyone with RBAC permissions can update/delete the space. OwnerOnly means only the creator can update/delete the space." + "description": "OwnershipType specifies who can modify the space. 'Public' means anyone with RBAC permissions can update/delete the space. 'OwnerOnly' means only the creator can update/delete the space." }, "resources": { "anyOf": [ @@ -297,7 +387,7 @@ } ], "default": null, - "description": "NodeSelector specifies node selection constraints for the space pod (JSON)", + "description": "NodeSelector specifies node selection constraints for the space pod (JSON string)", "title": "Node Selector" }, "affinity": { @@ -311,7 +401,7 @@ } ], "default": null, - "description": "Affinity specifies node affinity and anti-affinity rules for the space pod (JSON)", + "description": "Affinity specifies node affinity and anti-affinity rules for the space pod (JSON string)", "title": "Affinity" }, "tolerations": { @@ -328,7 +418,7 @@ } ], "default": null, - "description": "Tolerations specifies tolerations for the space pod to schedule on nodes with matching taints (JSON)", + "description": "Tolerations specifies tolerations for the space pod to schedule on nodes with matching taints (JSON string)", "title": "Tolerations" }, "lifecycle": { @@ -342,10 +432,47 @@ } ], "default": null, - "description": "Lifecycle specifies actions that the management system should take in response to container lifecycle events (JSON)", + "description": "Lifecycle specifies actions that the management system should take in response to container lifecycle events (JSON string)", "title": "Lifecycle" }, "template_ref": { + "anyOf": [ + { + "$ref": "#/$defs/TemplateRef" + }, + { + "type": "null" + } + ], + "default": null, + "description": "TemplateRef references a WorkspaceTemplate to use as base configuration. When set, template provides defaults and workspace spec fields act as overrides" + }, + "idle_shutdown": { + "anyOf": [ + { + "$ref": "#/$defs/IdleShutdownSpec" + }, + { + "type": "null" + } + ], + "default": null, + "description": "IdleShutdown specifies idle shutdown configuration" + }, + "app_type": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "AppType specifies the application type for this workspace", + "title": "App Type" + }, + "service_account_name": { "anyOf": [ { "type": "string" @@ -355,8 +482,8 @@ } ], "default": null, - "description": "TemplateRef references a WorkspaceTemplate to use as base configuration. When set, template provides defaults and spec fields (Image, Resources, Storage.Size) act as overrides.", - "title": "Template Ref" + "description": "ServiceAccountName specifies the name of the ServiceAccount to use for the workspace pod", + "title": "Service Account Name" } }, "required": [ diff --git a/src/sagemaker/hyperpod/cli/commands/space_access.py b/src/sagemaker/hyperpod/cli/commands/space_access.py index fbe36e63..f35994bc 100644 --- a/src/sagemaker/hyperpod/cli/commands/space_access.py +++ b/src/sagemaker/hyperpod/cli/commands/space_access.py @@ -9,7 +9,11 @@ @click.command("hyp-space-access") @click.option("--name", required=True, help="Name of the space") @click.option("--namespace", "-n", required=False, default="default", help="Kubernetes namespace") -@click.option("--connection-type", "-t", required=False, default="vscode-remote", help="Remote access type") +@click.option("--connection-type", "-t", + required=False, + default="vscode-remote", + help="Remote access type supported values: [vscode-remote, web-ui] [default: vscode-remote]" +) def space_access_create(name, namespace, connection_type): """Create a space access resource.""" diff --git a/src/sagemaker/hyperpod/cli/constants/space_constants.py b/src/sagemaker/hyperpod/cli/constants/space_constants.py index 0c4d4453..b595a7aa 100644 --- a/src/sagemaker/hyperpod/cli/constants/space_constants.py +++ b/src/sagemaker/hyperpod/cli/constants/space_constants.py @@ -16,5 +16,5 @@ # Immutable fields that cannot be updated after space creation IMMUTABLE_FIELDS = { "storage", # storage is immutable per Go struct validation - "template_ref", # templateRef is immutable per Go struct validation -} \ No newline at end of file +} +ENABLE_MIG_PROFILE_VALIDATION = False diff --git a/src/sagemaker/hyperpod/cli/space_utils.py b/src/sagemaker/hyperpod/cli/space_utils.py index 9068ed83..c8193098 100644 --- a/src/sagemaker/hyperpod/cli/space_utils.py +++ b/src/sagemaker/hyperpod/cli/space_utils.py @@ -69,6 +69,8 @@ def _build_resources(cpu, cpu_limit, memory, memory_limit, gpu, gpu_limit, if gpu_limit is not None: limits["nvidia.com/gpu"] = gpu_limit if accelerator_partition_type is not None and accelerator_partition_count is not None: + if not accelerator_partition_type.startswith("mig"): + raise click.UsageError(f"Invalid accelerator partition type '{accelerator_partition_type}'") requests[f"nvidia.com/{accelerator_partition_type}"] = accelerator_partition_count limits[f"nvidia.com/{accelerator_partition_type}"] = accelerator_partition_count @@ -149,6 +151,22 @@ def _parse_container_config_param(ctx, param, value): return parts except Exception as e: raise click.UsageError(f"Error parsing container-config: {str(e)}") + + def _parse_template_ref(ctx, param, value): + """Parse template ref from command line format to dictionary format.""" + if not value: + return None + + try: + parts = {} + for item in value.split(','): + if '=' not in item: + raise click.UsageError(f"Invalid template ref format: '{item}' should be key=value") + key, val = item.split('=', 1) + parts[key.strip()] = val.strip() + return parts + except Exception as e: + raise click.UsageError(f"Error parsing template ref: {str(e)}") # 1) the wrapper click will call def wrapped_func(*args, **kwargs): @@ -310,6 +328,12 @@ def wrapped_func(*args, **kwargs): help="Container configuration. Format: --container-config command=,args=", )(wrapped_func) + wrapped_func = click.option( + "--template-ref", + callback=_parse_template_ref, + help="TemplateRef references a WorkspaceTemplate to use as base configuration. Format: --template-ref name=,namespace=", + )(wrapped_func) + # Exclude the props that were handled out of the below for loop excluded_props = set( [ @@ -318,6 +342,7 @@ def wrapped_func(*args, **kwargs): "volumes", "storage", "container_config", + "template_ref", ] ) diff --git a/src/sagemaker/hyperpod/space/hyperpod_space.py b/src/sagemaker/hyperpod/space/hyperpod_space.py index c29932c6..efc098db 100644 --- a/src/sagemaker/hyperpod/space/hyperpod_space.py +++ b/src/sagemaker/hyperpod/space/hyperpod_space.py @@ -2,7 +2,7 @@ import yaml import boto3 from typing import List, Optional, ClassVar, Dict, Any -from pydantic import BaseModel, Field, ConfigDict +from pydantic import BaseModel, Field, ConfigDict, model_validator from kubernetes import client, config from kubernetes.client.rest import ApiException @@ -11,9 +11,15 @@ handle_exception, get_default_namespace, setup_logging, - verify_kubernetes_version_compatibility + verify_kubernetes_version_compatibility, + get_current_cluster, + get_current_region, + get_cluster_instance_types, +) +from sagemaker.hyperpod.space.utils import ( + map_kubernetes_response_to_model, + get_pod_instance_type, ) -from sagemaker.hyperpod.space.utils import map_kubernetes_response_to_model from sagemaker.hyperpod.common.telemetry.telemetry_logging import ( _hyperpod_telemetry_emitter, ) @@ -22,6 +28,7 @@ SPACE_GROUP, SPACE_VERSION, SPACE_PLURAL, + ENABLE_MIG_PROFILE_VALIDATION, ) from sagemaker.hyperpod.cli.constants.space_access_constants import ( SPACE_ACCESS_GROUP, @@ -30,6 +37,9 @@ ) from hyperpod_space_template.v1_0.model import SpaceConfig +if ENABLE_MIG_PROFILE_VALIDATION: + from sagemaker.hyperpod.training.hyperpod_pytorch_job import list_accelerator_partition_types + class HPSpace(BaseModel): """HyperPod Space on Amazon SageMaker HyperPod clusters. @@ -42,7 +52,7 @@ class HPSpace(BaseModel): model_config = ConfigDict(extra="forbid") config: SpaceConfig = Field( - description="The space configuration using the template model" + description="The space configuration using the space parameter model" ) raw_resource: Optional[Dict[str, Any]] = Field( @@ -96,11 +106,33 @@ def create(self, debug: bool = False): Raises: Exception: If the space creation fails or Kubernetes API call fails """ + self.verify_kube_config() logger = self.get_logger() logger = setup_logging(logger, debug) + # Validate supported MIG profiles for the cluster + if ENABLE_MIG_PROFILE_VALIDATION: + if self.config.resources: + mig_profiles = set() + if self.config.resources.requests: + mig_profiles.update([key for key in self.config.resources.requests.keys() if key.startswith("nvidia.com/mig")]) + if self.config.resources.limits: + mig_profiles.update([key for key in self.config.resources.limits.keys() if key.startswith("nvidia.com/mig")]) + + if len(mig_profiles) > 1: + raise RuntimeError("Space only supports one MIG profile") + + if mig_profiles: + cluster_instance_types = get_cluster_instance_types( + get_current_cluster(), + get_current_region() + ) + supported_mig_profiles = {profile for instance_type in cluster_instance_types for profile in list_accelerator_partition_types(instance_type)} + if list(mig_profiles)[0] not in supported_mig_profiles: + raise RuntimeError(f"Accelerator partition type '{list(mig_profiles)[0]}' does not exist in this cluster. Use 'hyp list-accelerator-partition-type' to check for available resources.") + # Convert config to domain model domain_config = self.config.to_domain() config_body = domain_config["space_spec"] @@ -189,12 +221,13 @@ def list(cls, namespace: Optional[str] = None) -> List["HPSpace"]: @classmethod @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "get_space") - def get(cls, name: str, namespace: str = "default") -> "HPSpace": + def get(cls, name: str, namespace: str = None) -> "HPSpace": """Get a specific HyperPod Space by name. Args: name (str): The name of the space to retrieve - namespace (str, optional): The Kubernetes namespace. Defaults to "default". + namespace (str, optional): The Kubernetes namespace. + If None, uses the default namespace from current context. Returns: HPSpace: The space instance @@ -204,6 +237,9 @@ def get(cls, name: str, namespace: str = "default") -> "HPSpace": """ cls.verify_kube_config() + if not namespace: + namespace = get_default_namespace() + custom_api = client.CustomObjectsApi() try: @@ -265,6 +301,33 @@ def update(self, **kwargs): self.verify_kube_config() logger = self.get_logger() + # Validate supported MIG profile for node which the Space is running on + if ENABLE_MIG_PROFILE_VALIDATION: + if "resources" in kwargs: + mig_profiles = set() + mig_profiles.update([key for key in kwargs["resources"].get("requests", {}).keys() if key.startswith("nvidia.com/mig")]) + mig_profiles.update([key for key in kwargs["resources"].get("limits", {}).keys() if key.startswith("nvidia.com/mig")]) + + if len(mig_profiles) > 1: + raise RuntimeError("Space only supports one MIG profile") + + if mig_profiles: + pods = self.list_pods() + if not pods: + raise RuntimeError(f"No pods found for space '{self.config.name}'") + + node_instance_type = get_pod_instance_type(pods[0], self.config.namespace) + supported_mig_profiles = set(list_accelerator_partition_types(node_instance_type)) + if list(mig_profiles)[0] not in supported_mig_profiles: + raise RuntimeError(f"Accelerator partition type '{list(mig_profiles)[0]}' does not exist in this cluster. Use 'hyp list-accelerator-partition-type' to check for available resources.") + + # Ensure existing MIG profile gets removed before setting a new one + existing_config = HPSpace.get(self.config.name, self.config.namespace).config + existing_mig_profiles = [key for key in existing_config.resources.requests.keys() if key.startswith("nvidia.com/mig")] + if existing_mig_profiles: + kwargs["resources"]["requests"].update({existing_mig_profiles[0]: None}) + kwargs["resources"]["limits"].update({existing_mig_profiles[0]: None}) + custom_api = client.CustomObjectsApi() # Update space config with the input config @@ -308,9 +371,9 @@ def list_pods(self) -> List[str]: """ self.verify_kube_config() logger = self.get_logger() - + v1 = client.CoreV1Api() - + try: pods = v1.list_namespaced_pod( namespace=self.config.namespace, @@ -363,7 +426,7 @@ def create_space_access(self, connection_type: str = "vscode-remote") -> Dict[st """Create a space access for this space. Args: - connection_type (str, optional): The IDE type for remote access. Defaults to "vscode". + connection_type (str, optional): The IDE type for remote access. Defaults to "vscode-remote". Returns: Dict[str, str]: Dictionary with 'SpaceConnectionType' and 'SpaceConnectionUrl' keys @@ -374,6 +437,9 @@ def create_space_access(self, connection_type: str = "vscode-remote") -> Dict[st self.verify_kube_config() logger = self.get_logger() + if connection_type not in {"vscode-remote", "web-ui"}: + raise ValueError("--connection-type must be 'vscode-remote' or 'web-ui'.") + config = { "metadata": { "namespace": self.config.namespace, @@ -401,4 +467,4 @@ def create_space_access(self, connection_type: str = "vscode-remote") -> Dict[st } except Exception as e: logger.error(f"Failed to create space access for {self.config.name}!") - handle_exception(e, self.config.name, self.config.namespace) + handle_exception(e, self.config.name, self.config.namespace) \ No newline at end of file diff --git a/src/sagemaker/hyperpod/space/hyperpod_space_template.py b/src/sagemaker/hyperpod/space/hyperpod_space_template.py index 80ae8800..0c596372 100644 --- a/src/sagemaker/hyperpod/space/hyperpod_space_template.py +++ b/src/sagemaker/hyperpod/space/hyperpod_space_template.py @@ -6,6 +6,7 @@ from sagemaker.hyperpod.common.utils import ( handle_exception, + get_default_namespace, verify_kubernetes_version_compatibility ) from sagemaker.hyperpod.common.telemetry.telemetry_logging import ( @@ -55,6 +56,7 @@ def __init__(self, *, file_path: Optional[str] = None, config_data: Optional[Dic self.config_data = config_data self.name = self.config_data.get('metadata', {}).get('name') + self.namespace = self.config_data.get('metadata', {}).get('namespace') @classmethod def get_logger(cls): @@ -80,9 +82,10 @@ def create(self) -> "HPSpaceTemplate": try: api_instance = client.CustomObjectsApi() - response = api_instance.create_cluster_custom_object( + response = api_instance.create_namespaced_custom_object( group=SPACE_TEMPLATE_GROUP, version=SPACE_TEMPLATE_VERSION, + namespace=self.namespace, plural=SPACE_TEMPLATE_PLURAL, body=self.config_data ) @@ -98,19 +101,27 @@ def create(self) -> "HPSpaceTemplate": @classmethod @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "list_space_templates") - def list(cls) -> List["HPSpaceTemplate"]: - """List all space templates. + def list(cls, namespace: Optional[str] = None) -> List["HPSpaceTemplate"]: + """List all space templates in the specified namespace. + + Args: + namespace (str, optional): The Kubernetes namespace to list space templates from. + If None, uses the default namespace from current context. Returns: List of HPSpaceTemplate instances """ cls.verify_kube_config() + + if not namespace: + namespace = get_default_namespace() try: api_instance = client.CustomObjectsApi() - response = api_instance.list_cluster_custom_object( + response = api_instance.list_namespaced_custom_object( group=SPACE_TEMPLATE_GROUP, version=SPACE_TEMPLATE_VERSION, + namespace=namespace, plural=SPACE_TEMPLATE_PLURAL ) @@ -128,22 +139,28 @@ def list(cls) -> List["HPSpaceTemplate"]: @classmethod @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "get_space_template") - def get(cls, name: str) -> "HPSpaceTemplate": + def get(cls, name: str, namespace: Optional[str] = None) -> "HPSpaceTemplate": """Get a specific space template by name. Args: name: Name of the space template + namespace (str, optional): The Kubernetes namespace. + If None, uses the default namespace from current context. Returns: HPSpaceTemplate instance """ cls.verify_kube_config() + + if not namespace: + namespace = get_default_namespace() try: api_instance = client.CustomObjectsApi() - response = api_instance.get_cluster_custom_object( + response = api_instance.get_namespaced_custom_object( group=SPACE_TEMPLATE_GROUP, version=SPACE_TEMPLATE_VERSION, + namespace=namespace, plural=SPACE_TEMPLATE_PLURAL, name=name ) @@ -167,9 +184,10 @@ def delete(self) -> None: try: api_instance = client.CustomObjectsApi() - api_instance.delete_cluster_custom_object( + api_instance.delete_namespaced_custom_object( group=SPACE_TEMPLATE_GROUP, version=SPACE_TEMPLATE_VERSION, + namespace=self.namespace, plural=SPACE_TEMPLATE_PLURAL, name=self.name ) @@ -209,9 +227,10 @@ def update(self, file_path: str) -> "HPSpaceTemplate": config_data['metadata'].pop(field, None) api_instance = client.CustomObjectsApi() - response = api_instance.patch_cluster_custom_object( + response = api_instance.patch_namespaced_custom_object( group=SPACE_TEMPLATE_GROUP, version=SPACE_TEMPLATE_VERSION, + namespace=self.namespace, plural=SPACE_TEMPLATE_PLURAL, name=self.name, body=config_data diff --git a/src/sagemaker/hyperpod/space/utils.py b/src/sagemaker/hyperpod/space/utils.py index 4b3023f6..da8200a1 100644 --- a/src/sagemaker/hyperpod/space/utils.py +++ b/src/sagemaker/hyperpod/space/utils.py @@ -1,8 +1,9 @@ """Utility functions for space operations.""" import re -from typing import Dict, Any, Set +from typing import Dict, Any, Set, List from pydantic import BaseModel +from kubernetes import client def camel_to_snake(name: str) -> str: @@ -55,3 +56,36 @@ def map_kubernetes_response_to_model(k8s_data: Dict[str, Any], model_class: Base mapped_data[snake_field] = value return mapped_data + + +def get_pod_instance_type(pod_name: str, namespace: str = "default") -> str: + """ + Get the instance type of the node where a pod is running. + + Args: + pod_name: Name of the pod + namespace: Kubernetes namespace of the pod + + Returns: + Instance type of the node running the pod + + Raises: + RuntimeError: If pod is not found or not scheduled on a node + """ + v1 = client.CoreV1Api() + + pod = v1.read_namespaced_pod(name=pod_name, namespace=namespace) + + if not pod.spec.node_name: + raise RuntimeError(f"Pod '{pod_name}' is not scheduled on any node") + + node = v1.read_node(name=pod.spec.node_name) + if node.metadata.labels: + instance_type = ( + node.metadata.labels.get('node.kubernetes.io/instance-type') or + node.metadata.labels.get('beta.kubernetes.io/instance-type') + ) + if instance_type: + return instance_type + + raise RuntimeError(f"Instance type not found for node '{pod.spec.node_name}'") diff --git a/test/unit_tests/cli/test_space_utils.py b/test/unit_tests/cli/test_space_utils.py index 47f5911c..9c658c0e 100644 --- a/test/unit_tests/cli/test_space_utils.py +++ b/test/unit_tests/cli/test_space_utils.py @@ -239,7 +239,6 @@ def test_immutable_fields_excluded_in_update(self, mock_load_schema): 'properties': { 'name': {'type': 'string'}, 'storage': {'type': 'object'}, # storage is immutable - 'template_ref': {'type': 'string'}, # template_ref is immutable 'image': {'type': 'string'} }, 'required': ['name'] @@ -266,7 +265,6 @@ def cmd(version, domain_config): assert result.exit_code == 0 # storage and template_ref should not be available in update mode assert '--storage' not in result.output - assert '--template-ref' not in result.output # but other fields should be available assert '--name' in result.output assert '--image' in result.output @@ -617,3 +615,104 @@ def cmd(version, domain_config): result = self.runner.invoke(cmd, ['--name', 'test-space']) assert result.exit_code == 0 assert 'success' in result.output + + @patch('sagemaker.hyperpod.cli.space_utils.load_schema_for_version') + def test_template_ref_parsing(self, mock_load_schema): + """Test template_ref parameter parsing""" + schema = { + 'properties': { + 'name': {'type': 'string'}, + 'template_ref': {'type': 'object'} + }, + 'required': ['name'] + } + mock_load_schema.return_value = schema + + class DummyModel(BaseModel): + class Config: + extra = 'allow' + + registry = {'1.0': DummyModel} + + @click.command() + @generate_click_command(registry=registry, schema_pkg="hyperpod_space_template") + def cmd(version, domain_config): + click.echo(json.dumps(domain_config.get('template_ref'))) + + # Test valid template_ref parsing + result = self.runner.invoke(cmd, [ + '--name', 'test-space', + '--template-ref', 'name=sagemaker-jupyter-template,namespace=jupyter-k8s-shared' + ]) + assert result.exit_code == 0 + template_ref = json.loads(result.output) + assert template_ref['name'] == 'sagemaker-jupyter-template' + assert template_ref['namespace'] == 'jupyter-k8s-shared' + + # Test template_ref with different values + result = self.runner.invoke(cmd, [ + '--name', 'test-space', + '--template-ref', 'name=custom-template,namespace=default' + ]) + assert result.exit_code == 0 + template_ref = json.loads(result.output) + assert template_ref['name'] == 'custom-template' + assert template_ref['namespace'] == 'default' + + # Test invalid template_ref format (missing equals) + result = self.runner.invoke(cmd, [ + '--name', 'test-space', + '--template-ref', 'invalid_format' + ]) + assert result.exit_code == 2 + assert 'Invalid template ref format' in result.output + + # Test invalid template_ref format (no comma separation) + result = self.runner.invoke(cmd, [ + '--name', 'test-space', + '--template-ref', 'name=template' + ]) + assert result.exit_code == 0 + template_ref = json.loads(result.output) + assert template_ref['name'] == 'template' + assert 'namespace' not in template_ref + + # Test empty template_ref + result = self.runner.invoke(cmd, ['--name', 'test-space']) + assert result.exit_code == 0 + assert result.output.strip() == 'null' + + @patch('sagemaker.hyperpod.cli.space_utils.load_schema_for_version') + def test_accelerator_partition_validation(self, mock_load_schema): + """Test accelerator partition type validation""" + schema = {'properties': {}, 'required': []} + mock_load_schema.return_value = schema + + class DummyModel(BaseModel): + class Config: + extra = 'allow' + + registry = {'1.0': DummyModel} + + @click.command() + @generate_click_command(registry=registry, schema_pkg="hyperpod_space_template") + def cmd(version, domain_config): + click.echo(json.dumps(domain_config.get('resources'))) + + # Test invalid accelerator partition type (not starting with 'mig') + result = self.runner.invoke(cmd, [ + '--accelerator-partition-type', 'invalid-type', + '--accelerator-partition-count', '2' + ]) + assert result.exit_code == 2 + assert "Invalid accelerator partition type 'invalid-type'" in result.output + + # Test valid accelerator partition type + result = self.runner.invoke(cmd, [ + '--accelerator-partition-type', 'mig-2g.10gb', + '--accelerator-partition-count', '1' + ]) + assert result.exit_code == 0 + output = json.loads(result.output) + assert output['requests']['nvidia.com/mig-2g.10gb'] == '1' + assert output['limits']['nvidia.com/mig-2g.10gb'] == '1' diff --git a/test/unit_tests/test_hyperpod_space_template.py b/test/unit_tests/test_hyperpod_space_template.py index 8e918bb9..755b24b3 100644 --- a/test/unit_tests/test_hyperpod_space_template.py +++ b/test/unit_tests/test_hyperpod_space_template.py @@ -15,7 +15,8 @@ def setUp(self): "apiVersion": "workspace.jupyter.org/v1alpha1", "kind": "WorkspaceTemplate", "metadata": { - "name": "test-template" + "name": "test-template", + "namespace": "test-namespace" }, "spec": { "displayName": "Test Template", @@ -80,15 +81,16 @@ def test_create_success(self, mock_verify_config, mock_custom_api_class, mock_ya mock_yaml_load.return_value = self.mock_config_data mock_custom_api = Mock() mock_custom_api_class.return_value = mock_custom_api - mock_custom_api.create_cluster_custom_object.return_value = self.mock_config_data + mock_custom_api.create_namespaced_custom_object.return_value = self.mock_config_data template = HPSpaceTemplate(file_path="test.yaml") template.create() mock_verify_config.assert_called_once() - mock_custom_api.create_cluster_custom_object.assert_called_once_with( + mock_custom_api.create_namespaced_custom_object.assert_called_once_with( group="workspace.jupyter.org", version="v1alpha1", + namespace="test-namespace", plural="workspacetemplates", body=self.mock_config_data ) @@ -103,7 +105,7 @@ def test_create_api_exception(self, mock_handle_exception, mock_verify_config, m mock_yaml_load.return_value = self.mock_config_data mock_custom_api = Mock() mock_custom_api_class.return_value = mock_custom_api - mock_custom_api.create_cluster_custom_object.side_effect = ApiException(status=409) + mock_custom_api.create_namespaced_custom_object.side_effect = ApiException(status=409) template = HPSpaceTemplate(file_path="test.yaml") template.create() @@ -119,7 +121,7 @@ def test_create_general_exception(self, mock_verify_config, mock_custom_api_clas mock_yaml_load.return_value = self.mock_config_data mock_custom_api = Mock() mock_custom_api_class.return_value = mock_custom_api - mock_custom_api.create_cluster_custom_object.side_effect = Exception("Creation failed") + mock_custom_api.create_namespaced_custom_object.side_effect = Exception("Creation failed") template = HPSpaceTemplate(file_path="test.yaml") @@ -128,44 +130,49 @@ def test_create_general_exception(self, mock_verify_config, mock_custom_api_clas @patch('sagemaker.hyperpod.space.hyperpod_space_template.client.CustomObjectsApi') @patch.object(HPSpaceTemplate, 'verify_kube_config') - def test_list_success(self, mock_verify_config, mock_custom_api_class): + @patch('sagemaker.hyperpod.space.hyperpod_space_template.get_default_namespace') + def test_list_success(self, mock_get_namespace, mock_verify_config, mock_custom_api_class): """Test successful space template listing""" mock_custom_api = Mock() mock_custom_api_class.return_value = mock_custom_api + mock_get_namespace.return_value = "default" mock_response = { "items": [ { - "metadata": {"name": "template1"}, + "metadata": {"name": "template1", "namespace": "default"}, "spec": {"displayName": "Template 1"} }, { - "metadata": {"name": "template2"}, + "metadata": {"name": "template2", "namespace": "default"}, "spec": {"displayName": "Template 2"} } ] } - mock_custom_api.list_cluster_custom_object.return_value = mock_response + mock_custom_api.list_namespaced_custom_object.return_value = mock_response with patch('builtins.open', new_callable=mock_open), \ patch('yaml.safe_load', return_value=mock_response["items"][0]): result = HPSpaceTemplate.list() self.assertEqual(len(result), 2) - mock_custom_api.list_cluster_custom_object.assert_called_once_with( + mock_custom_api.list_namespaced_custom_object.assert_called_once_with( group="workspace.jupyter.org", version="v1alpha1", + namespace="default", plural="workspacetemplates" ) @patch('sagemaker.hyperpod.space.hyperpod_space_template.client.CustomObjectsApi') - @patch.object(HPSpaceTemplate, 'verify_kube_config') + @patch('sagemaker.hyperpod.space.hyperpod_space_template.HPSpaceTemplate.verify_kube_config') @patch('sagemaker.hyperpod.space.hyperpod_space_template.handle_exception') - def test_list_api_exception(self, mock_handle_exception, mock_verify_config, mock_custom_api_class): + @patch('sagemaker.hyperpod.space.hyperpod_space_template.get_default_namespace') + def test_list_api_exception(self, mock_get_namespace, mock_handle_exception, mock_verify_config, mock_custom_api_class): """Test space template listing with API exception""" + mock_get_namespace.return_value = "default" mock_custom_api = Mock() mock_custom_api_class.return_value = mock_custom_api - mock_custom_api.list_cluster_custom_object.side_effect = ApiException(status=500) + mock_custom_api.list_namespaced_custom_object.side_effect = ApiException(status=500) HPSpaceTemplate.list() @@ -177,21 +184,24 @@ def test_list_general_exception(self, mock_verify_config, mock_custom_api_class) """Test space template listing with general exception""" mock_custom_api = Mock() mock_custom_api_class.return_value = mock_custom_api - mock_custom_api.list_cluster_custom_object.side_effect = Exception("List failed") + mock_custom_api.list_namespaced_custom_object.side_effect = Exception("List failed") with self.assertRaises(Exception): HPSpaceTemplate.list() @patch('sagemaker.hyperpod.space.hyperpod_space_template.client.CustomObjectsApi') @patch.object(HPSpaceTemplate, 'verify_kube_config') - def test_get_success(self, mock_verify_config, mock_custom_api_class): + @patch('sagemaker.hyperpod.space.hyperpod_space_template.get_default_namespace') + def test_get_success(self, mock_get_namespace, mock_verify_config, mock_custom_api_class): """Test successful space template retrieval""" mock_custom_api = Mock() mock_custom_api_class.return_value = mock_custom_api + mock_get_namespace.return_value = "default" mock_response = { "metadata": { "name": "test-template", + "namespace": "test-namespace", "managedFields": [{"manager": "test"}] }, "spec": {"displayName": "Test Template"} @@ -200,27 +210,30 @@ def test_get_success(self, mock_verify_config, mock_custom_api_class): "metadata": {"name": "test-template"}, "spec": {"displayName": "Test Template"} } - mock_custom_api.get_cluster_custom_object.return_value = mock_response + mock_custom_api.get_namespaced_custom_object.return_value = mock_response with patch('builtins.open', new_callable=mock_open), \ patch('yaml.safe_load', return_value=expected_response): result = HPSpaceTemplate.get("test-template") - mock_custom_api.get_cluster_custom_object.assert_called_once_with( + mock_custom_api.get_namespaced_custom_object.assert_called_once_with( group="workspace.jupyter.org", version="v1alpha1", + namespace="default", plural="workspacetemplates", name="test-template" ) @patch('sagemaker.hyperpod.space.hyperpod_space_template.client.CustomObjectsApi') - @patch.object(HPSpaceTemplate, 'verify_kube_config') + @patch('sagemaker.hyperpod.space.hyperpod_space_template.HPSpaceTemplate.verify_kube_config') @patch('sagemaker.hyperpod.space.hyperpod_space_template.handle_exception') - def test_get_api_exception(self, mock_handle_exception, mock_verify_config, mock_custom_api_class): + @patch('sagemaker.hyperpod.space.hyperpod_space_template.get_default_namespace') + def test_get_api_exception(self, mock_get_namespace, mock_handle_exception, mock_verify_config, mock_custom_api_class): """Test space template retrieval with API exception""" + mock_get_namespace.return_value = "default" mock_custom_api = Mock() mock_custom_api_class.return_value = mock_custom_api - mock_custom_api.get_cluster_custom_object.side_effect = ApiException(status=404) + mock_custom_api.get_namespaced_custom_object.side_effect = ApiException(status=404) HPSpaceTemplate.get("nonexistent-template") @@ -240,9 +253,10 @@ def test_delete_success(self, mock_verify_config, mock_custom_api_class, mock_ya template.delete() mock_verify_config.assert_called_once() - mock_custom_api.delete_cluster_custom_object.assert_called_once_with( + mock_custom_api.delete_namespaced_custom_object.assert_called_once_with( group="workspace.jupyter.org", version="v1alpha1", + namespace="test-namespace", plural="workspacetemplates", name="test-template" ) @@ -257,7 +271,7 @@ def test_delete_api_exception(self, mock_handle_exception, mock_verify_config, m mock_yaml_load.return_value = self.mock_config_data mock_custom_api = Mock() mock_custom_api_class.return_value = mock_custom_api - mock_custom_api.delete_cluster_custom_object.side_effect = ApiException(status=404) + mock_custom_api.delete_namespaced_custom_object.side_effect = ApiException(status=404) template = HPSpaceTemplate(file_path="test.yaml") template.delete() @@ -273,15 +287,16 @@ def test_update_success(self, mock_verify_config, mock_custom_api_class, mock_ya mock_yaml_load.side_effect = [self.mock_config_data, self.mock_config_data] mock_custom_api = Mock() mock_custom_api_class.return_value = mock_custom_api - mock_custom_api.patch_cluster_custom_object.return_value = self.mock_config_data + mock_custom_api.patch_namespaced_custom_object.return_value = self.mock_config_data template = HPSpaceTemplate(file_path="test.yaml") template.update("updated.yaml") mock_verify_config.assert_called_once() - mock_custom_api.patch_cluster_custom_object.assert_called_once_with( + mock_custom_api.patch_namespaced_custom_object.assert_called_once_with( group="workspace.jupyter.org", version="v1alpha1", + namespace="test-namespace", plural="workspacetemplates", name="test-template", body=self.mock_config_data @@ -340,7 +355,7 @@ def test_update_api_exception(self, mock_handle_exception, mock_verify_config, m mock_yaml_load.side_effect = [self.mock_config_data, self.mock_config_data] mock_custom_api = Mock() mock_custom_api_class.return_value = mock_custom_api - mock_custom_api.patch_cluster_custom_object.side_effect = ApiException(status=404) + mock_custom_api.patch_namespaced_custom_object.side_effect = ApiException(status=404) template = HPSpaceTemplate(file_path="test.yaml") template.update("updated.yaml") diff --git a/test/unit_tests/test_space_utils.py b/test/unit_tests/test_space_utils.py index 025c0ee1..a0e6a3ef 100644 --- a/test/unit_tests/test_space_utils.py +++ b/test/unit_tests/test_space_utils.py @@ -1,7 +1,9 @@ """Unit tests for space utils module.""" import unittest -from sagemaker.hyperpod.space.utils import camel_to_snake, get_model_fields, map_kubernetes_response_to_model +from unittest.mock import Mock, patch +from kubernetes import client +from sagemaker.hyperpod.space.utils import camel_to_snake, get_model_fields, map_kubernetes_response_to_model, get_pod_instance_type from hyperpod_space_template.v1_0.model import SpaceConfig @@ -74,3 +76,60 @@ def test_map_kubernetes_response_creates_valid_config(self): self.assertEqual(config.display_name, 'Valid Space') self.assertEqual(config.namespace, 'test') self.assertEqual(config.image, 'valid:latest') + + @patch('sagemaker.hyperpod.space.utils.client.CoreV1Api') + def test_get_pod_instance_type_success(self, mock_core_v1): + """Test successful retrieval of pod instance type.""" + # Mock pod with node assignment + mock_pod = Mock() + mock_pod.spec.node_name = 'test-node' + + # Mock node with instance type label + mock_node = Mock() + mock_node.metadata.labels = {'node.kubernetes.io/instance-type': 'ml.p4d.24xlarge'} + + # Setup API mock + mock_api = Mock() + mock_api.read_namespaced_pod.return_value = mock_pod + mock_api.read_node.return_value = mock_node + mock_core_v1.return_value = mock_api + + result = get_pod_instance_type('test-pod', 'default') + + self.assertEqual(result, 'ml.p4d.24xlarge') + mock_api.read_namespaced_pod.assert_called_once_with(name='test-pod', namespace='default') + mock_api.read_node.assert_called_once_with(name='test-node') + + @patch('sagemaker.hyperpod.space.utils.client.CoreV1Api') + def test_get_pod_instance_type_pod_not_scheduled(self, mock_core_v1): + """Test error when pod is not scheduled on any node.""" + mock_pod = Mock() + mock_pod.spec.node_name = None + + mock_api = Mock() + mock_api.read_namespaced_pod.return_value = mock_pod + mock_core_v1.return_value = mock_api + + with self.assertRaises(RuntimeError) as context: + get_pod_instance_type('unscheduled-pod') + + self.assertIn("Pod 'unscheduled-pod' is not scheduled", str(context.exception)) + + @patch('sagemaker.hyperpod.space.utils.client.CoreV1Api') + def test_get_pod_instance_type_no_instance_type_label(self, mock_core_v1): + """Test error when node has no instance type label.""" + mock_pod = Mock() + mock_pod.spec.node_name = 'test-node' + + mock_node = Mock() + mock_node.metadata.labels = {'other.label': 'value'} + + mock_api = Mock() + mock_api.read_namespaced_pod.return_value = mock_pod + mock_api.read_node.return_value = mock_node + mock_core_v1.return_value = mock_api + + with self.assertRaises(RuntimeError) as context: + get_pod_instance_type('test-pod') + + self.assertIn("Instance type not found for node 'test-node'", str(context.exception)) From 0afcec1998d7278e71ea8a1bbd331df472499a7d Mon Sep 17 00:00:00 2001 From: aws-brianxia Date: Wed, 19 Nov 2025 18:45:31 -0800 Subject: [PATCH 21/31] Parker GA issues (#296) * Update Space Template CLI to be namespaced * Space get-logs default to the workspace container * Remove error handling to bubble up the actual K8s errors * Listing public Spaces * Fix typos, elaborated text, add logic to parse idle-shutdown --- .../hyperpod_space_template/v1_0/model.py | 6 +- .../hyperpod_space_template/v1_0/schema.json | 8 +- src/sagemaker/hyperpod/cli/commands/space.py | 151 +++++++---------- .../hyperpod/cli/commands/space_access.py | 10 +- .../hyperpod/cli/commands/space_template.py | 90 +++++----- src/sagemaker/hyperpod/cli/space_utils.py | 62 ++++++- .../hyperpod/space/hyperpod_space.py | 26 ++- test/unit_tests/cli/test_space.py | 155 +----------------- test/unit_tests/cli/test_space_access.py | 14 -- test/unit_tests/cli/test_space_template.py | 80 +++------ test/unit_tests/test_hyperpod_space.py | 32 ++-- 11 files changed, 220 insertions(+), 414 deletions(-) diff --git a/hyperpod-space-template/hyperpod_space_template/v1_0/model.py b/hyperpod-space-template/hyperpod_space_template/v1_0/model.py index 16d51883..58084e11 100644 --- a/hyperpod-space-template/hyperpod_space_template/v1_0/model.py +++ b/hyperpod-space-template/hyperpod_space_template/v1_0/model.py @@ -68,9 +68,9 @@ class IdleShutdownSpec(BaseModel): enabled: bool = Field( description="Enabled indicates if idle shutdown is enabled" ) - timeout_minutes: int = Field( - alias="timeoutMinutes", - description="TimeoutMinutes specifies idle timeout in minutes", + idle_timeout_in_minutes: int = Field( + alias="idleTimeoutInMinutes", + description="IdleTimeoutInMinutes specifies idle timeout in minutes", ge=1 ) detection: IdleDetectionSpec = Field( diff --git a/hyperpod-space-template/hyperpod_space_template/v1_0/schema.json b/hyperpod-space-template/hyperpod_space_template/v1_0/schema.json index 71fa032a..b94817f5 100644 --- a/hyperpod-space-template/hyperpod_space_template/v1_0/schema.json +++ b/hyperpod-space-template/hyperpod_space_template/v1_0/schema.json @@ -76,10 +76,10 @@ "title": "Enabled", "type": "boolean" }, - "timeoutMinutes": { - "description": "TimeoutMinutes specifies idle timeout in minutes", + "idleTimeoutInMinutes": { + "description": "IdleTimeoutInMinutes specifies idle timeout in minutes", "minimum": 1, - "title": "Timeoutminutes", + "title": "Idletimeoutinminutes", "type": "integer" }, "detection": { @@ -89,7 +89,7 @@ }, "required": [ "enabled", - "timeoutMinutes", + "idleTimeoutInMinutes", "detection" ], "title": "IdleShutdownSpec", diff --git a/src/sagemaker/hyperpod/cli/commands/space.py b/src/sagemaker/hyperpod/cli/commands/space.py index 28c46b71..75261078 100644 --- a/src/sagemaker/hyperpod/cli/commands/space.py +++ b/src/sagemaker/hyperpod/cli/commands/space.py @@ -19,14 +19,10 @@ ) def space_create(version, config): """Create a space resource.""" - try: - space_config = SpaceConfig(**config) - space = HPSpace(config=space_config) - space.create() - - click.echo(f"Space '{space_config.name}' created successfully in namespace '{space_config.namespace}'") - except Exception as e: - click.echo(f"Error creating space: {e}", err=True) + space_config = SpaceConfig(**config) + space = HPSpace(config=space_config) + space.create() + click.echo(f"Space '{space_config.name}' created successfully in namespace '{space_config.namespace}'") @click.command("hyp-space") @@ -34,42 +30,39 @@ def space_create(version, config): @click.option("--output", "-o", type=click.Choice(["table", "json"]), default="table") def space_list(namespace, output): """List space resources.""" - try: - spaces = HPSpace.list(namespace=namespace) - - if output == "json": - spaces_data = [] + spaces = HPSpace.list(namespace=namespace) + + if output == "json": + spaces_data = [] + for space in spaces: + space_dict = space.config.model_dump() + spaces_data.append(space_dict) + click.echo(json.dumps(spaces_data, indent=2)) + else: + if spaces: + table_data = [] for space in spaces: - space_dict = space.config.model_dump() - spaces_data.append(space_dict) - click.echo(json.dumps(spaces_data, indent=2)) + # Extract status conditions from raw resource + available = "" + progressing = "" + degraded = "" + + if space.status and 'conditions' in space.status: + conditions = {c['type']: c['status'] for c in space.status['conditions']} + available = conditions.get('Available', '') + progressing = conditions.get('Progressing', '') + degraded = conditions.get('Degraded', '') + + table_data.append([ + space.config.name, + namespace, + available, + progressing, + degraded + ]) + click.echo(tabulate(table_data, headers=["NAME", "NAMESPACE", "AVAILABLE", "PROGRESSING", "DEGRADED"])) else: - if spaces: - table_data = [] - for space in spaces: - # Extract status conditions from raw resource - available = "" - progressing = "" - degraded = "" - - if space.status and 'conditions' in space.status: - conditions = {c['type']: c['status'] for c in space.status['conditions']} - available = conditions.get('Available', '') - progressing = conditions.get('Progressing', '') - degraded = conditions.get('Degraded', '') - - table_data.append([ - space.config.name, - namespace, - available, - progressing, - degraded - ]) - click.echo(tabulate(table_data, headers=["NAME", "NAMESPACE", "AVAILABLE", "PROGRESSING", "DEGRADED"])) - else: - click.echo("No spaces found") - except Exception as e: - click.echo(f"Error listing spaces: {e}", err=True) + click.echo("No spaces found") @click.command("hyp-space") @@ -78,18 +71,15 @@ def space_list(namespace, output): @click.option("--output", "-o", type=click.Choice(["yaml", "json"]), default="yaml") def space_describe(name, namespace, output): """Describe a space resource.""" - try: - current_space = HPSpace.get(name=name, namespace=namespace) - - # Combine config and raw resource data - current_space.raw_resource.get('metadata', {}).pop('managedFields', None) - - if output == "json": - click.echo(json.dumps(current_space.raw_resource, indent=2)) - else: - click.echo(yaml.dump(current_space.raw_resource, default_flow_style=False)) - except Exception as e: - click.echo(f"Error describing space '{name}': {e}", err=True) + current_space = HPSpace.get(name=name, namespace=namespace) + + # Combine config and raw resource data + current_space.raw_resource.get('metadata', {}).pop('managedFields', None) + + if output == "json": + click.echo(json.dumps(current_space.raw_resource, indent=2)) + else: + click.echo(yaml.dump(current_space.raw_resource, default_flow_style=False)) @click.command("hyp-space") @@ -97,13 +87,9 @@ def space_describe(name, namespace, output): @click.option("--namespace", "-n", required=False, default="default", help="Kubernetes namespace") def space_delete(name, namespace): """Delete a space resource.""" - try: - current_space = HPSpace.get(name=name, namespace=namespace) - current_space.delete() - - click.echo(f"Space '{name}' deleted successfully in namespace '{namespace}'") - except Exception as e: - click.echo(f"Error deleting space '{name}': {e}", err=True) + current_space = HPSpace.get(name=name, namespace=namespace) + current_space.delete() + click.echo(f"Requested deletion for Space '{name}' in namespace '{namespace}'") @click.command("hyp-space") @@ -114,16 +100,12 @@ def space_delete(name, namespace): ) def space_update(version, config): """Update a space resource.""" - try: - current_space = HPSpace.get(name=config['name'], namespace=config['namespace']) - if not config.get("display_name"): - config["display_name"] = current_space.config.display_name - - current_space.update(**config) + current_space = HPSpace.get(name=config['name'], namespace=config['namespace']) + if not config.get("display_name"): + config["display_name"] = current_space.config.display_name - click.echo(f"Space '{current_space.config.name}' updated successfully in namespace '{config['namespace']}'") - except Exception as e: - click.echo(f"Error updating space: {e}", err=True) + current_space.update(**config) + click.echo(f"Space '{current_space.config.name}' updated successfully in namespace '{config['namespace']}'") @click.command("hyp-space") @@ -131,13 +113,9 @@ def space_update(version, config): @click.option("--namespace", "-n", required=False, default="default", help="Kubernetes namespace") def space_start(name, namespace): """Start a space resource.""" - try: - current_space = HPSpace.get(name=name, namespace=namespace) - current_space.start() - - click.echo(f"Space '{name}' start requested") - except Exception as e: - click.echo(f"Error starting space '{name}': {e}", err=True) + current_space = HPSpace.get(name=name, namespace=namespace) + current_space.start() + click.echo(f"Space '{name}' start requested") @click.command("hyp-space") @@ -145,13 +123,9 @@ def space_start(name, namespace): @click.option("--namespace", "-n", required=False, default="default", help="Kubernetes namespace") def space_stop(name, namespace): """Stop a space resource.""" - try: - current_space = HPSpace.get(name=name, namespace=namespace) - current_space.stop() - - click.echo(f"Space '{name}' stop requested") - except Exception as e: - click.echo(f"Error stopping space '{name}': {e}", err=True) + current_space = HPSpace.get(name=name, namespace=namespace) + current_space.stop() + click.echo(f"Space '{name}' stop requested") @click.command("hyp-space") @@ -161,9 +135,6 @@ def space_stop(name, namespace): @click.option("--container", required=False, help="Name of the container to get logs from") def space_get_logs(name, namespace, pod_name, container): """Get logs for a space resource.""" - try: - current_space = HPSpace.get(name=name, namespace=namespace) - logs = current_space.get_logs(pod_name=pod_name, container=container) - click.echo(logs) - except Exception as e: - click.echo(f"Error getting logs for space '{name}': {e}", err=True) + current_space = HPSpace.get(name=name, namespace=namespace) + logs = current_space.get_logs(pod_name=pod_name, container=container) + click.echo(logs) diff --git a/src/sagemaker/hyperpod/cli/commands/space_access.py b/src/sagemaker/hyperpod/cli/commands/space_access.py index f35994bc..1de7e96c 100644 --- a/src/sagemaker/hyperpod/cli/commands/space_access.py +++ b/src/sagemaker/hyperpod/cli/commands/space_access.py @@ -16,10 +16,6 @@ ) def space_access_create(name, namespace, connection_type): """Create a space access resource.""" - - try: - space = HPSpace.get(name=name, namespace=namespace) - response = space.create_space_access(connection_type=connection_type) - click.echo(response) - except Exception as e: - click.echo(f"Error creating space access: {e}", err=True) + space = HPSpace.get(name=name, namespace=namespace) + response = space.create_space_access(connection_type=connection_type) + click.echo(response) diff --git a/src/sagemaker/hyperpod/cli/commands/space_template.py b/src/sagemaker/hyperpod/cli/commands/space_template.py index 540125ae..ab84ee5c 100644 --- a/src/sagemaker/hyperpod/cli/commands/space_template.py +++ b/src/sagemaker/hyperpod/cli/commands/space_template.py @@ -9,76 +9,66 @@ @click.option("--file", "-f", required=True, help="YAML file containing the configuration") def space_template_create(file): """Create a space-template resource.""" - try: - template = HPSpaceTemplate(file_path=file) - template.create() - click.echo(f"Space template '{template.name}' created successfully") - except Exception as e: - click.echo(f"Error creating space template: {e}", err=True) + template = HPSpaceTemplate(file_path=file) + template.create() + click.echo(f"Space template '{template.name}' in namespace '{template.namespace}' created successfully") @click.command("hyp-space-template") +@click.option("--namespace", "-n", required=False, default=None, help="Kubernetes namespace") @click.option("--output", "-o", type=click.Choice(["table", "json"]), default="table") -def space_template_list(output): +def space_template_list(namespace, output): """List space-template resources.""" - try: - templates = HPSpaceTemplate.list() - - if output == "json": - templates_data = [template.to_dict() for template in templates] - click.echo(json.dumps(templates_data, indent=2)) + templates = HPSpaceTemplate.list(namespace) + + if output == "json": + templates_data = [template.to_dict() for template in templates] + click.echo(json.dumps(templates_data, indent=2)) + else: + if templates: + table_data = [] + for template in templates: + table_data.append([ + template.namespace, + template.name, + template.config_data.get("spec", {}).get("displayName", ""), + template.config_data.get("spec", {}).get("defaultImage", ""), + ]) + click.echo(tabulate(table_data, headers=["NAMESPACE", "NAME", "DISPLAY_NAME", "DEFAULT_IMAGE"])) else: - if templates: - table_data = [] - for template in templates: - table_data.append([ - template.name, - template.config_data.get("spec", {}).get("displayName", ""), - template.config_data.get("spec", {}).get("defaultImage", ""), - ]) - click.echo(tabulate(table_data, headers=["NAME", "DISPLAY_NAME", "DEFAULT_IMAGE"])) - else: - click.echo("No space templates found") - except Exception as e: - click.echo(f"Error listing space templates: {e}", err=True) + click.echo("No space templates found") @click.command("hyp-space-template") @click.option("--name", required=True, help="Name of the space template") +@click.option("--namespace", "-n", required=False, default=None, help="Kubernetes namespace") @click.option("--output", "-o", type=click.Choice(["yaml", "json"]), default="yaml") -def space_template_describe(name, output): +def space_template_describe(name, namespace, output): """Describe a space-template resource.""" - try: - template = HPSpaceTemplate.get(name) - - if output == "json": - click.echo(json.dumps(template.to_dict(), indent=2)) - else: - click.echo(template.to_yaml()) - except Exception as e: - click.echo(f"Error describing space template '{name}': {e}", err=True) + template = HPSpaceTemplate.get(name, namespace) + + if output == "json": + click.echo(json.dumps(template.to_dict(), indent=2)) + else: + click.echo(template.to_yaml()) @click.command("hyp-space-template") @click.option("--name", required=True, help="Name of the space template") -def space_template_delete(name): +@click.option("--namespace", "-n", required=False, default=None, help="Kubernetes namespace") +def space_template_delete(name, namespace): """Delete a space-template resource.""" - try: - template = HPSpaceTemplate.get(name) - template.delete() - click.echo(f"Space template '{name}' deleted successfully") - except Exception as e: - click.echo(f"Error deleting space template '{name}': {e}", err=True) + template = HPSpaceTemplate.get(name, namespace) + template.delete() + click.echo(f"Requested deletion for Space template '{name}' in namespace '{namespace}'") @click.command("hyp-space-template") @click.option("--name", required=True, help="Name of the space template") +@click.option("--namespace", "-n", required=False, default=None, help="Kubernetes namespace") @click.option("--file", "-f", required=True, help="YAML file containing the updated template") -def space_template_update(name, file): +def space_template_update(name, namespace, file): """Update a space-template resource.""" - try: - template = HPSpaceTemplate.get(name) - template.update(file) - click.echo(f"Space template '{name}' updated successfully") - except Exception as e: - click.echo(f"Error updating space template '{name}': {e}", err=True) + template = HPSpaceTemplate.get(name, namespace) + template.update(file) + click.echo(f"Space template '{name}' in namespace '{namespace}' updated successfully") diff --git a/src/sagemaker/hyperpod/cli/space_utils.py b/src/sagemaker/hyperpod/cli/space_utils.py index c8193098..b84020c7 100644 --- a/src/sagemaker/hyperpod/cli/space_utils.py +++ b/src/sagemaker/hyperpod/cli/space_utils.py @@ -1,6 +1,7 @@ import json import pkgutil import click +import re from typing import Callable, Optional, Mapping, Type, Dict, Any from pydantic import ValidationError from sagemaker.hyperpod.cli.constants.space_constants import IMMUTABLE_FIELDS @@ -167,7 +168,35 @@ def _parse_template_ref(ctx, param, value): return parts except Exception as e: raise click.UsageError(f"Error parsing template ref: {str(e)}") - + + def _parse_idle_shutdown_param(ctx, param, value): + """Parse idle shutdown parameters from command line format to dictionary format.""" + if not value: + return None + + try: + parts = {} + for item in re.split(r',(?![^{]*})', value): + if '=' not in item: + raise click.UsageError(f"Invalid idle-shutdown format: '{item}' should be key=value") + key, val = item.split('=', 1) + key = key.strip() + val = val.strip() + + if key == 'idle_timeout_in_minutes': + key = 'idleTimeoutInMinutes' + elif key == 'enabled': + val = val.lower() in ('True', 'true', '1', 'yes') + elif key == 'detection': + try: + val = json.loads(val) + except json.JSONDecodeError: + raise click.UsageError(f"Invalid JSON for --{key}: {val}") + parts[key] = val + return parts + except Exception as e: + raise click.UsageError(f"Error parsing idle-shutdown: {str(e)}") + # 1) the wrapper click will call def wrapped_func(*args, **kwargs): version = version_key or kwargs.pop("version", "1.0") @@ -201,6 +230,14 @@ def wrapped_func(*args, **kwargs): if container_config is not None: kwargs["container_config"] = container_config + template_ref = kwargs.pop("template_ref", None) + if template_ref is not None: + kwargs["template_ref"] = template_ref + + idle_shutdown = kwargs.pop("idle_shutdown", None) + if idle_shutdown is not None: + kwargs["idle_shutdown"] = idle_shutdown + # filter out None/empty values so Pydantic model defaults apply filtered_kwargs = {} for key, value in kwargs.items(): @@ -255,56 +292,56 @@ def wrapped_func(*args, **kwargs): "--cpu", type=str, default=None, - help="CPU resource request, e.g. '250m'", + help="CPU resource request, e.g. '500m'", )(wrapped_func) wrapped_func = click.option( "--cpu-limit", type=str, default=None, - help="CPU resource limit, e.g. '250m'", + help="CPU resource limit, e.g. '500m'", )(wrapped_func) wrapped_func = click.option( "--memory", type=str, default=None, - help="Memory resource request, e.g. '256Mi'", + help="Memory resource request, e.g. '2Gi'", )(wrapped_func) wrapped_func = click.option( "--memory-limit", type=str, default=None, - help="Memory resource limit, e.g. '256Mi'", + help="Memory resource limit, e.g. '2Gi'", )(wrapped_func) wrapped_func = click.option( "--gpu", type=str, default=None, - help="Gpu resource request, e.g. '1'", + help="GPU resource request, e.g. '1'", )(wrapped_func) wrapped_func = click.option( "--gpu-limit", type=str, default=None, - help="Gpu resource limit, e.g. '1'", + help="GPU resource limit, e.g. '1'", )(wrapped_func) wrapped_func = click.option( "--accelerator-partition-type", type=str, default=None, - help="Fractional GPU parition type", + help="Fractional GPU partition type, e.g. 'mig-3g.20gb'", )(wrapped_func) wrapped_func = click.option( "--accelerator-partition-count", type=str, default=None, - help="Fractional GPU parition count", + help="Fractional GPU partition count, e.g. '1'", )(wrapped_func) wrapped_func = click.option( @@ -334,6 +371,12 @@ def wrapped_func(*args, **kwargs): help="TemplateRef references a WorkspaceTemplate to use as base configuration. Format: --template-ref name=,namespace=", )(wrapped_func) + wrapped_func = click.option( + "--idle-shutdown", + callback=_parse_idle_shutdown_param, + help="Idle shutdown configuration. Format: --idle-shutdown enabled=,idleTimeoutInMinutes=,detection=", + )(wrapped_func) + # Exclude the props that were handled out of the below for loop excluded_props = set( [ @@ -343,6 +386,7 @@ def wrapped_func(*args, **kwargs): "storage", "container_config", "template_ref", + "idle_shutdown", ] ) diff --git a/src/sagemaker/hyperpod/space/hyperpod_space.py b/src/sagemaker/hyperpod/space/hyperpod_space.py index efc098db..da5555bd 100644 --- a/src/sagemaker/hyperpod/space/hyperpod_space.py +++ b/src/sagemaker/hyperpod/space/hyperpod_space.py @@ -197,10 +197,10 @@ def list(cls, namespace: Optional[str] = None) -> List["HPSpace"]: ) for item in response.get("items", []): - # Check if space was created by the caller - # TODO: need to also check OwnershipType when it's implemented in the operator + # Check if space was created by the caller or it's set as 'Public' created_by = item.get('metadata', {}).get('annotations', {}).get('workspace.jupyter.org/created-by') - if created_by == caller_arn: + ownership_type = item.get('spec', {}).get('ownershipType', '') + if created_by == caller_arn or ownership_type == "Public": config_data = map_kubernetes_response_to_model(item, SpaceConfig) space_config = SpaceConfig(**config_data) @@ -403,21 +403,17 @@ def get_logs(self, pod_name: Optional[str] = None, container: Optional[str] = No raise RuntimeError(f"No pods found for space '{self.config.name}'") pod_name = pods[0] + if not container: + container = "workspace" + v1 = client.CoreV1Api() try: - if container: - logs = v1.read_namespaced_pod_log( - name=pod_name, - namespace=self.config.namespace, - container=container - ) - else: - logs = v1.read_namespaced_pod_log( - name=pod_name, - namespace=self.config.namespace - ) - return logs + return v1.read_namespaced_pod_log( + name=pod_name, + namespace=self.config.namespace, + container=container + ) except Exception as e: handle_exception(e, pod_name, self.config.namespace) diff --git a/test/unit_tests/cli/test_space.py b/test/unit_tests/cli/test_space.py index 2341c0a4..f2073f82 100644 --- a/test/unit_tests/cli/test_space.py +++ b/test/unit_tests/cli/test_space.py @@ -78,42 +78,6 @@ def test_space_create_missing_required_args(self, mock_load_schema): assert result.exit_code != 0 assert 'Missing option' in result.output - @patch('sagemaker.hyperpod.cli.commands.space.HPSpace') - def test_space_create_hp_space_error(self, mock_hp_space_class): - """Test space creation error handling""" - mock_hp_space_instance = Mock() - mock_hp_space_instance.create.side_effect = Exception("Creation failed") - mock_hp_space_class.return_value = mock_hp_space_instance - - mock_model = Mock() - mock_model.return_value = Mock() - mock_model.return_value.to_domain.return_value = { - "name": "test-space", - "display_name": "Test Space", - "namespace": "test-ns", - "space_spec": {} - } - - with patch('hyperpod_space_template.registry.SCHEMA_REGISTRY', {'1.0': mock_model}): - with patch('sagemaker.hyperpod.cli.space_utils.load_schema_for_version') as mock_load_schema: - mock_load_schema.return_value = { - "properties": { - "name": {"type": "string"}, - "display_name": {"type": "string"}, - "namespace": {"type": "string"} - }, - "required": ["name", "display_name"] - } - result = self.runner.invoke(space_create, [ - '--version', '1.0', - '--name', 'test-space', - '--display-name', 'Test Space', - '--namespace', 'test-ns' - ]) - - assert result.exit_code == 0 - assert "Error creating space: Creation failed" in result.output - @patch('sagemaker.hyperpod.cli.commands.space.HPSpace') def test_space_list_table_output(self, mock_hp_space_class): """Test space list with table output""" @@ -176,18 +140,6 @@ def test_space_list_empty(self, mock_hp_space_class): assert result.exit_code == 0 assert "No spaces found" in result.output - @patch('sagemaker.hyperpod.cli.commands.space.HPSpace') - def test_space_list_error(self, mock_hp_space_class): - """Test space list error handling""" - mock_hp_space_class.list.side_effect = Exception("List failed") - - result = self.runner.invoke(space_list, [ - '--namespace', 'test-ns' - ]) - - assert result.exit_code == 0 - assert "Error listing spaces: List failed" in result.output - @patch('sagemaker.hyperpod.cli.commands.space.HPSpace') def test_space_describe_yaml_output(self, mock_hp_space_class): """Test space describe with YAML output""" @@ -225,19 +177,6 @@ def test_space_describe_json_output(self, mock_hp_space_class): output_json = json.loads(result.output) assert output_json == mock_resource - @patch('sagemaker.hyperpod.cli.commands.space.HPSpace') - def test_space_describe_hp_space_error(self, mock_hp_space_class): - """Test space describe error handling""" - mock_hp_space_class.get.side_effect = Exception("Describe failed") - - result = self.runner.invoke(space_describe, [ - '--name', 'test-space', - '--namespace', 'test-ns' - ]) - - assert result.exit_code == 0 - assert "Error describing space 'test-space': Describe failed" in result.output - @patch('sagemaker.hyperpod.cli.commands.space.HPSpace') def test_space_delete_success(self, mock_hp_space_class): """Test successful space deletion""" @@ -250,22 +189,10 @@ def test_space_delete_success(self, mock_hp_space_class): ]) assert result.exit_code == 0 - assert "Space 'test-space' deleted successfully" in result.output + assert "Requested deletion for Space 'test-space' in namespace 'test-ns'" in result.output mock_hp_space_class.get.assert_called_once_with(name='test-space', namespace='test-ns') mock_hp_space_instance.delete.assert_called_once() - @patch('sagemaker.hyperpod.cli.commands.space.HPSpace') - def test_space_delete_hp_space_error(self, mock_hp_space_class): - """Test space delete error handling""" - mock_hp_space_class.get.side_effect = Exception("Delete failed") - - result = self.runner.invoke(space_delete, [ - '--name', 'test-space', - '--namespace', 'test-ns' - ]) - - assert result.exit_code == 0 - assert "Error deleting space 'test-space': Delete failed" in result.output @patch('sagemaker.hyperpod.cli.commands.space.HPSpace') @patch('sagemaker.hyperpod.cli.space_utils.load_schema_for_version') @@ -309,41 +236,6 @@ def test_space_update_success(self, mock_load_schema, mock_hp_space_class): mock_hp_space_class.get.assert_called_once_with(name='test-space', namespace='test-ns') mock_hp_space_instance.update.assert_called_once() - @patch('sagemaker.hyperpod.cli.commands.space.HPSpace') - def test_space_update_hp_space_error(self, mock_hp_space_class): - """Test space update error handling""" - mock_hp_space_instance = Mock() - mock_hp_space_instance.update.side_effect = Exception("Update failed") - mock_hp_space_class.get.return_value = mock_hp_space_instance - - mock_model = Mock() - mock_model.return_value = Mock() - mock_model.return_value.to_domain.return_value = { - "name": "test-space", - "namespace": "test-ns", - "space_spec": {} - } - - with patch('hyperpod_space_template.registry.SCHEMA_REGISTRY', {'1.0': mock_model}): - with patch('sagemaker.hyperpod.cli.space_utils.load_schema_for_version') as mock_load_schema: - mock_load_schema.return_value = { - "properties": { - "name": {"type": "string"}, - "display_name": {"type": "string"}, - "namespace": {"type": "string"} - }, - "required": ["name"] - } - result = self.runner.invoke(space_update, [ - '--version', '1.0', - '--name', 'test-space', - '--display-name', 'Test Space', - '--namespace', 'test-ns' - ]) - - assert result.exit_code == 0 - assert "Error updating space: Update failed" in result.output - @patch('sagemaker.hyperpod.cli.commands.space.HPSpace') def test_space_start_success(self, mock_hp_space_class): """Test successful space start""" @@ -360,21 +252,6 @@ def test_space_start_success(self, mock_hp_space_class): mock_hp_space_class.get.assert_called_once_with(name='test-space', namespace='test-ns') mock_hp_space_instance.start.assert_called_once() - @patch('sagemaker.hyperpod.cli.commands.space.HPSpace') - def test_space_start_hp_space_error(self, mock_hp_space_class): - """Test space start error handling""" - mock_hp_space_instance = Mock() - mock_hp_space_instance.start.side_effect = Exception("Start failed") - mock_hp_space_class.get.return_value = mock_hp_space_instance - - result = self.runner.invoke(space_start, [ - '--name', 'test-space', - '--namespace', 'test-ns' - ]) - - assert result.exit_code == 0 - assert "Error starting space 'test-space': Start failed" in result.output - @patch('sagemaker.hyperpod.cli.commands.space.HPSpace') def test_space_stop_success(self, mock_hp_space_class): """Test successful space stop""" @@ -391,21 +268,6 @@ def test_space_stop_success(self, mock_hp_space_class): mock_hp_space_class.get.assert_called_once_with(name='test-space', namespace='test-ns') mock_hp_space_instance.stop.assert_called_once() - @patch('sagemaker.hyperpod.cli.commands.space.HPSpace') - def test_space_stop_hp_space_error(self, mock_hp_space_class): - """Test space stop error handling""" - mock_hp_space_instance = Mock() - mock_hp_space_instance.stop.side_effect = Exception("Stop failed") - mock_hp_space_class.get.return_value = mock_hp_space_instance - - result = self.runner.invoke(space_stop, [ - '--name', 'test-space', - '--namespace', 'test-ns' - ]) - - assert result.exit_code == 0 - assert "Error stopping space 'test-space': Stop failed" in result.output - @patch('sagemaker.hyperpod.cli.commands.space.HPSpace') def test_space_get_logs_success(self, mock_hp_space_class): """Test successful space get logs""" @@ -438,21 +300,6 @@ def test_space_get_logs_no_pods(self, mock_hp_space_class): assert result.exit_code == 0 # HPSpace.get_logs() handles the "no pods" case internally - @patch('sagemaker.hyperpod.cli.commands.space.HPSpace') - def test_space_get_logs_hp_space_error(self, mock_hp_space_class): - """Test space get logs error handling""" - mock_hp_space_instance = Mock() - mock_hp_space_instance.get_logs.side_effect = Exception("Get logs failed") - mock_hp_space_class.get.return_value = mock_hp_space_instance - - result = self.runner.invoke(space_get_logs, [ - '--name', 'test-space', - '--namespace', 'test-ns' - ]) - - assert result.exit_code == 0 - assert "Error getting logs for space 'test-space': Get logs failed" in result.output - def test_missing_required_arguments(self): """Test commands with missing required arguments""" # Test create without name diff --git a/test/unit_tests/cli/test_space_access.py b/test/unit_tests/cli/test_space_access.py index 9602edc7..717047e7 100644 --- a/test/unit_tests/cli/test_space_access.py +++ b/test/unit_tests/cli/test_space_access.py @@ -51,17 +51,3 @@ def test_space_access_create_default_values(self, mock_hp_space_class): assert "https://default-url.com" in result.output mock_hp_space_class.get.assert_called_once_with(name='test-space', namespace='default') mock_space_instance.create_space_access.assert_called_once_with(connection_type='vscode-remote') - - @patch('sagemaker.hyperpod.cli.commands.space_access.HPSpace') - def test_space_access_create_api_error(self, mock_hp_space_class): - """Test space access creation when API call fails""" - mock_space_instance = Mock() - mock_space_instance.create_space_access.side_effect = Exception("API error") - mock_hp_space_class.get.return_value = mock_space_instance - - result = self.runner.invoke(space_access_create, [ - '--name', 'test-space' - ]) - - assert result.exit_code == 0 - assert "Error creating space access: API error" in result.output diff --git a/test/unit_tests/cli/test_space_template.py b/test/unit_tests/cli/test_space_template.py index 5f8dfd41..fa9f25ae 100644 --- a/test/unit_tests/cli/test_space_template.py +++ b/test/unit_tests/cli/test_space_template.py @@ -41,32 +41,27 @@ def test_space_template_create_success(self, mock_hp_space_template): """Test successful space template creation""" mock_template_instance = Mock() mock_template_instance.name = "test-template" + mock_template_instance.namespace = "default" mock_hp_space_template.return_value = mock_template_instance result = self.runner.invoke(space_template_create, ["--file", "test.yaml"]) self.assertEqual(result.exit_code, 0) - self.assertIn("Space template 'test-template' created successfully", result.output) + self.assertIn("Space template 'test-template' in namespace 'default' created successfully", result.output) mock_hp_space_template.assert_called_once_with(file_path="test.yaml") mock_template_instance.create.assert_called_once() - @patch("sagemaker.hyperpod.cli.commands.space_template.HPSpaceTemplate") - def test_space_template_create_file_not_found(self, mock_hp_space_template): - """Test space template creation with missing file""" - mock_hp_space_template.side_effect = FileNotFoundError("File 'nonexistent.yaml' not found") - - result = self.runner.invoke(space_template_create, ["--file", "nonexistent.yaml"]) - - self.assertEqual(result.exit_code, 0) - self.assertIn("Error creating space template: File 'nonexistent.yaml' not found", result.output) - @patch("sagemaker.hyperpod.cli.commands.space_template.HPSpaceTemplate") def test_space_template_list_table_output(self, mock_hp_space_template): """Test space template list with table output""" mock_template1 = Mock() mock_template1.name = "template1" + mock_template1.namespace = "default" + mock_template1.config_data = {"spec": {"displayName": "Template 1", "defaultImage": "image1"}} mock_template2 = Mock() mock_template2.name = "template2" + mock_template2.namespace = "test" + mock_template2.config_data = {"spec": {"displayName": "Template 2", "defaultImage": "image2"}} mock_hp_space_template.list.return_value = [mock_template1, mock_template2] result = self.runner.invoke(space_template_list, ["--output", "table"]) @@ -74,7 +69,9 @@ def test_space_template_list_table_output(self, mock_hp_space_template): self.assertEqual(result.exit_code, 0) self.assertIn("template1", result.output) self.assertIn("template2", result.output) + self.assertIn("NAMESPACE", result.output) self.assertIn("NAME", result.output) + mock_hp_space_template.list.assert_called_once_with(None) @patch("sagemaker.hyperpod.cli.commands.space_template.HPSpaceTemplate") def test_space_template_list_json_output(self, mock_hp_space_template): @@ -92,6 +89,7 @@ def test_space_template_list_json_output(self, mock_hp_space_template): self.assertEqual(len(output_json), 2) self.assertEqual(output_json[0]["metadata"]["name"], "template1") self.assertEqual(output_json[1]["metadata"]["name"], "template2") + mock_hp_space_template.list.assert_called_once_with(None) @patch("sagemaker.hyperpod.cli.commands.space_template.HPSpaceTemplate") def test_space_template_list_empty(self, mock_hp_space_template): @@ -102,16 +100,23 @@ def test_space_template_list_empty(self, mock_hp_space_template): self.assertEqual(result.exit_code, 0) self.assertIn("No space templates found", result.output) + mock_hp_space_template.list.assert_called_once_with(None) @patch("sagemaker.hyperpod.cli.commands.space_template.HPSpaceTemplate") - def test_space_template_list_error(self, mock_hp_space_template): - """Test space template list with error""" - mock_hp_space_template.list.side_effect = Exception("List error") + def test_space_template_list_with_namespace(self, mock_hp_space_template): + """Test space template list with namespace parameter""" + mock_template1 = Mock() + mock_template1.name = "template1" + mock_template1.namespace = "test-namespace" + mock_template1.config_data = {"spec": {"displayName": "Template 1", "defaultImage": "image1"}} + mock_hp_space_template.list.return_value = [mock_template1] - result = self.runner.invoke(space_template_list) + result = self.runner.invoke(space_template_list, ["--namespace", "test-namespace", "--output", "table"]) self.assertEqual(result.exit_code, 0) - self.assertIn("Error listing space templates: List error", result.output) + self.assertIn("template1", result.output) + self.assertIn("test-namespace", result.output) + mock_hp_space_template.list.assert_called_once_with("test-namespace") @patch("sagemaker.hyperpod.cli.commands.space_template.HPSpaceTemplate") def test_space_template_describe_yaml_output(self, mock_hp_space_template): @@ -125,7 +130,7 @@ def test_space_template_describe_yaml_output(self, mock_hp_space_template): self.assertEqual(result.exit_code, 0) self.assertIn("name: test-template", result.output) self.assertIn("displayName: Test Template", result.output) - mock_hp_space_template.get.assert_called_once_with("test-template") + mock_hp_space_template.get.assert_called_once_with("test-template", None) @patch("sagemaker.hyperpod.cli.commands.space_template.HPSpaceTemplate") def test_space_template_describe_json_output(self, mock_hp_space_template): @@ -143,16 +148,7 @@ def test_space_template_describe_json_output(self, mock_hp_space_template): output_json = json.loads(result.output) self.assertEqual(output_json["metadata"]["name"], "test-template") self.assertEqual(output_json["spec"]["displayName"], "Test Template") - - @patch("sagemaker.hyperpod.cli.commands.space_template.HPSpaceTemplate") - def test_space_template_describe_error(self, mock_hp_space_template): - """Test space template describe with error""" - mock_hp_space_template.get.side_effect = Exception("Not found") - - result = self.runner.invoke(space_template_describe, ["--name", "nonexistent"]) - - self.assertEqual(result.exit_code, 0) - self.assertIn("Error describing space template 'nonexistent': Not found", result.output) + mock_hp_space_template.get.assert_called_once_with("test-template", None) @patch("sagemaker.hyperpod.cli.commands.space_template.HPSpaceTemplate") def test_space_template_delete_success(self, mock_hp_space_template): @@ -163,20 +159,10 @@ def test_space_template_delete_success(self, mock_hp_space_template): result = self.runner.invoke(space_template_delete, ["--name", "test-template"]) self.assertEqual(result.exit_code, 0) - self.assertIn("Space template 'test-template' deleted successfully", result.output) - mock_hp_space_template.get.assert_called_once_with("test-template") + self.assertIn("Requested deletion for Space template 'test-template' in namespace 'None'", result.output) + mock_hp_space_template.get.assert_called_once_with("test-template", None) mock_template_instance.delete.assert_called_once() - @patch("sagemaker.hyperpod.cli.commands.space_template.HPSpaceTemplate") - def test_space_template_delete_error(self, mock_hp_space_template): - """Test space template deletion with error""" - mock_hp_space_template.get.side_effect = Exception("Delete error") - - result = self.runner.invoke(space_template_delete, ["--name", "test-template"]) - - self.assertEqual(result.exit_code, 0) - self.assertIn("Error deleting space template 'test-template': Delete error", result.output) - @patch("sagemaker.hyperpod.cli.commands.space_template.HPSpaceTemplate") def test_space_template_update_success(self, mock_hp_space_template): """Test successful space template update""" @@ -186,18 +172,6 @@ def test_space_template_update_success(self, mock_hp_space_template): result = self.runner.invoke(space_template_update, ["--name", "test-template", "--file", "test.yaml"]) self.assertEqual(result.exit_code, 0) - self.assertIn("Space template 'test-template' updated successfully", result.output) - mock_hp_space_template.get.assert_called_once_with("test-template") + self.assertIn("Space template 'test-template' in namespace 'None' updated successfully", result.output) + mock_hp_space_template.get.assert_called_once_with("test-template", None) mock_template_instance.update.assert_called_once_with("test.yaml") - - @patch("sagemaker.hyperpod.cli.commands.space_template.HPSpaceTemplate") - def test_space_template_update_yaml_error(self, mock_hp_space_template): - """Test space template update with YAML parsing error""" - mock_template_instance = Mock() - mock_template_instance.update.side_effect = yaml.YAMLError("Invalid YAML") - mock_hp_space_template.get.return_value = mock_template_instance - - result = self.runner.invoke(space_template_update, ["--name", "test-template", "--file", "test.yaml"]) - - self.assertEqual(result.exit_code, 0) - self.assertIn("Error updating space template 'test-template': Invalid YAML", result.output) diff --git a/test/unit_tests/test_hyperpod_space.py b/test/unit_tests/test_hyperpod_space.py index c06e89c2..b0de3933 100644 --- a/test/unit_tests/test_hyperpod_space.py +++ b/test/unit_tests/test_hyperpod_space.py @@ -52,7 +52,7 @@ def test_verify_kube_config_already_loaded(self): @patch('sagemaker.hyperpod.space.hyperpod_space.client.CustomObjectsApi') @patch.object(HPSpace, 'verify_kube_config') def test_create_success(self, mock_verify_config, mock_custom_api_class): - """Test successful dev space creation""" + """Test successful space creation""" mock_custom_api = Mock() mock_custom_api_class.return_value = mock_custom_api @@ -76,7 +76,7 @@ def test_create_success(self, mock_verify_config, mock_custom_api_class): @patch.object(HPSpace, 'verify_kube_config') @patch('sagemaker.hyperpod.space.hyperpod_space.handle_exception') def test_create_failure(self, mock_handle_exception, mock_verify_config, mock_custom_api_class): - """Test dev space creation failure""" + """Test space creation failure""" mock_custom_api = Mock() mock_custom_api_class.return_value = mock_custom_api @@ -102,7 +102,7 @@ def test_create_failure(self, mock_handle_exception, mock_verify_config, mock_cu @patch.object(HPSpace, 'verify_kube_config') @patch('sagemaker.hyperpod.space.hyperpod_space.get_default_namespace') def test_list_success(self, mock_get_namespace, mock_verify_config, mock_custom_api_class, mock_boto3_client): - """Test successful dev space listing""" + """Test successful space listing""" mock_custom_api = Mock() mock_custom_api_class.return_value = mock_custom_api mock_get_namespace.return_value = "default" @@ -149,7 +149,7 @@ def test_list_success(self, mock_get_namespace, mock_verify_config, mock_custom_ @patch('sagemaker.hyperpod.space.hyperpod_space.client.CustomObjectsApi') @patch.object(HPSpace, 'verify_kube_config') def test_list_with_namespace(self, mock_verify_config, mock_custom_api_class, mock_boto3_client): - """Test dev space listing with specific namespace""" + """Test space listing with specific namespace""" mock_custom_api = Mock() mock_custom_api_class.return_value = mock_custom_api @@ -352,7 +352,7 @@ def test_list_no_matching_spaces_across_pages(self, mock_verify_config, mock_cus @patch.object(HPSpace, 'verify_kube_config') @patch('sagemaker.hyperpod.space.hyperpod_space.handle_exception') def test_list_failure(self, mock_handle_exception, mock_verify_config, mock_custom_api_class, mock_boto3_client): - """Test dev space listing failure""" + """Test space listing failure""" mock_custom_api = Mock() mock_custom_api_class.return_value = mock_custom_api mock_custom_api.list_namespaced_custom_object.side_effect = Exception("List failed") @@ -364,7 +364,7 @@ def test_list_failure(self, mock_handle_exception, mock_verify_config, mock_cust @patch('sagemaker.hyperpod.space.hyperpod_space.client.CustomObjectsApi') @patch.object(HPSpace, 'verify_kube_config') def test_get_success(self, mock_verify_config, mock_custom_api_class): - """Test successful dev space retrieval""" + """Test successful space retrieval""" mock_custom_api = Mock() mock_custom_api_class.return_value = mock_custom_api @@ -383,7 +383,7 @@ def test_get_success(self, mock_verify_config, mock_custom_api_class): @patch.object(HPSpace, 'verify_kube_config') @patch('sagemaker.hyperpod.space.hyperpod_space.handle_exception') def test_get_failure(self, mock_handle_exception, mock_verify_config, mock_custom_api_class): - """Test dev space retrieval failure""" + """Test space retrieval failure""" mock_custom_api = Mock() mock_custom_api_class.return_value = mock_custom_api mock_custom_api.get_namespaced_custom_object.side_effect = Exception("Get failed") @@ -402,7 +402,7 @@ def test_get_failure(self, mock_handle_exception, mock_verify_config, mock_custo @patch('sagemaker.hyperpod.space.hyperpod_space.client.CustomObjectsApi') @patch.object(HPSpace, 'verify_kube_config') def test_delete_success(self, mock_verify_config, mock_custom_api_class): - """Test successful dev space deletion""" + """Test successful space deletion""" mock_custom_api = Mock() mock_custom_api_class.return_value = mock_custom_api @@ -421,7 +421,7 @@ def test_delete_success(self, mock_verify_config, mock_custom_api_class): @patch.object(HPSpace, 'verify_kube_config') @patch('sagemaker.hyperpod.space.hyperpod_space.handle_exception') def test_delete_failure(self, mock_handle_exception, mock_verify_config, mock_custom_api_class): - """Test dev space deletion failure""" + """Test space deletion failure""" mock_custom_api = Mock() mock_custom_api_class.return_value = mock_custom_api mock_custom_api.delete_namespaced_custom_object.side_effect = Exception("Delete failed") @@ -433,7 +433,7 @@ def test_delete_failure(self, mock_handle_exception, mock_verify_config, mock_cu @patch('sagemaker.hyperpod.space.hyperpod_space.client.CustomObjectsApi') @patch.object(HPSpace, 'verify_kube_config') def test_update_success(self, mock_verify_config, mock_custom_api_class): - """Test successful dev space update""" + """Test successful space update""" mock_custom_api = Mock() mock_custom_api_class.return_value = mock_custom_api @@ -459,7 +459,7 @@ def test_update_success(self, mock_verify_config, mock_custom_api_class): @patch.object(HPSpace, 'verify_kube_config') @patch('sagemaker.hyperpod.space.hyperpod_space.handle_exception') def test_update_failure(self, mock_handle_exception, mock_verify_config, mock_custom_api_class): - """Test dev space update failure""" + """Test space update failure""" mock_custom_api = Mock() mock_custom_api_class.return_value = mock_custom_api mock_custom_api.patch_namespaced_custom_object.side_effect = Exception("Update failed") @@ -473,13 +473,13 @@ def test_update_failure(self, mock_handle_exception, mock_verify_config, mock_cu @patch.object(HPSpace, 'update') def test_start(self, mock_update): - """Test dev space start""" + """Test space start""" self.hp_space.start() mock_update.assert_called_once_with(desired_status="Running") @patch.object(HPSpace, 'update') def test_stop(self, mock_update): - """Test dev space stop""" + """Test space stop""" self.hp_space.stop() mock_update.assert_called_once_with(desired_status="Stopped") @@ -534,7 +534,8 @@ def test_get_logs_with_pod_name(self, mock_list_pods, mock_verify_config, mock_c self.assertEqual(result, "test logs") mock_core_api.read_namespaced_pod_log.assert_called_once_with( name="test-pod", - namespace="test-namespace" + namespace="test-namespace", + container="workspace", ) mock_list_pods.assert_not_called() @@ -553,7 +554,8 @@ def test_get_logs_without_pod_name(self, mock_list_pods, mock_verify_config, moc self.assertEqual(result, "test logs") mock_core_api.read_namespaced_pod_log.assert_called_once_with( name="pod1", - namespace="test-namespace" + namespace="test-namespace", + container="workspace", ) @patch.object(HPSpace, 'verify_kube_config') From 75affc21e04995b70d2dfc1b3befb2488123d7cb Mon Sep 17 00:00:00 2001 From: aws-brianxia Date: Wed, 19 Nov 2025 20:22:35 -0800 Subject: [PATCH 22/31] Fix the template ref regression (#300) --- hyperpod-space-template/hyperpod_space_template/v1_0/model.py | 4 ++-- .../hyperpod_space_template/v1_0/schema.json | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/hyperpod-space-template/hyperpod_space_template/v1_0/model.py b/hyperpod-space-template/hyperpod_space_template/v1_0/model.py index 58084e11..5bf4d56e 100644 --- a/hyperpod-space-template/hyperpod_space_template/v1_0/model.py +++ b/hyperpod-space-template/hyperpod_space_template/v1_0/model.py @@ -44,7 +44,7 @@ class ContainerConfig(BaseModel): class TemplateRef(BaseModel): - """ContainerConfig defines container command and args configuration""" + """TemplateRef defines a reference to a WorkspaceTemplate""" name: str = Field( description="Name of the WorkspaceTemplate" ) @@ -248,7 +248,7 @@ def to_domain(self) -> Dict: if self.lifecycle is not None: spec["lifecycle"] = self.lifecycle if self.template_ref is not None: - spec["templateRef"] = self.template_ref + spec["templateRef"] = self.template_ref.model_dump(exclude_none=True, by_alias=True) if self.idle_shutdown is not None: spec["idleShutdown"] = self.idle_shutdown.model_dump(exclude_none=True, by_alias=True) if self.app_type is not None: diff --git a/hyperpod-space-template/hyperpod_space_template/v1_0/schema.json b/hyperpod-space-template/hyperpod_space_template/v1_0/schema.json index b94817f5..eb9659d7 100644 --- a/hyperpod-space-template/hyperpod_space_template/v1_0/schema.json +++ b/hyperpod-space-template/hyperpod_space_template/v1_0/schema.json @@ -203,7 +203,7 @@ "type": "object" }, "TemplateRef": { - "description": "ContainerConfig defines container command and args configuration", + "description": "TemplateRef defines a reference to a WorkspaceTemplate", "properties": { "name": { "description": "Name of the WorkspaceTemplate", From 84ff4b365a76d39102f2532111b1e72225360d23 Mon Sep 17 00:00:00 2001 From: aws-brianxia Date: Thu, 20 Nov 2025 16:32:35 -0800 Subject: [PATCH 23/31] Update SageMaker Space documentation (#301) --- README.md | 254 ++++++++++ doc/cli/cli_index.rst | 12 +- doc/cli/cli_reference.md | 11 +- doc/cli/space/cli_space.md | 410 ++++++++++++++++ doc/sdk/sdk_index.rst | 10 +- doc/sdk/space/hyperpod_space.rst | 30 ++ .../hyperpod/space/hyperpod_space.py | 454 ++++++++++++++++-- .../hyperpod/space/hyperpod_space_template.py | 353 ++++++++++++-- 8 files changed, 1440 insertions(+), 94 deletions(-) create mode 100644 doc/cli/space/cli_space.md create mode 100644 doc/sdk/space/hyperpod_space.rst diff --git a/README.md b/README.md index 72e1bc6c..9dbf7aa0 100644 --- a/README.md +++ b/README.md @@ -21,10 +21,12 @@ Note: Old `hyperpod`CLI V2 has been moved to `release_v2` branch. Please refer [ - [Inference](#inference) - [Jumpstart Endpoint](#jumpstart-endpoint-creation) - [Custom Endpoint](#custom-endpoint-creation) + - [Space](#space) - [SDK](#sdk) - [Cluster Management](#cluster-management-sdk) - [Training](#training-sdk) - [Inference](#inference-sdk) + - [Space](#space-sdk) - [Examples](#examples) @@ -614,6 +616,105 @@ hyp get-operator-logs hyp-custom-endpoint --since-hours 0.5 hyp delete hyp-custom-endpoint --name endpoint-custom ``` +### Space + +#### Create a Space + +```bash +hyp create hyp-space \ + --name myspace \ + --namespace default \ + --display-name "My Space" +``` + +| Parameter | Type | Required | Description | +|-----------|------|----------|-------------| +| `--name` | TEXT | Yes | Space name | +| `--display-name` | TEXT | Yes | Display Name of the space | +| `--namespace` | TEXT | No | Kubernetes namespace | +| `--image` | TEXT | No | Image specifies the container image to use | +| `--desired-status` | TEXT | No | DesiredStatus specifies the desired operational status | +| `--ownership-type` | TEXT | No | OwnershipType specifies who can modify the space. 'Public' means anyone with RBAC permissions can update/delete the space. 'OwnerOnly' means only the creator can update/delete the space. | +| `--node-selector` | TEXT | No | NodeSelector specifies node selection constraints for the space pod (JSON string) | +| `--affinity` | TEXT | No | Affinity specifies node affinity and anti-affinity rules for the space pod (JSON string) | +| `--tolerations` | TEXT | No | Tolerations specifies tolerations for the space pod to schedule on nodes with matching taints (JSON string) | +| `--lifecycle` | TEXT | No | Lifecycle specifies actions that the management system should take in response to container lifecycle events (JSON string) | +| `--app-type` | TEXT | No | AppType specifies the application type for this workspace | +| `--service-account-name` | TEXT | No | ServiceAccountName specifies the name of the ServiceAccount to use for the workspace pod | +| `--idle-shutdown` | TEXT | No | Idle shutdown configuration. Format: --idle-shutdown enabled=,idleTimeoutInMinutes=,detection= | +| `--template-ref` | TEXT | No | TemplateRef references a WorkspaceTemplate to use as base configuration. Format: --template-ref name=,namespace= | +| `--container-config` | TEXT | No | Container configuration. Format: --container-config command=,args= | +| `--storage` | TEXT | No | Storage configuration. Format: --storage storageClassName=,size=,mountPath= | +| `--volume` | TEXT | No | Volume configuration. Format: --volume name=,mountPath=,persistentVolumeClaimName=. Use multiple --volume flags for multiple volumes. | +| `--accelerator-partition-count` | TEXT | No | Fractional GPU partition count, e.g. '1' | +| `--accelerator-partition-type` | TEXT | No | Fractional GPU partition type, e.g. 'mig-3g.20gb' | +| `--gpu-limit` | TEXT | No | GPU resource limit, e.g. '1' | +| `--gpu` | TEXT | No | GPU resource request, e.g. '1' | +| `--memory-limit` | TEXT | No | Memory resource limit, e.g. '2Gi' | +| `--memory` | TEXT | No | Memory resource request, e.g. '2Gi' | +| `--cpu-limit` | TEXT | No | CPU resource limit, e.g. '500m' | +| `--cpu` | TEXT | No | CPU resource request, e.g. '500m' | + +#### List Spaces + +```bash +hyp list hyp-space +``` + +#### Describe a Space + +```bash +hyp describe hyp-space --name myspace +``` + +#### Update a Space + +```bash +hyp update hyp-space \ + --name myspace \ + --display-name "Updated Space Name" +``` + +#### Start/Stop a Space + +```bash +hyp start hyp-space --name myspace +hyp stop hyp-space --name myspace +``` + +#### Get Logs + +```bash +hyp get-logs hyp-space --name myspace +``` + +#### Delete a Space + +```bash +hyp delete hyp-space --name myspace +``` + +#### Space Template Management + +Create reusable space templates: + +```bash +hyp create hyp-space-template --file template.yaml +hyp list hyp-space-template +hyp describe hyp-space-template --name +hyp update hyp-space-template --name --file updated-template.yaml +hyp delete hyp-space-template --name +``` + +#### Space Access + +Create remote access to spaces: + +```bash +hyp create hyp-space-access --name myspace --connection-type vscode-remote +hyp create hyp-space-access --name myspace --connection-type web-ui +``` + ## SDK Along with the CLI, we also have SDKs available that can perform the cluster management, training and inference functionalities that the CLI performs @@ -993,6 +1094,159 @@ from sagemaker.hyperpod.observability.utils import get_monitoring_config monitor_config = get_monitoring_config() ``` +### Space SDK + +#### Creating a Space + +```python +from sagemaker.hyperpod.space.hyperpod_space import HPSpace +from hyperpod_space_template.v1_0.model import SpaceConfig + +# Create space configuration +space_config = SpaceConfig( + name="myspace", + namespace="default", + display_name="My Space", +) + +# Create and start the space +space = HPSpace(config=space_config) +space.create() +``` + +#### List Spaces + +```python +from sagemaker.hyperpod.space.hyperpod_space import HPSpace + +# List all spaces in default namespace +spaces = HPSpace.list() +for space in spaces: + print(f"Space: {space.config.name}, Status: {space.status}") + +# List spaces in specific namespace +spaces = HPSpace.list(namespace="your-namespace") +``` + +#### Get a Space + +```python +from sagemaker.hyperpod.space.hyperpod_space import HPSpace + +# Get specific space +space = HPSpace.get(name="myspace", namespace="default") +print(f"Space name: {space.config.name}") +print(f"Display name: {space.config.display_name}") +``` + +#### Update a Space + +```python +from sagemaker.hyperpod.space.hyperpod_space import HPSpace + +# Get existing space +space = HPSpace.get(name="myspace") + +# Update space configuration +space.update( + display_name="Updated Space Name", +) +``` + +#### Start/Stop a Space + +```python +from sagemaker.hyperpod.space.hyperpod_space import HPSpace + +# Get existing space +space = HPSpace.get(name="myspace") + +# Start the space +space.start() + +# Stop the space +space.stop() +``` + +#### Get Space Logs + +```python +from sagemaker.hyperpod.space.hyperpod_space import HPSpace + +# Get space and retrieve logs +space = HPSpace.get(name="myspace") + +# Get logs from default pod and container +logs = space.get_logs() +print(logs) +``` + +#### List Space Pods + +```python +from sagemaker.hyperpod.space.hyperpod_space import HPSpace + +# Get space and list associated pods +space = HPSpace.get(name="myspace") +pods = space.list_pods() +for pod in pods: + print(f"Pod: {pod}") +``` + +#### Create Space Access + +```python +from sagemaker.hyperpod.space.hyperpod_space import HPSpace + +# Get existing space +space = HPSpace.get(name="myspace") + +# Create VS Code remote access +vscode_access = space.create_space_access(connection_type="vscode-remote") +print(f"VS Code URL: {vscode_access['SpaceConnectionUrl']}") + +# Create web UI access +web_access = space.create_space_access(connection_type="web-ui") +print(f"Web UI URL: {web_access['SpaceConnectionUrl']}") +``` + +#### Delete a Space + +```python +from sagemaker.hyperpod.space.hyperpod_space import HPSpace + +# Get existing space +space = HPSpace.get(name="myspace") + +# Delete the space +space.delete() +``` + +#### Space Template Management + +```python +from sagemaker.hyperpod.space.hyperpod_space_template import HPSpaceTemplate + +# Create space template from YAML file +template = HPSpaceTemplate(file_path="template.yaml") +template.create() + +# List all space templates +templates = HPSpaceTemplate.list() +for template in templates: + print(f"Template: {template.name}") + +# Get specific space template +template = HPSpaceTemplate.get(name="my-template") +print(template.to_yaml()) + +# Update space template +template.update(file_path="updated-template.yaml") + +# Delete space template +template.delete() +``` + ## Examples #### Cluster Management Example Notebooks diff --git a/doc/cli/cli_index.rst b/doc/cli/cli_index.rst index 3d3885a3..801c7c2f 100644 --- a/doc/cli/cli_index.rst +++ b/doc/cli/cli_index.rst @@ -10,10 +10,11 @@ Complete reference for the SageMaker HyperPod Command Line Interface. cluster_management/cli_cluster_management training/cli_training inference/cli_inference + space/cli_space .. container:: - .. grid:: 1 1 3 3 + .. grid:: 1 1 4 4 :gutter: 3 .. grid-item-card:: Cluster Management CLI @@ -35,4 +36,11 @@ Complete reference for the SageMaker HyperPod Command Line Interface. :link-type: doc :class-card: sd-border-secondary - Inference CLI commands, options and parameters. \ No newline at end of file + Inference CLI commands, options and parameters. + + .. grid-item-card:: Space CLI + :link: space/cli_space + :link-type: doc + :class-card: sd-border-secondary + + Space management commands, options and parameters. \ No newline at end of file diff --git a/doc/cli/cli_reference.md b/doc/cli/cli_reference.md index 6ae3af58..2e40599b 100644 --- a/doc/cli/cli_reference.md +++ b/doc/cli/cli_reference.md @@ -9,12 +9,13 @@ cli_training cli_inference cli_cluster_management +cli_space ``` Complete reference for the SageMaker HyperPod Command Line Interface. ::::{container} -::::{grid} 1 1 3 3 +::::{grid} 1 1 4 4 :gutter: 3 :::{grid-item-card} Training CLI @@ -41,5 +42,13 @@ Inference CLI commands, options and parameters. Cluster stack management commands, options and parameters. ::: +:::{grid-item-card} Space CLI +:link: cli_space +:link-type: ref +:class-card: sd-border-secondary + +Space management commands, options and parameters. +::: + :::: :::: \ No newline at end of file diff --git a/doc/cli/space/cli_space.md b/doc/cli/space/cli_space.md new file mode 100644 index 00000000..c5b3b76d --- /dev/null +++ b/doc/cli/space/cli_space.md @@ -0,0 +1,410 @@ +(cli_space)= + +# Space + +Complete reference for Amazon SageMaker Space management commands and configuration options. + +```{note} +**Region Configuration**: For commands that accept the `--region` option, if no region is explicitly provided, the command will use the default region from your AWS credentials configuration. +``` + +* [Create Space](#hyp-create-hyp-space) +* [List Spaces](#hyp-list-hyp-space) +* [Describe Space](#hyp-describe-hyp-space) +* [Update Space](#hyp-update-hyp-space) +* [Delete Space](#hyp-delete-hyp-space) +* [Start Space](#hyp-start-hyp-space) +* [Stop Space](#hyp-stop-hyp-space) +* [Get Logs](#hyp-get-logs-hyp-space) +* [Create Space Access](#hyp-create-hyp-space-access) +* [Create Space Template](#hyp-create-hyp-space-template) +* [List Space Templates](#hyp-list-hyp-space-template) +* [Describe Space Template](#hyp-describe-hyp-space-template) +* [Update Space Template](#hyp-update-hyp-space-template) +* [Delete Space Template](#hyp-delete-hyp-space-template) + +## hyp create hyp-space + +Create a space resource on SageMaker HyperPod clusters. + +### Syntax + +```bash +hyp create hyp-space [OPTIONS] +``` + +### Parameters + +| Parameter | Type | Required | Description | +|-----------|------|----------|-------------| +| `--version` | TEXT | No | Schema version to use | +| `--name` | TEXT | Yes | Space name | +| `--display-name` | TEXT | Yes | Display Name of the space | +| `--namespace` | TEXT | No | Kubernetes namespace | +| `--image` | TEXT | No | Image specifies the container image to use | +| `--desired-status` | TEXT | No | DesiredStatus specifies the desired operational status | +| `--ownership-type` | TEXT | No | OwnershipType specifies who can modify the space ('Public' or 'OwnerOnly') | +| `--node-selector` | TEXT | No | NodeSelector specifies node selection constraints for the space pod (JSON string) | +| `--affinity` | TEXT | No | Affinity specifies node affinity and anti-affinity rules for the space pod (JSON string) | +| `--tolerations` | TEXT | No | Tolerations specifies tolerations for the space pod to schedule on nodes with matching taints (JSON string) | +| `--lifecycle` | TEXT | No | Lifecycle specifies actions that the management system should take in response to container lifecycle events (JSON string) | +| `--app-type` | TEXT | No | AppType specifies the application type for this workspace | +| `--service-account-name` | TEXT | No | ServiceAccountName specifies the name of the ServiceAccount to use for the workspace pod | +| `--idle-shutdown` | TEXT | No | Idle shutdown configuration. Format: enabled=,idleTimeoutInMinutes=,detection= | +| `--template-ref` | TEXT | No | TemplateRef references a WorkspaceTemplate to use as base configuration. Format: name=,namespace= | +| `--container-config` | TEXT | No | Container configuration. Format: command=,args= | +| `--storage` | TEXT | No | Storage configuration. Format: storageClassName=,size=,mountPath= | +| `--volume` | TEXT | No | Volume configuration. Format: name=,mountPath=,persistentVolumeClaimName=. Use multiple --volume flags for multiple volumes | +| `--accelerator-partition-count` | TEXT | No | Fractional GPU partition count, e.g. '1' | +| `--accelerator-partition-type` | TEXT | No | Fractional GPU partition type, e.g. 'mig-3g.20gb' | +| `--gpu-limit` | TEXT | No | GPU resource limit, e.g. '1' | +| `--gpu` | TEXT | No | GPU resource request, e.g. '1' | +| `--memory-limit` | TEXT | No | Memory resource limit, e.g. '2Gi' | +| `--memory` | TEXT | No | Memory resource request, e.g. '2Gi' | +| `--cpu-limit` | TEXT | No | CPU resource limit, e.g. '500m' | +| `--cpu` | TEXT | No | CPU resource request, e.g. '500m' | + +### Example + +```bash +hyp create hyp-space --version 1.0 --name my-space --namespace default +``` + +## Space Management Commands + +Commands for managing Amazon SageMaker Spaces. + +### hyp list hyp-space + +List all spaces in a namespace. + +#### Syntax + +```bash +hyp list hyp-space [OPTIONS] +``` + +#### Parameters + +| Parameter | Type | Required | Description | +|-----------|------|----------|-------------| +| `--namespace, -n` | TEXT | No | Kubernetes namespace (default: "default") | +| `--output, -o` | TEXT | No | Output format: table or json (default: "table") | + +#### Example + +```bash +hyp list hyp-space --namespace default --output table +``` + +### hyp describe hyp-space + +Describe a specific space resource. + +#### Syntax + +```bash +hyp describe hyp-space [OPTIONS] +``` + +#### Parameters + +| Parameter | Type | Required | Description | +|-----------|------|----------|-------------| +| `--name` | TEXT | Yes | Name of the space to describe | +| `--namespace, -n` | TEXT | No | Kubernetes namespace (default: "default") | +| `--output, -o` | TEXT | No | Output format: yaml or json (default: "yaml") | + +#### Example + +```bash +hyp describe hyp-space --name my-space --namespace default --output yaml +``` + +### hyp update hyp-space + +Update an existing space resource. + +#### Syntax + +```bash +hyp update hyp-space [OPTIONS] +``` + +#### Parameters + +| Parameter | Type | Required | Description | +|-----------|------|----------|-------------| +| `--version` | TEXT | No | Schema version to use | +| `--name` | TEXT | Yes | Space name | +| `--display-name` | TEXT | No | Display Name of the space | +| `--namespace` | TEXT | No | Kubernetes namespace | +| `--image` | TEXT | No | Image specifies the container image to use | +| `--desired-status` | TEXT | No | DesiredStatus specifies the desired operational status | +| `--ownership-type` | TEXT | No | OwnershipType specifies who can modify the space ('Public' or 'OwnerOnly') | +| `--node-selector` | TEXT | No | NodeSelector specifies node selection constraints for the space pod (JSON string) | +| `--affinity` | TEXT | No | Affinity specifies node affinity and anti-affinity rules for the space pod (JSON string) | +| `--tolerations` | TEXT | No | Tolerations specifies tolerations for the space pod to schedule on nodes with matching taints (JSON string) | +| `--lifecycle` | TEXT | No | Lifecycle specifies actions that the management system should take in response to container lifecycle events (JSON string) | +| `--app-type` | TEXT | No | AppType specifies the application type for this workspace | +| `--service-account-name` | TEXT | No | ServiceAccountName specifies the name of the ServiceAccount to use for the workspace pod | +| `--idle-shutdown` | TEXT | No | Idle shutdown configuration. Format: enabled=,idleTimeoutInMinutes=,detection= | +| `--template-ref` | TEXT | No | TemplateRef references a WorkspaceTemplate to use as base configuration. Format: name=,namespace= | +| `--container-config` | TEXT | No | Container configuration. Format: command=,args= | +| `--volume` | TEXT | No | Volume configuration. Format: name=,mountPath=,persistentVolumeClaimName=. Use multiple --volume flags for multiple volumes | +| `--accelerator-partition-count` | TEXT | No | Fractional GPU partition count, e.g. '1' | +| `--accelerator-partition-type` | TEXT | No | Fractional GPU partition type, e.g. 'mig-3g.20gb' | +| `--gpu-limit` | TEXT | No | GPU resource limit, e.g. '1' | +| `--gpu` | TEXT | No | GPU resource request, e.g. '1' | +| `--memory-limit` | TEXT | No | Memory resource limit, e.g. '2Gi' | +| `--memory` | TEXT | No | Memory resource request, e.g. '2Gi' | +| `--cpu-limit` | TEXT | No | CPU resource limit, e.g. '500m' | +| `--cpu` | TEXT | No | CPU resource request, e.g. '500m' | + +#### Example + +```bash +hyp update hyp-space --version 1.0 --name my-space --namespace default +``` + +### hyp delete hyp-space + +Delete a space resource. + +#### Syntax + +```bash +hyp delete hyp-space [OPTIONS] +``` + +#### Parameters + +| Parameter | Type | Required | Description | +|-----------|------|----------|-------------| +| `--name` | TEXT | Yes | Name of the space to delete | +| `--namespace, -n` | TEXT | No | Kubernetes namespace (default: "default") | + +#### Example + +```bash +hyp delete hyp-space --name my-space --namespace default +``` + +### hyp start hyp-space + +Start a space resource. + +#### Syntax + +```bash +hyp start hyp-space [OPTIONS] +``` + +#### Parameters + +| Parameter | Type | Required | Description | +|-----------|------|----------|-------------| +| `--name` | TEXT | Yes | Name of the space to start | +| `--namespace, -n` | TEXT | No | Kubernetes namespace (default: "default") | + +#### Example + +```bash +hyp start hyp-space --name my-space --namespace default +``` + +### hyp stop hyp-space + +Stop a space resource. + +#### Syntax + +```bash +hyp stop hyp-space [OPTIONS] +``` + +#### Parameters + +| Parameter | Type | Required | Description | +|-----------|------|----------|-------------| +| `--name` | TEXT | Yes | Name of the space to stop | +| `--namespace, -n` | TEXT | No | Kubernetes namespace (default: "default") | + +#### Example + +```bash +hyp stop hyp-space --name my-space --namespace default +``` + +### hyp get-logs hyp-space + +Get logs from a space resource. + +#### Syntax + +```bash +hyp get-logs hyp-space [OPTIONS] +``` + +#### Parameters + +| Parameter | Type | Required | Description | +|-----------|------|----------|-------------| +| `--name` | TEXT | Yes | Name of the space to get logs from | +| `--namespace, -n` | TEXT | No | Kubernetes namespace (default: "default") | +| `--pod-name` | TEXT | No | Name of the specific pod to get logs from | +| `--container` | TEXT | No | Name of the specific container to get logs from | + +#### Example + +```bash +hyp get-logs hyp-space --name my-space --namespace default --pod-name my-pod +``` + +## Space Access Commands + +Commands for managing space access resources. + +### hyp create hyp-space-access + +Create a space access resource for remote connection to a space. + +#### Syntax + +```bash +hyp create hyp-space-access [OPTIONS] +``` + +#### Parameters + +| Parameter | Type | Required | Description | +|-----------|------|----------|-------------| +| `--name` | TEXT | Yes | Name of the space to create access for | +| `--namespace, -n` | TEXT | No | Kubernetes namespace (default: "default") | +| `--connection-type, -t` | TEXT | No | Remote access type: vscode-remote or web-ui (default: "vscode-remote") | + +#### Example + +```bash +hyp create hyp-space-access --name my-space --namespace default --connection-type vscode-remote +``` + +## Space Template Commands + +Commands for managing space template resources. + +### hyp create hyp-space-template + +Create a space template resource from a YAML configuration file. + +#### Syntax + +```bash +hyp create hyp-space-template [OPTIONS] +``` + +#### Parameters + +| Parameter | Type | Required | Description | +|-----------|------|----------|-------------| +| `--file, -f` | TEXT | Yes | YAML file containing the template configuration | + +#### Example + +```bash +hyp create hyp-space-template --file my-template.yaml +``` + +### hyp list hyp-space-template + +List all space template resources. + +#### Syntax + +```bash +hyp list hyp-space-template [OPTIONS] +``` + +#### Parameters + +| Parameter | Type | Required | Description | +|-----------|------|----------|-------------| +| `--namespace, -n` | TEXT | No | Kubernetes namespace | +| `--output, -o` | TEXT | No | Output format: table or json (default: "table") | + +#### Example + +```bash +hyp list hyp-space-template --namespace default --output table +``` + +### hyp describe hyp-space-template + +Describe a specific space template resource. + +#### Syntax + +```bash +hyp describe hyp-space-template [OPTIONS] +``` + +#### Parameters + +| Parameter | Type | Required | Description | +|-----------|------|----------|-------------| +| `--name` | TEXT | Yes | Name of the space template to describe | +| `--namespace, -n` | TEXT | No | Kubernetes namespace | +| `--output, -o` | TEXT | No | Output format: yaml or json (default: "yaml") | + +#### Example + +```bash +hyp describe hyp-space-template --name my-template --namespace default --output yaml +``` + +### hyp update hyp-space-template + +Update an existing space template resource. + +#### Syntax + +```bash +hyp update hyp-space-template [OPTIONS] +``` + +#### Parameters + +| Parameter | Type | Required | Description | +|-----------|------|----------|-------------| +| `--name` | TEXT | Yes | Name of the space template to update | +| `--namespace, -n` | TEXT | No | Kubernetes namespace | +| `--file, -f` | TEXT | Yes | YAML file containing the updated template configuration | + +#### Example + +```bash +hyp update hyp-space-template --name my-template --namespace default --file updated-template.yaml +``` + +### hyp delete hyp-space-template + +Delete a space template resource. + +#### Syntax + +```bash +hyp delete hyp-space-template [OPTIONS] +``` + +#### Parameters + +| Parameter | Type | Required | Description | +|-----------|------|----------|-------------| +| `--name` | TEXT | Yes | Name of the space template to delete | +| `--namespace, -n` | TEXT | No | Kubernetes namespace | + +#### Example + +```bash +hyp delete hyp-space-template --name my-template --namespace default +``` diff --git a/doc/sdk/sdk_index.rst b/doc/sdk/sdk_index.rst index 7bdad56b..18b910de 100644 --- a/doc/sdk/sdk_index.rst +++ b/doc/sdk/sdk_index.rst @@ -9,12 +9,13 @@ SDK Reference cluster_management/hp_cluster_stack training/hyperpod_pytorch_job inference/hp_endpoint + space/hyperpod_space Complete reference for the SageMaker HyperPod SDK. .. container:: - .. grid:: 1 1 3 3 + .. grid:: 1 1 4 4 :gutter: 3 .. grid-item-card:: Cluster Management SDK @@ -38,4 +39,11 @@ Complete reference for the SageMaker HyperPod SDK. Inference SDK classes, methods and parameters. + .. grid-item-card:: Space SDK + :link: space/hyperpod_space + :link-type: doc + :class-card: sd-border-secondary + + Space SDK classes, methods and parameters. + diff --git a/doc/sdk/space/hyperpod_space.rst b/doc/sdk/space/hyperpod_space.rst new file mode 100644 index 00000000..73357ac4 --- /dev/null +++ b/doc/sdk/space/hyperpod_space.rst @@ -0,0 +1,30 @@ +Space +===== + +* `HPSpace`_ +* `HPSpaceTemplate`_ +* `Space Configs`_ + + +HPSpace +------- + +.. autoclass:: sagemaker.hyperpod.space.hyperpod_space.HPSpace + :exclude-members: is_kubeconfig_loaded, model_config, get_logger, verify_kube_config + :show-inheritance: + + +HPSpaceTemplate +--------------- + +.. autoclass:: sagemaker.hyperpod.space.hyperpod_space_template.HPSpaceTemplate + :exclude-members: is_kubeconfig_loaded, get_logger, verify_kube_config + :show-inheritance: + + +Space Configs +------------- + +.. automodule:: hyperpod_space_template.v1_0.model + :members: SpaceConfig + :show-inheritance: diff --git a/src/sagemaker/hyperpod/space/hyperpod_space.py b/src/sagemaker/hyperpod/space/hyperpod_space.py index da5555bd..817d6077 100644 --- a/src/sagemaker/hyperpod/space/hyperpod_space.py +++ b/src/sagemaker/hyperpod/space/hyperpod_space.py @@ -45,7 +45,41 @@ class HPSpace(BaseModel): """HyperPod Space on Amazon SageMaker HyperPod clusters. This class provides methods to create, manage, and monitor spaces - on SageMaker HyperPod clusters orchestrated by Amazon EKS. + on SageMaker HyperPod clusters orchestrated by Amazon EKS. Spaces are + interactive workspaces that provide development environments with + configurable resources, storage, and access controls. + + **Attributes:** + + .. list-table:: + :header-rows: 1 + :widths: 20 20 60 + + * - Attribute + - Type + - Description + * - config + - SpaceConfig + - The space configuration using the space parameter model + * - raw_resource + - Dict[str, Any], optional + - The complete Kubernetes resource data including apiVersion, kind, metadata, and status + + .. dropdown:: Usage Examples + :open: + + .. code-block:: python + + >>> # Create a new space + >>> from hyperpod_space_template.v1_0.model import SpaceConfig + >>> config = SpaceConfig(name="my-space", display_name="My Space") + >>> space = HPSpace(config=config) + >>> space.create() + + >>> # List all spaces + >>> spaces = HPSpace.list() + >>> for space in spaces: + ... print(f"Space: {space.config.name}") """ is_kubeconfig_loaded: ClassVar[bool] = False @@ -62,32 +96,116 @@ class HPSpace(BaseModel): @classmethod def get_logger(cls): - """Get logger for the class.""" + """Get logger for the HPSpace class. + + **Returns:** + + logging.Logger: Logger instance configured for the HPSpace class + + .. dropdown:: Usage Examples + :open: + + .. code-block:: python + + >>> logger = HPSpace.get_logger() + >>> logger.info("Space operation completed") + """ return logging.getLogger(__name__) @property def api_version(self) -> Optional[str]: - """Get the apiVersion from the Kubernetes resource.""" + """Get the apiVersion from the Kubernetes resource. + + **Returns:** + + str or None: The API version of the Kubernetes resource, or None if raw_resource is not available + + .. dropdown:: Usage Examples + :open: + + .. code-block:: python + + >>> space = HPSpace.get("my-space") + >>> print(f"API Version: {space.api_version}") + """ return self.raw_resource.get("apiVersion") if self.raw_resource else None @property def kind(self) -> Optional[str]: - """Get the kind from the Kubernetes resource.""" + """Get the kind from the Kubernetes resource. + + **Returns:** + + str or None: The kind of the Kubernetes resource, or None if raw_resource is not available + + .. dropdown:: Usage Examples + :open: + + .. code-block:: python + + >>> space = HPSpace.get("my-space") + >>> print(f"Resource Kind: {space.kind}") + """ return self.raw_resource.get("kind") if self.raw_resource else None @property def metadata(self) -> Optional[Dict[str, Any]]: - """Get the metadata from the Kubernetes resource.""" + """Get the metadata from the Kubernetes resource. + + **Returns:** + + Dict[str, Any] or None: The metadata section of the Kubernetes resource, or None if raw_resource is not available + + .. dropdown:: Usage Examples + :open: + + .. code-block:: python + + >>> space = HPSpace.get("my-space") + >>> print(f"Creation Time: {space.metadata['creationTimestamp']}") + """ return self.raw_resource.get("metadata") if self.raw_resource else None @property def status(self) -> Optional[Dict[str, Any]]: - """Get the status from the Kubernetes resource.""" + """Get the status from the Kubernetes resource. + + **Returns:** + + Dict[str, Any] or None: The status section of the Kubernetes resource, or None if raw_resource is not available + + .. dropdown:: Usage Examples + :open: + + .. code-block:: python + + >>> space = HPSpace.get("my-space") + >>> conditions = space.status.get('conditions', []) + >>> for condition in conditions: + ... print(f"{condition['type']}: {condition['status']}") + """ return self.raw_resource.get("status") if self.raw_resource else None @classmethod def verify_kube_config(cls): - """Verify and load Kubernetes configuration.""" + """Verify and load Kubernetes configuration. + + Loads the Kubernetes configuration from the default kubeconfig location + and verifies compatibility with the cluster. This method is called + automatically by other methods that interact with the Kubernetes API. + + **Raises:** + + RuntimeError: If the kubeconfig cannot be loaded or is invalid + + .. dropdown:: Usage Examples + :open: + + .. code-block:: python + + >>> # Verify kubeconfig before operations + >>> HPSpace.verify_kube_config() + """ if not cls.is_kubeconfig_loaded: try: config.load_kube_config() @@ -100,11 +218,39 @@ def verify_kube_config(cls): def create(self, debug: bool = False): """Create and submit the HyperPod Space to the Kubernetes cluster. - Args: - debug (bool, optional): Enable debug logging. Defaults to False. + Creates a new space resource in the Kubernetes cluster based on the + configuration provided in the space config. Validates MIG profiles + if enabled and converts the configuration to the appropriate domain model. + + **Parameters:** + + .. list-table:: + :header-rows: 1 + :widths: 20 20 60 + + * - Parameter + - Type + - Description + * - debug + - bool, optional + - Enable debug logging (default: False) - Raises: - Exception: If the space creation fails or Kubernetes API call fails + **Raises:** + + RuntimeError: If MIG profile validation fails or unsupported profiles are used + Exception: If the space creation fails or Kubernetes API call fails + + .. dropdown:: Usage Examples + :open: + + .. code-block:: python + + >>> # Create a space with debug logging + >>> space = HPSpace(config=space_config) + >>> space.create(debug=True) + + >>> # Create a space with default settings + >>> space.create() """ self.verify_kube_config() @@ -162,15 +308,44 @@ def create(self, debug: bool = False): def list(cls, namespace: Optional[str] = None) -> List["HPSpace"]: """List all HyperPod Spaces in the specified namespace created by the caller. - Args: - namespace (str, optional): The Kubernetes namespace to list spaces from. - If None, uses the default namespace from current context. + Retrieves all spaces that were either created by the current caller (based on + AWS STS identity) or are marked as 'Public' ownership type. Uses pagination + to handle large numbers of spaces efficiently. + + **Parameters:** + + .. list-table:: + :header-rows: 1 + :widths: 20 20 60 + + * - Parameter + - Type + - Description + * - namespace + - str, optional + - The Kubernetes namespace to list spaces from. If None, uses the default namespace from current context - Returns: - List[HPSpace]: List of HPSpace instances created by the caller + **Returns:** - Raises: - Exception: If the Kubernetes API call fails or spaces cannot be retrieved + List[HPSpace]: List of HPSpace instances created by the caller or marked as public + + **Raises:** + + Exception: If the Kubernetes API call fails or spaces cannot be retrieved + + .. dropdown:: Usage Examples + :open: + + .. code-block:: python + + >>> # List spaces in default namespace + >>> spaces = HPSpace.list() + >>> print(f"Found {len(spaces)} spaces") + + >>> # List spaces in specific namespace + >>> spaces = HPSpace.list(namespace="my-namespace") + >>> for space in spaces: + ... print(f"Space: {space.config.name}") """ cls.verify_kube_config() @@ -224,16 +399,46 @@ def list(cls, namespace: Optional[str] = None) -> List["HPSpace"]: def get(cls, name: str, namespace: str = None) -> "HPSpace": """Get a specific HyperPod Space by name. - Args: - name (str): The name of the space to retrieve - namespace (str, optional): The Kubernetes namespace. - If None, uses the default namespace from current context. + Retrieves a single space resource from the Kubernetes cluster and maps + the response to the SpaceConfig model for easy access to configuration + and status information. + + **Parameters:** - Returns: - HPSpace: The space instance + .. list-table:: + :header-rows: 1 + :widths: 20 20 60 - Raises: - Exception: If the space is not found or Kubernetes API call fails + * - Parameter + - Type + - Description + * - name + - str + - The name of the space to retrieve + * - namespace + - str, optional + - The Kubernetes namespace. If None, uses the default namespace from current context + + **Returns:** + + HPSpace: The space instance with configuration and raw Kubernetes resource data + + **Raises:** + + Exception: If the space is not found or Kubernetes API call fails + + .. dropdown:: Usage Examples + :open: + + .. code-block:: python + + >>> # Get space from default namespace + >>> space = HPSpace.get("my-space") + >>> print(f"Space status: {space.status}") + + >>> # Get space from specific namespace + >>> space = HPSpace.get("my-space", namespace="production") + >>> print(f"Display name: {space.config.display_name}") """ cls.verify_kube_config() @@ -267,8 +472,22 @@ def get(cls, name: str, namespace: str = None) -> "HPSpace": def delete(self): """Delete the HyperPod Space from the Kubernetes cluster. - Raises: - Exception: If the deletion fails or Kubernetes API call fails + Permanently removes the space resource from the Kubernetes cluster. + This operation cannot be undone and will terminate any running + workloads associated with the space. + + **Raises:** + + Exception: If the deletion fails or Kubernetes API call fails + + .. dropdown:: Usage Examples + :open: + + .. code-block:: python + + >>> # Delete a space + >>> space = HPSpace.get("my-space") + >>> space.delete() """ self.verify_kube_config() logger = self.get_logger() @@ -292,11 +511,42 @@ def delete(self): def update(self, **kwargs): """Update the HyperPod Space configuration. - Args: - **kwargs: Configuration fields to update (e.g., desired_status="Stopped") + Updates the space configuration with the provided parameters. Validates + MIG profiles if resource updates are requested and ensures compatibility + with the current node instance type. + + **Parameters:** + + .. list-table:: + :header-rows: 1 + :widths: 20 20 60 - Raises: - Exception: If the update fails or Kubernetes API call fails + * - Parameter + - Type + - Description + * - **kwargs + - Any + - Configuration fields to update (e.g., desired_status="Stopped", display_name="New Name") + + **Raises:** + + RuntimeError: If MIG profile validation fails or unsupported profiles are used + Exception: If the update fails or Kubernetes API call fails + + .. dropdown:: Usage Examples + :open: + + .. code-block:: python + + >>> # Update space status + >>> space = HPSpace.get("my-space") + >>> space.update(desired_status="Stopped") + + >>> # Update display name and resources + >>> space.update( + ... display_name="Updated Space", + ... resources={"requests": {"cpu": "2", "memory": "4Gi"}} + ... ) """ self.verify_kube_config() logger = self.get_logger() @@ -355,19 +605,63 @@ def update(self, **kwargs): @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "start_space") def start(self): - """Start the HyperPod Space by setting desired status to Running.""" + """Start the HyperPod Space by setting desired status to Running. + + Convenience method that updates the space's desired status to "Running", + which will cause the Kubernetes operator to start the space workloads. + + .. dropdown:: Usage Examples + :open: + + .. code-block:: python + + >>> # Start a space + >>> space = HPSpace.get("my-space") + >>> space.start() + """ self.update(desired_status="Running") @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "stop_space") def stop(self): - """Stop the HyperPod Space by setting desired status to Stopped.""" + """Stop the HyperPod Space by setting desired status to Stopped. + + Convenience method that updates the space's desired status to "Stopped", + which will cause the Kubernetes operator to stop the space workloads. + + .. dropdown:: Usage Examples + :open: + + .. code-block:: python + + >>> # Stop a space + >>> space = HPSpace.get("my-space") + >>> space.stop() + """ self.update(desired_status="Stopped") def list_pods(self) -> List[str]: """List all pods associated with this space. - Returns: - List[str]: List of pod names associated with the space + Retrieves all Kubernetes pods that are labeled as belonging to this + space using the workspace-name label selector. + + **Returns:** + + List[str]: List of pod names associated with the space + + **Raises:** + + Exception: If the Kubernetes API call fails + + .. dropdown:: Usage Examples + :open: + + .. code-block:: python + + >>> # List pods for a space + >>> space = HPSpace.get("my-space") + >>> pods = space.list_pods() + >>> print(f"Found {len(pods)} pods: {pods}") """ self.verify_kube_config() logger = self.get_logger() @@ -386,13 +680,47 @@ def list_pods(self) -> List[str]: def get_logs(self, pod_name: Optional[str] = None, container: Optional[str] = None) -> str: """Get logs from a pod associated with this space. - Args: - pod_name (str, optional): Name of the pod to get logs from. - If None, gets logs from the first available pod. - container (str, optional): Name of the container to get logs from. + Retrieves logs from a specific pod and container. If no pod is specified, + uses the first available pod. If no container is specified, defaults to + the "workspace" container. + + **Parameters:** + + .. list-table:: + :header-rows: 1 + :widths: 20 20 60 + + * - Parameter + - Type + - Description + * - pod_name + - str, optional + - Name of the pod to get logs from. If None, gets logs from the first available pod + * - container + - str, optional + - Name of the container to get logs from. Defaults to "workspace" + + **Returns:** - Returns: - str: The pod logs + str: The pod logs as a string + + **Raises:** + + RuntimeError: If no pods are found for the space + Exception: If the Kubernetes API call fails + + .. dropdown:: Usage Examples + :open: + + .. code-block:: python + + >>> # Get logs from default pod and container + >>> space = HPSpace.get("my-space") + >>> logs = space.get_logs() + >>> print(logs) + + >>> # Get logs from specific pod and container + >>> logs = space.get_logs(pod_name="my-pod", container="sidecar") """ self.verify_kube_config() logger = self.get_logger() @@ -421,14 +749,44 @@ def get_logs(self, pod_name: Optional[str] = None, container: Optional[str] = No def create_space_access(self, connection_type: str = "vscode-remote") -> Dict[str, str]: """Create a space access for this space. - Args: - connection_type (str, optional): The IDE type for remote access. Defaults to "vscode-remote". + Creates a space access resource that provides remote connection capabilities + to the space. Supports VS Code remote development and web UI access types. + + **Parameters:** + + .. list-table:: + :header-rows: 1 + :widths: 20 20 60 + + * - Parameter + - Type + - Description + * - connection_type + - str, optional + - The IDE type for remote access. Must be "vscode-remote" or "web-ui" (default: "vscode-remote") + + **Returns:** + + Dict[str, str]: Dictionary containing 'SpaceConnectionType' and 'SpaceConnectionUrl' keys + + **Raises:** + + ValueError: If connection_type is not "vscode-remote" or "web-ui" + Exception: If the space access creation fails or Kubernetes API call fails + + .. dropdown:: Usage Examples + :open: - Returns: - Dict[str, str]: Dictionary with 'SpaceConnectionType' and 'SpaceConnectionUrl' keys + .. code-block:: python - Raises: - Exception: If the space access creation fails + >>> # Create VS Code remote access + >>> space = HPSpace.get("my-space") + >>> access = space.create_space_access("vscode-remote") + >>> print(f"Connection URL: {access['SpaceConnectionUrl']}") + + >>> # Create web UI access + >>> access = space.create_space_access("web-ui") + >>> print(f"Web UI URL: {access['SpaceConnectionUrl']}") """ self.verify_kube_config() logger = self.get_logger() diff --git a/src/sagemaker/hyperpod/space/hyperpod_space_template.py b/src/sagemaker/hyperpod/space/hyperpod_space_template.py index 0c596372..1ce8ccb0 100644 --- a/src/sagemaker/hyperpod/space/hyperpod_space_template.py +++ b/src/sagemaker/hyperpod/space/hyperpod_space_template.py @@ -24,20 +24,85 @@ class HPSpaceTemplate: """HyperPod Space Template on Amazon SageMaker HyperPod clusters. This class provides methods to create, manage, and monitor space templates - on SageMaker HyperPod clusters orchestrated by Amazon EKS. + on SageMaker HyperPod clusters orchestrated by Amazon EKS. Space templates + define reusable configurations for creating spaces with predefined settings, + resources, and constraints. + + **Attributes:** + + .. list-table:: + :header-rows: 1 + :widths: 20 20 60 + + * - Attribute + - Type + - Description + * - config_data + - Dict[str, Any] + - Dictionary containing the complete template configuration + * - name + - str + - Name of the space template extracted from metadata + * - namespace + - str + - Kubernetes namespace of the template extracted from metadata + + .. dropdown:: Usage Examples + :open: + + .. code-block:: python + + >>> # Create template from YAML file + >>> template = HPSpaceTemplate(file_path="template.yaml") + >>> template.create() + + >>> # List all templates + >>> templates = HPSpaceTemplate.list() + >>> for template in templates: + ... print(f"Template: {template.name}") """ is_kubeconfig_loaded: ClassVar[bool] = False def __init__(self, *, file_path: Optional[str] = None, config_data: Optional[Dict[str, Any]] = None): """Initialize space template with config YAML file path or dictionary data. - - Args: - file_path: Path to YAML configuration file - config_data: Dictionary containing configuration data - - Raises: - ValueError: If both or neither parameters are provided + + Creates a new HPSpaceTemplate instance from either a YAML configuration file + or a dictionary containing configuration data. Exactly one of the parameters + must be provided. + + **Parameters:** + + .. list-table:: + :header-rows: 1 + :widths: 20 20 60 + + * - Parameter + - Type + - Description + * - file_path + - str, optional + - Path to YAML configuration file (keyword-only) + * - config_data + - Dict[str, Any], optional + - Dictionary containing configuration data (keyword-only) + + **Raises:** + + ValueError: If both or neither parameters are provided, or if YAML parsing fails + FileNotFoundError: If the specified file path does not exist + + .. dropdown:: Usage Examples + :open: + + .. code-block:: python + + >>> # Initialize from YAML file + >>> template = HPSpaceTemplate(file_path="my-template.yaml") + + >>> # Initialize from dictionary (e.g., from API response) + >>> config = {"metadata": {"name": "my-template"}, "spec": {...}} + >>> template = HPSpaceTemplate(config_data=config) """ if (file_path is None) == (config_data is None): raise ValueError("Exactly one of 'file_path' or 'config_data' must be provided") @@ -60,12 +125,42 @@ def __init__(self, *, file_path: Optional[str] = None, config_data: Optional[Dic @classmethod def get_logger(cls): - """Get logger for the class.""" + """Get logger for the HPSpaceTemplate class. + + **Returns:** + + logging.Logger: Logger instance configured for the HPSpaceTemplate class + + .. dropdown:: Usage Examples + :open: + + .. code-block:: python + + >>> logger = HPSpaceTemplate.get_logger() + >>> logger.info("Template operation completed") + """ return logging.getLogger(__name__) @classmethod def verify_kube_config(cls): - """Verify and load Kubernetes configuration.""" + """Verify and load Kubernetes configuration. + + Loads the Kubernetes configuration from the default kubeconfig location + and verifies compatibility with the cluster. This method is called + automatically by other methods that interact with the Kubernetes API. + + **Raises:** + + Exception: If the kubeconfig cannot be loaded or is invalid + + .. dropdown:: Usage Examples + :open: + + .. code-block:: python + + >>> # Verify kubeconfig before operations + >>> HPSpaceTemplate.verify_kube_config() + """ if not cls.is_kubeconfig_loaded: config.load_kube_config() cls.is_kubeconfig_loaded = True @@ -73,10 +168,30 @@ def verify_kube_config(cls): @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "create_space_template") def create(self) -> "HPSpaceTemplate": - """Create the space template in the cluster. - - Returns: - Updated HPSpaceTemplate instance with server response + """Create the space template in the Kubernetes cluster. + + Submits the space template configuration to the Kubernetes cluster and + creates a new template resource. Updates the instance with the server + response including generated metadata. + + **Returns:** + + HPSpaceTemplate: Updated HPSpaceTemplate instance with server response data + + **Raises:** + + ApiException: If the Kubernetes API call fails + Exception: If template creation fails for other reasons + + .. dropdown:: Usage Examples + :open: + + .. code-block:: python + + >>> # Create template from file + >>> template = HPSpaceTemplate(file_path="template.yaml") + >>> created_template = template.create() + >>> print(f"Created template: {created_template.name}") """ self.verify_kube_config() @@ -104,12 +219,45 @@ def create(self) -> "HPSpaceTemplate": def list(cls, namespace: Optional[str] = None) -> List["HPSpaceTemplate"]: """List all space templates in the specified namespace. - Args: - namespace (str, optional): The Kubernetes namespace to list space templates from. - If None, uses the default namespace from current context. - - Returns: - List of HPSpaceTemplate instances + Retrieves all space template resources from the Kubernetes cluster in the + specified namespace. If no namespace is provided, uses the default namespace + from the current Kubernetes context. + + **Parameters:** + + .. list-table:: + :header-rows: 1 + :widths: 20 20 60 + + * - Parameter + - Type + - Description + * - namespace + - str, optional + - The Kubernetes namespace to list space templates from. If None, uses the default namespace from current context + + **Returns:** + + List[HPSpaceTemplate]: List of HPSpaceTemplate instances found in the namespace + + **Raises:** + + ApiException: If the Kubernetes API call fails + Exception: If template listing fails for other reasons + + .. dropdown:: Usage Examples + :open: + + .. code-block:: python + + >>> # List templates in default namespace + >>> templates = HPSpaceTemplate.list() + >>> print(f"Found {len(templates)} templates") + + >>> # List templates in specific namespace + >>> templates = HPSpaceTemplate.list(namespace="production") + >>> for template in templates: + ... print(f"Template: {template.name}") """ cls.verify_kube_config() @@ -141,14 +289,47 @@ def list(cls, namespace: Optional[str] = None) -> List["HPSpaceTemplate"]: @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "get_space_template") def get(cls, name: str, namespace: Optional[str] = None) -> "HPSpaceTemplate": """Get a specific space template by name. - - Args: - name: Name of the space template - namespace (str, optional): The Kubernetes namespace. - If None, uses the default namespace from current context. - - Returns: - HPSpaceTemplate instance + + Retrieves a single space template resource from the Kubernetes cluster + by name. Removes managedFields from the metadata for cleaner output. + + **Parameters:** + + .. list-table:: + :header-rows: 1 + :widths: 20 20 60 + + * - Parameter + - Type + - Description + * - name + - str + - Name of the space template to retrieve + * - namespace + - str, optional + - The Kubernetes namespace. If None, uses the default namespace from current context + + **Returns:** + + HPSpaceTemplate: The space template instance with configuration data + + **Raises:** + + ApiException: If the template is not found or Kubernetes API call fails + Exception: If template retrieval fails for other reasons + + .. dropdown:: Usage Examples + :open: + + .. code-block:: python + + >>> # Get template from default namespace + >>> template = HPSpaceTemplate.get("my-template") + >>> print(f"Template display name: {template.config_data['spec']['displayName']}") + + >>> # Get template from specific namespace + >>> template = HPSpaceTemplate.get("my-template", namespace="production") + >>> print(template.to_yaml()) """ cls.verify_kube_config() @@ -179,7 +360,26 @@ def get(cls, name: str, namespace: Optional[str] = None) -> "HPSpaceTemplate": @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "delete_space_template") def delete(self) -> None: - """Delete the space template from the cluster.""" + """Delete the space template from the Kubernetes cluster. + + Permanently removes the space template resource from the Kubernetes cluster. + This operation cannot be undone. Any spaces created from this template + will continue to exist but will no longer reference the template. + + **Raises:** + + ApiException: If the deletion fails or Kubernetes API call fails + Exception: If template deletion fails for other reasons + + .. dropdown:: Usage Examples + :open: + + .. code-block:: python + + >>> # Delete a template + >>> template = HPSpaceTemplate.get("my-template") + >>> template.delete() + """ self.verify_kube_config() try: @@ -202,13 +402,45 @@ def delete(self) -> None: @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "update_space_template") def update(self, file_path: str) -> "HPSpaceTemplate": - """Update the space template from a YAML file. - - Args: - file_path: Path to the YAML configuration file - - Returns: - Updated HPSpaceTemplate instance + """Update the space template from a YAML configuration file. + + Updates the existing space template with new configuration from a YAML file. + Validates that the template name in the file matches the current template name + and removes immutable fields before applying the update. + + **Parameters:** + + .. list-table:: + :header-rows: 1 + :widths: 20 20 60 + + * - Parameter + - Type + - Description + * - file_path + - str + - Path to the YAML configuration file containing updated template configuration + + **Returns:** + + HPSpaceTemplate: Updated HPSpaceTemplate instance with server response data + + **Raises:** + + FileNotFoundError: If the specified file path does not exist + ValueError: If YAML parsing fails or template name mismatch occurs + ApiException: If the Kubernetes API call fails + Exception: If template update fails for other reasons + + .. dropdown:: Usage Examples + :open: + + .. code-block:: python + + >>> # Update template from file + >>> template = HPSpaceTemplate.get("my-template") + >>> updated_template = template.update("updated-template.yaml") + >>> print(f"Updated template: {updated_template.name}") """ self.verify_kube_config() @@ -251,16 +483,53 @@ def update(self, file_path: str) -> "HPSpaceTemplate": def to_yaml(self) -> str: """Convert the space template to YAML format. - - Returns: - YAML string representation + + Serializes the template configuration data to a YAML string representation + with readable formatting (non-flow style). + + **Returns:** + + str: YAML string representation of the template configuration + + .. dropdown:: Usage Examples + :open: + + .. code-block:: python + + >>> # Convert template to YAML + >>> template = HPSpaceTemplate.get("my-template") + >>> yaml_content = template.to_yaml() + >>> print(yaml_content) + + >>> # Save template to file + >>> with open("exported-template.yaml", "w") as f: + ... f.write(template.to_yaml()) """ return yaml.dump(self.config_data, default_flow_style=False) def to_dict(self) -> Dict[str, Any]: """Convert the space template to dictionary format. - - Returns: - Dictionary representation + + Returns the template configuration data as a dictionary, which can be + used for programmatic access to template properties or serialization + to other formats. + + **Returns:** + + Dict[str, Any]: Dictionary representation of the template configuration + + .. dropdown:: Usage Examples + :open: + + .. code-block:: python + + >>> # Get template as dictionary + >>> template = HPSpaceTemplate.get("my-template") + >>> config_dict = template.to_dict() + >>> print(f"Template spec: {config_dict['spec']}") + + >>> # Access specific configuration values + >>> display_name = config_dict['spec']['displayName'] + >>> default_image = config_dict['spec']['defaultImage'] """ return self.config_data From 2ce7ab0f17fe67860a52809e907c84047f2d842b Mon Sep 17 00:00:00 2001 From: aws-brianxia Date: Thu, 20 Nov 2025 17:59:55 -0800 Subject: [PATCH 24/31] Implement Space integration tests (#298) Inference tests succeeded with parker-cli code - https://quip-amazon.com/fhwhAAMht0Mm/Project-Parker-HyperPod-User-Experience-for-Data-Scientist-persona Parker-cli integ tests pass (shown below) These inference tests failing are known to be flaky- https://w.amazon.com/bin/view/AWS/AmazonAI/Platform/Codex/CodexInfra/Runbooks/HyperPodCLI/TroubleshootInferenceTests#HTroubleshooting ticket has been created to fix these flaky tests - https://t.corp.amazon.com/V1943878058 Parker-cli integ tests passing ============================= test session starts ============================== platform linux -- Python 3.11.14, pytest-8.3.2, pluggy-1.6.0 -- /root/.pyenv/versions/3.11.14/bin/python3.11 cachedir: .pytest_cache rootdir: /codebuild/output/src1458832038/src/github.com/aws/private-sagemaker-hyperpod-cli-staging configfile: setup.cfg plugins: hydra-core-1.3.2, order-1.3.0, dependency-0.6.0, cov-5.0.0 collecting ... collected 39 items test/integration_tests/space/cli/test_cli_space.py::TestSpaceCLI::test_space_create PASSED [ 2%] test/integration_tests/space/cli/test_cli_space.py::TestSpaceCLI::test_space_list_table PASSED [ 5%] test/integration_tests/space/cli/test_cli_space.py::TestSpaceCLI::test_space_list_json PASSED [ 7%] test/integration_tests/space/cli/test_cli_space.py::TestSpaceCLI::test_space_describe_yaml PASSED [ 10%] test/integration_tests/space/cli/test_cli_space.py::TestSpaceCLI::test_space_describe_json PASSED [ 12%] test/integration_tests/space/cli/test_cli_space.py::TestSpaceCLI::test_space_stop PASSED [ 15%] test/integration_tests/space/cli/test_cli_space.py::TestSpaceCLI::test_space_start PASSED [ 17%] test/integration_tests/space/cli/test_cli_space.py::TestSpaceCLI::test_space_update PASSED [ 20%] test/integration_tests/space/cli/test_cli_space.py::TestSpaceCLI::test_space_get_logs PASSED [ 23%] test/integration_tests/space/cli/test_cli_space.py::TestSpaceCLI::test_space_delete PASSED [ 25%] test/integration_tests/space/cli/test_cli_space.py::TestSpaceCLI::test_space_list_empty_namespace PASSED [ 28%] test/integration_tests/space/cli/test_cli_space.py::TestSpaceCLI::test_space_describe_nonexistent PASSED [ 30%] test/integration_tests/space/cli/test_cli_space.py::TestSpaceCLI::test_space_delete_nonexistent PASSED [ 33%] test/integration_tests/space/cli/test_cli_space_template.py::TestSpaceTemplateCLI::test_space_template_create PASSED [ 35%] test/integration_tests/space/cli/test_cli_space_template.py::TestSpaceTemplateCLI::test_space_template_list_table PASSED [ 38%] test/integration_tests/space/cli/test_cli_space_template.py::TestSpaceTemplateCLI::test_space_template_list_json PASSED [ 41%] test/integration_tests/space/cli/test_cli_space_template.py::TestSpaceTemplateCLI::test_space_template_describe_yaml PASSED [ 43%] test/integration_tests/space/cli/test_cli_space_template.py::TestSpaceTemplateCLI::test_space_template_describe_json PASSED [ 46%] test/integration_tests/space/cli/test_cli_space_template.py::TestSpaceTemplateCLI::test_space_template_update PASSED [ 48%] test/integration_tests/space/cli/test_cli_space_template.py::TestSpaceTemplateCLI::test_space_template_delete PASSED [ 51%] test/integration_tests/space/cli/test_cli_space_template.py::TestSpaceTemplateCLI::test_space_template_list_empty_namespace PASSED [ 53%] test/integration_tests/space/cli/test_cli_space_template.py::TestSpaceTemplateCLI::test_space_template_describe_nonexistent PASSED [ 56%] test/integration_tests/space/cli/test_cli_space_template.py::TestSpaceTemplateCLI::test_space_template_delete_nonexistent PASSED [ 58%] test/integration_tests/space/sdk/test_sdk_space.py::test_create_space PASSED [ 61%] test/integration_tests/space/sdk/test_sdk_space.py::test_list_spaces PASSED [ 64%] test/integration_tests/space/sdk/test_sdk_space.py::test_get_space PASSED [ 66%] test/integration_tests/space/sdk/test_sdk_space.py::test_wait_until_running PASSED [ 69%] test/integration_tests/space/sdk/test_sdk_space.py::test_update_space PASSED [ 71%] test/integration_tests/space/sdk/test_sdk_space.py::test_stop_space PASSED [ 74%] test/integration_tests/space/sdk/test_sdk_space.py::test_start_space PASSED [ 76%] test/integration_tests/space/sdk/test_sdk_space.py::test_list_pods PASSED [ 79%] test/integration_tests/space/sdk/test_sdk_space.py::test_get_logs PASSED [ 82%] test/integration_tests/space/sdk/test_sdk_space.py::test_create_space_access SKIPPED [ 84%] test/integration_tests/space/sdk/test_sdk_space.py::test_delete_space PASSED [ 87%] test/integration_tests/space/sdk/test_sdk_space_template.py::TestHPSpaceTemplate::test_create_template PASSED [ 89%] test/integration_tests/space/sdk/test_sdk_space_template.py::TestHPSpaceTemplate::test_list_templates PASSED [ 92%] test/integration_tests/space/sdk/test_sdk_space_template.py::TestHPSpaceTemplate::test_get_template PASSED [ 94%] test/integration_tests/space/sdk/test_sdk_space_template.py::TestHPSpaceTemplate::test_update_template PASSED [ 97%] test/integration_tests/space/sdk/test_sdk_space_template.py::TestHPSpaceTemplate::test_delete_template PASSED [100%] =============================== warnings summary =============================== --- .../space/cli/test_cli_space.py | 164 ++++++++++++ .../space/cli/test_cli_space_template.py | 251 ++++++++++++++++++ .../space/sdk/test_sdk_space.py | 161 +++++++++++ .../space/sdk/test_sdk_space_template.py | 141 ++++++++++ test/unit_tests/cli/test_space.py | 3 - 5 files changed, 717 insertions(+), 3 deletions(-) create mode 100644 test/integration_tests/space/cli/test_cli_space.py create mode 100644 test/integration_tests/space/cli/test_cli_space_template.py create mode 100644 test/integration_tests/space/sdk/test_sdk_space.py create mode 100644 test/integration_tests/space/sdk/test_sdk_space_template.py diff --git a/test/integration_tests/space/cli/test_cli_space.py b/test/integration_tests/space/cli/test_cli_space.py new file mode 100644 index 00000000..b0d0a012 --- /dev/null +++ b/test/integration_tests/space/cli/test_cli_space.py @@ -0,0 +1,164 @@ +import time +import pytest +from click.testing import CliRunner +from sagemaker.hyperpod.cli.commands.space import ( + space_create, space_list, space_describe, space_delete, + space_update, space_start, space_stop, space_get_logs +) +from test.integration_tests.utils import get_time_str + +# --------- Test Configuration --------- +NAMESPACE = "default" +VERSION = "1.0" +SPACE_NAME = "space-cli-integ-test" + get_time_str() +DISPLAY_NAME = f"Space CLI Integ Test {get_time_str()}" + + +@pytest.fixture(scope="module") +def runner(): + return CliRunner() + +@pytest.fixture(scope="module") +def space_name(): + return SPACE_NAME + +class TestSpaceCLI: + """Integration tests for HyperPod Space CLI commands.""" + + @pytest.mark.dependency(name="create") + def test_space_create(self, runner, space_name): + """Test creating a space via CLI.""" + result = runner.invoke(space_create, [ + "--name", space_name, + "--display-name", DISPLAY_NAME, + "--namespace", NAMESPACE, + ]) + assert result.exit_code == 0, result.output + assert f"Space '{space_name}' created successfully" in result.output + + @pytest.mark.dependency(depends=["create"]) + def test_space_list_table(self, runner, space_name): + """Test listing spaces in table format.""" + result = runner.invoke(space_list, [ + "--namespace", NAMESPACE, + "--output", "table" + ]) + assert result.exit_code == 0, result.output + assert space_name in result.output + assert "NAME" in result.output + assert "NAMESPACE" in result.output + + @pytest.mark.dependency(depends=["create"]) + def test_space_list_json(self, runner, space_name): + """Test listing spaces in JSON format.""" + result = runner.invoke(space_list, [ + "--namespace", NAMESPACE, + "--output", "json" + ]) + assert result.exit_code == 0, result.output + assert space_name in result.output + # Verify it's valid JSON by checking for brackets + assert "[" in result.output and "]" in result.output + + @pytest.mark.dependency(name="describe", depends=["create"]) + def test_space_describe_yaml(self, runner, space_name): + """Test describing a space in YAML format.""" + result = runner.invoke(space_describe, [ + "--name", space_name, + "--namespace", NAMESPACE, + "--output", "yaml" + ]) + assert result.exit_code == 0, result.output + assert space_name in result.output + assert "apiVersion:" in result.output + assert "kind:" in result.output + + @pytest.mark.dependency(depends=["create"]) + def test_space_describe_json(self, runner, space_name): + """Test describing a space in JSON format.""" + result = runner.invoke(space_describe, [ + "--name", space_name, + "--namespace", NAMESPACE, + "--output", "json" + ]) + assert result.exit_code == 0, result.output + assert space_name in result.output + assert "{" in result.output and "}" in result.output + + @pytest.mark.dependency(depends=["create"]) + def test_space_stop(self, runner, space_name): + """Test stopping a space.""" + result = runner.invoke(space_stop, [ + "--name", space_name, + "--namespace", NAMESPACE + ]) + assert result.exit_code == 0, result.output + assert f"Space '{space_name}' stop requested" in result.output + + @pytest.mark.dependency(depends=["create"]) + def test_space_start(self, runner, space_name): + """Test starting a space.""" + result = runner.invoke(space_start, [ + "--name", space_name, + "--namespace", NAMESPACE + ]) + assert result.exit_code == 0, result.output + assert f"Space '{space_name}' start requested" in result.output + + @pytest.mark.dependency(depends=["create", "describe"]) + def test_space_update(self, runner, space_name): + """Test updating a space.""" + result = runner.invoke(space_update, [ + "--name", space_name, + "--namespace", NAMESPACE, + "--display-name", f"Updated {DISPLAY_NAME}", + ]) + assert result.exit_code == 0, result.output + assert f"Space '{space_name}' updated successfully" in result.output + + @pytest.mark.dependency(depends=["create"]) + def test_space_get_logs(self, runner, space_name): + """Test getting logs from a space.""" + # This might fail if no pods are running, which is acceptable + result = runner.invoke(space_get_logs, [ + "--name", space_name, + "--namespace", NAMESPACE + ]) + # Don't assert exit code as logs might not be available + # Just verify the command runs without crashing + assert isinstance(result.exit_code, int) + + @pytest.mark.dependency(depends=["create"]) + def test_space_delete(self, runner, space_name): + """Test deleting a space.""" + result = runner.invoke(space_delete, [ + "--name", space_name, + "--namespace", NAMESPACE + ]) + assert result.exit_code == 0, result.output + assert f"Requested deletion for Space '{space_name}'" in result.output + + def test_space_list_empty_namespace(self, runner): + """Test listing spaces in an empty namespace.""" + result = runner.invoke(space_list, [ + "--namespace", "nonexistent-namespace", + "--output", "table" + ]) + assert result.exit_code == 0, result.output + assert "No spaces found" in result.output + + def test_space_describe_nonexistent(self, runner): + """Test describing a nonexistent space.""" + result = runner.invoke(space_describe, [ + "--name", "nonexistent-space", + "--namespace", NAMESPACE + ]) + assert result.exit_code != 0 + + def test_space_delete_nonexistent(self, runner): + """Test deleting a nonexistent space.""" + result = runner.invoke(space_delete, [ + "--name", "nonexistent-space", + "--namespace", NAMESPACE + ]) + assert result.exit_code != 0 diff --git a/test/integration_tests/space/cli/test_cli_space_template.py b/test/integration_tests/space/cli/test_cli_space_template.py new file mode 100644 index 00000000..baee8b50 --- /dev/null +++ b/test/integration_tests/space/cli/test_cli_space_template.py @@ -0,0 +1,251 @@ +import pytest +import tempfile +import os +import yaml +import json +from click.testing import CliRunner +from sagemaker.hyperpod.cli.commands.space_template import ( + space_template_create, space_template_list, space_template_describe, + space_template_delete, space_template_update +) +from test.integration_tests.utils import get_time_str + +# --------- Test Configuration --------- +NAMESPACE = "default" +TEMPLATE_NAME = "space-template-cli-integ-test" + get_time_str() + +# Template configuration aligned with template.yaml +TEMPLATE_CONFIG = { + "apiVersion": "workspace.jupyter.org/v1alpha1", + "kind": "WorkspaceTemplate", + "metadata": { + "name": TEMPLATE_NAME, + "namespace": NAMESPACE + }, + "spec": { + "displayName": f"Space Template CLI Integ Test {get_time_str()}", + "description": "Integration test template for Space Template CLI", + "defaultImage": "jk8s-application-jupyter-uv:latest", + "allowedImages": [ + "jk8s-application-jupyter-uv:latest" + ], + "defaultResources": { + "requests": { + "cpu": "200m", + "memory": "256Mi" + }, + "limits": { + "cpu": "500m", + "memory": "512Mi" + } + }, + "resourceBounds": { + "cpu": { + "min": "100m", + "max": "2" + }, + "memory": { + "min": "128Mi", + "max": "4Gi" + }, + "gpu": { + "min": "0", + "max": "1" + } + }, + "primaryStorage": { + "defaultSize": "1Gi", + "minSize": "100Mi", + "maxSize": "20Gi" + }, + "appType": "jupyter" + } +} + +@pytest.fixture(scope="module") +def runner(): + return CliRunner() + +@pytest.fixture(scope="module") +def template_yaml_file(): + """Create a temporary YAML file with template configuration.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + yaml.dump(TEMPLATE_CONFIG, f) + temp_file = f.name + + yield temp_file + + # Cleanup + if os.path.exists(temp_file): + os.unlink(temp_file) + +@pytest.fixture(scope="module") +def template_name(): + return TEMPLATE_NAME + +class TestSpaceTemplateCLI: + """Integration tests for HyperPod Space Template CLI commands.""" + + @pytest.mark.dependency(name="create") + def test_space_template_create(self, runner, template_yaml_file, template_name): + """Test creating a space template via CLI.""" + result = runner.invoke(space_template_create, [ + "--file", template_yaml_file + ]) + assert result.exit_code == 0, result.output + assert f"Space template '{template_name}' in namespace '{NAMESPACE}' created successfully" in result.output + + @pytest.mark.dependency(depends=["create"]) + def test_space_template_list_table(self, runner, template_name): + """Test listing space templates in table format.""" + result = runner.invoke(space_template_list, [ + "--namespace", NAMESPACE, + "--output", "table" + ]) + assert result.exit_code == 0, result.output + assert template_name in result.output + assert "NAMESPACE" in result.output + assert "NAME" in result.output + assert "DISPLAY_NAME" in result.output + assert "DEFAULT_IMAGE" in result.output + + @pytest.mark.dependency(depends=["create"]) + def test_space_template_list_json(self, runner, template_name): + """Test listing space templates in JSON format.""" + result = runner.invoke(space_template_list, [ + "--namespace", NAMESPACE, + "--output", "json" + ]) + assert result.exit_code == 0, result.output + assert template_name in result.output + + # Verify it's valid JSON + try: + templates_data = json.loads(result.output) + assert isinstance(templates_data, list) + + # Find our template in the list + our_template = next((t for t in templates_data if t.get("metadata", {}).get("name") == template_name), None) + assert our_template is not None + + except json.JSONDecodeError: + pytest.fail("Invalid JSON output from space template list command") + + @pytest.mark.dependency(name="describe", depends=["create"]) + def test_space_template_describe_yaml(self, runner, template_name): + """Test describing a space template in YAML format.""" + result = runner.invoke(space_template_describe, [ + "--name", template_name, + "--namespace", NAMESPACE, + "--output", "yaml" + ]) + assert result.exit_code == 0, result.output + assert template_name in result.output + assert "apiVersion:" in result.output + assert "kind:" in result.output + + # Verify YAML structure + try: + template_data = yaml.safe_load(result.output) + assert template_data["metadata"]["name"] == template_name + assert template_data["metadata"]["namespace"] == NAMESPACE + + except yaml.YAMLError: + pytest.fail("Invalid YAML output from space template describe command") + + @pytest.mark.dependency(depends=["create"]) + def test_space_template_describe_json(self, runner, template_name): + """Test describing a space template in JSON format.""" + result = runner.invoke(space_template_describe, [ + "--name", template_name, + "--namespace", NAMESPACE, + "--output", "json" + ]) + assert result.exit_code == 0, result.output + assert template_name in result.output + + # Verify JSON structure + try: + template_data = json.loads(result.output) + assert template_data["metadata"]["name"] == template_name + assert template_data["metadata"]["namespace"] == NAMESPACE + assert template_data["kind"] == "WorkspaceTemplate" + + except json.JSONDecodeError: + pytest.fail("Invalid JSON output from space template describe command") + + @pytest.mark.dependency(depends=["create", "describe"]) + def test_space_template_update(self, runner, template_name): + """Test updating a space template.""" + # Create updated config + updated_config = TEMPLATE_CONFIG.copy() + updated_config["spec"]["description"] = "Updated CLI integration test template" + updated_config["spec"]["defaultResources"]["requests"]["cpu"] = "300m" + + # Create temporary file with updated config + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + yaml.dump(updated_config, f) + temp_file = f.name + + try: + result = runner.invoke(space_template_update, [ + "--name", template_name, + "--namespace", NAMESPACE, + "--file", temp_file + ]) + assert result.exit_code == 0, result.output + assert f"Space template '{template_name}' in namespace '{NAMESPACE}' updated successfully" in result.output + + # Verify update by describing the template + describe_result = runner.invoke(space_template_describe, [ + "--name", template_name, + "--namespace", NAMESPACE, + "--output", "json" + ]) + assert describe_result.exit_code == 0 + + try: + template_data = json.loads(describe_result.output) + assert template_data["spec"]["description"] == "Updated CLI integration test template" + assert template_data["spec"]["defaultResources"]["requests"]["cpu"] == "300m" + except json.JSONDecodeError: + pytest.fail("Invalid JSON output from space template describe after update") + + finally: + if os.path.exists(temp_file): + os.unlink(temp_file) + + @pytest.mark.dependency(depends=["create"]) + def test_space_template_delete(self, runner, template_name): + """Test deleting a space template.""" + result = runner.invoke(space_template_delete, [ + "--name", template_name, + "--namespace", NAMESPACE + ]) + assert result.exit_code == 0, result.output + assert f"Requested deletion for Space template '{template_name}' in namespace '{NAMESPACE}'" in result.output + + def test_space_template_list_empty_namespace(self, runner): + """Test listing space templates in an empty namespace.""" + result = runner.invoke(space_template_list, [ + "--namespace", "nonexistent-namespace", + "--output", "table" + ]) + assert result.exit_code == 0, result.output + assert "No space templates found" in result.output + + def test_space_template_describe_nonexistent(self, runner): + """Test describing a nonexistent space template.""" + result = runner.invoke(space_template_describe, [ + "--name", "nonexistent-template", + "--namespace", NAMESPACE + ]) + assert result.exit_code != 0 + + def test_space_template_delete_nonexistent(self, runner): + """Test deleting a nonexistent space template.""" + result = runner.invoke(space_template_delete, [ + "--name", "nonexistent-template", + "--namespace", NAMESPACE + ]) + assert result.exit_code != 0 diff --git a/test/integration_tests/space/sdk/test_sdk_space.py b/test/integration_tests/space/sdk/test_sdk_space.py new file mode 100644 index 00000000..b34a4fb4 --- /dev/null +++ b/test/integration_tests/space/sdk/test_sdk_space.py @@ -0,0 +1,161 @@ +import time +import pytest +from sagemaker.hyperpod.space.hyperpod_space import HPSpace +from hyperpod_space_template.v1_0.model import SpaceConfig, ResourceRequirements +from test.integration_tests.utils import get_time_str + +# --------- Config --------- +NAMESPACE = "default" +SPACE_NAME = "space-sdk-integration-test-" + get_time_str() +DISPLAY_NAME = f"Space SDK Integration Test {get_time_str()}" + +# Basic configuration for testing +TIMEOUT_MINUTES = 2 +POLL_INTERVAL_SECONDS = 13 + +@pytest.fixture(scope="module") +def space_config(): + """Create a basic space configuration for testing.""" + return SpaceConfig( + name=SPACE_NAME, + display_name=DISPLAY_NAME, + namespace=NAMESPACE, + ) + +@pytest.fixture(scope="module") +def space_obj(space_config): + """Create an HPSpace instance for testing.""" + return HPSpace(config=space_config) + +@pytest.mark.dependency(name="create") +def test_create_space(space_obj): + """Test creating a space.""" + space_obj.create() + assert space_obj.config.name == SPACE_NAME + +@pytest.mark.dependency(depends=["create"]) +def test_list_spaces(): + """Test listing spaces.""" + spaces = HPSpace.list(namespace=NAMESPACE) + names = [space.config.name for space in spaces] + assert SPACE_NAME in names + +@pytest.mark.dependency(name="get", depends=["create"]) +def test_get_space(): + """Test getting a specific space.""" + space = HPSpace.get(name=SPACE_NAME, namespace=NAMESPACE) + assert space.config.name == SPACE_NAME + assert space.config.display_name == DISPLAY_NAME + +@pytest.mark.dependency(name="wait_until_running", depends=["create"]) +def test_wait_until_running(): + """Poll until space reaches Running status.""" + print(f"[INFO] Waiting for space '{SPACE_NAME}' to be Running...") + deadline = time.time() + (TIMEOUT_MINUTES * 60) + poll_count = 0 + + while time.time() < deadline: + poll_count += 1 + print(f"[DEBUG] Poll #{poll_count}: Checking space status...") + + try: + space = HPSpace.get(name=SPACE_NAME, namespace=NAMESPACE) + if space.status: + conditions = {c['type']: c['status'] for c in space.status['conditions']} + if conditions.get('Available', None) == "True": + print("[INFO] Space is Running.") + return + else: + print("[DEBUG] No status available yet") + + except Exception as e: + print(f"[ERROR] Exception during polling: {e}") + + time.sleep(POLL_INTERVAL_SECONDS) + + pytest.fail("[ERROR] Timed out waiting for space to be Running") + +@pytest.mark.dependency(name="update", depends=["wait_until_running"]) +def test_update_space(): + """Test updating space configuration.""" + space = HPSpace.get(name=SPACE_NAME, namespace=NAMESPACE) + + # Update resources + new_resources = ResourceRequirements( + requests={"cpu": "500m", "memory": "8Gi"}, + limits={"cpu": "800m", "memory": "8Gi"} + ) + + space.update(resources=new_resources) + + # Verify update + updated_space = HPSpace.get(name=SPACE_NAME, namespace=NAMESPACE) + assert updated_space.config.resources.requests["cpu"] == "500m" + assert updated_space.config.resources.limits["cpu"] == "800m" + +@pytest.mark.dependency(name="stop", depends=["update"]) +def test_stop_space(): + """Test stopping a space.""" + space = HPSpace.get(name=SPACE_NAME, namespace=NAMESPACE) + space.stop() + + # Verify the desired status is updated + updated_space = HPSpace.get(name=SPACE_NAME, namespace=NAMESPACE) + assert updated_space.config.desired_status == "Stopped" + +@pytest.mark.dependency(depends=["stop"]) +def test_start_space(): + """Test starting a space.""" + space = HPSpace.get(name=SPACE_NAME, namespace=NAMESPACE) + space.start() + + # Verify the desired status is updated + updated_space = HPSpace.get(name=SPACE_NAME, namespace=NAMESPACE) + assert updated_space.config.desired_status == "Running" + +@pytest.mark.dependency(depends=["create", "wait_until_running"]) +def test_list_pods(): + """Test listing pods associated with the space.""" + space = HPSpace.get(name=SPACE_NAME, namespace=NAMESPACE) + pods = space.list_pods() + # Pods may not exist immediately, so just verify the method works + assert isinstance(pods, list) + +@pytest.mark.dependency(depends=["create", "wait_until_running"]) +def test_get_logs(): + """Test getting logs from space pods.""" + space = HPSpace.get(name=SPACE_NAME, namespace=NAMESPACE) + + # First check if there are any pods + pods = space.list_pods() + if pods: + try: + logs = space.get_logs(pod_name=pods[0]) + assert isinstance(logs, str) + except Exception as e: + # Logs might not be available immediately, which is acceptable + print(f"[INFO] Logs not available yet: {e}") + else: + print("[INFO] No pods available for log retrieval") + +@pytest.mark.skip(reason="Skipping space access test due to an operator setup issue") +@pytest.mark.dependency(depends=["create", "wait_until_running"]) +def test_create_space_access(): + """Test creating space access for remote connection.""" + space = HPSpace.get(name=SPACE_NAME, namespace=NAMESPACE) + access_info = space.create_space_access(connection_type="vscode-remote") + assert "SpaceConnectionType" in access_info + assert "SpaceConnectionUrl" in access_info + assert access_info["SpaceConnectionType"] == "vscode-remote" + +@pytest.mark.dependency(depends=["create"]) +def test_delete_space(): + """Test deleting a space.""" + space = HPSpace.get(name=SPACE_NAME, namespace=NAMESPACE) + space.delete() + + # Verify space is deleted by checking it's not in the list + time.sleep(60) # Give some time for deletion to propagate + spaces = HPSpace.list(namespace=NAMESPACE) + names = [space.config.name for space in spaces] + assert SPACE_NAME not in names diff --git a/test/integration_tests/space/sdk/test_sdk_space_template.py b/test/integration_tests/space/sdk/test_sdk_space_template.py new file mode 100644 index 00000000..b96ccd14 --- /dev/null +++ b/test/integration_tests/space/sdk/test_sdk_space_template.py @@ -0,0 +1,141 @@ +import pytest +import tempfile +import os +import yaml +from sagemaker.hyperpod.space.hyperpod_space_template import HPSpaceTemplate +from test.integration_tests.utils import get_time_str + +# --------- Config --------- +NAMESPACE = "default" +TEMPLATE_NAME = "space-template-sdk-integ-test-" + get_time_str() + +# Sample template configuration aligned with template.yaml +TEMPLATE_CONFIG = { + "apiVersion": "workspace.jupyter.org/v1alpha1", + "kind": "WorkspaceTemplate", + "metadata": { + "name": TEMPLATE_NAME, + "namespace": NAMESPACE + }, + "spec": { + "displayName": f"Space Template SDK Integ Test {get_time_str()}", + "description": "Integration test template for Space Template SDK", + "defaultImage": "jk8s-application-jupyter-uv:latest", + "allowedImages": [ + "jk8s-application-jupyter-uv:latest" + ], + "defaultResources": { + "requests": { + "cpu": "200m", + "memory": "256Mi" + }, + "limits": { + "cpu": "500m", + "memory": "512Mi" + } + }, + "resourceBounds": { + "cpu": { + "min": "100m", + "max": "2" + }, + "memory": { + "min": "128Mi", + "max": "4Gi" + }, + "gpu": { + "min": "0", + "max": "1" + } + }, + "primaryStorage": { + "defaultSize": "1Gi", + "minSize": "100Mi", + "maxSize": "20Gi" + }, + "appType": "jupyter" + } +} + +@pytest.fixture(scope="module") +def template_yaml_file(): + """Create a temporary YAML file with template configuration.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + yaml.dump(TEMPLATE_CONFIG, f) + temp_file = f.name + + yield temp_file + + # Cleanup + if os.path.exists(temp_file): + os.unlink(temp_file) + +@pytest.fixture(scope="module") +def template_obj_from_file(template_yaml_file): + """Create HPSpaceTemplate from YAML file.""" + return HPSpaceTemplate(file_path=template_yaml_file) + +@pytest.fixture(scope="module") +def template_obj_from_dict(): + """Create HPSpaceTemplate from dictionary.""" + return HPSpaceTemplate(config_data=TEMPLATE_CONFIG) + +class TestHPSpaceTemplate: + """Integration tests for HyperPod Space Template SDK.""" + + @pytest.mark.dependency(name="create") + def test_create_template(self, template_obj_from_dict): + """Test creating a space template.""" + template_obj_from_dict.create() + assert template_obj_from_dict.name == TEMPLATE_NAME + + @pytest.mark.dependency(depends=["create"]) + def test_list_templates(self): + """Test listing space templates.""" + templates = HPSpaceTemplate.list(namespace=NAMESPACE) + names = [template.name for template in templates] + assert TEMPLATE_NAME in names + + @pytest.mark.dependency(name="get", depends=["create"]) + def test_get_template(self): + """Test getting a specific space template.""" + template = HPSpaceTemplate.get(name=TEMPLATE_NAME, namespace=NAMESPACE) + assert template.name == TEMPLATE_NAME + assert template.namespace == NAMESPACE + assert template.config_data["spec"]["defaultImage"] == "jk8s-application-jupyter-uv:latest" + + @pytest.mark.dependency(depends=["create", "get"]) + def test_update_template(self): + """Test updating a space template.""" + # Create updated config + updated_config = TEMPLATE_CONFIG.copy() + updated_config["spec"]["description"] = "Updated integration test template" + updated_config["spec"]["defaultResources"]["requests"]["cpu"] = "300m" + + # Create temporary file with updated config + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + yaml.dump(updated_config, f) + temp_file = f.name + + try: + template = HPSpaceTemplate.get(name=TEMPLATE_NAME, namespace=NAMESPACE) + template.update(file_path=temp_file) + + # Verify update + updated_template = HPSpaceTemplate.get(name=TEMPLATE_NAME, namespace=NAMESPACE) + assert updated_template.config_data["spec"]["description"] == "Updated integration test template" + assert updated_template.config_data["spec"]["defaultResources"]["requests"]["cpu"] == "300m" + finally: + if os.path.exists(temp_file): + os.unlink(temp_file) + + @pytest.mark.dependency(depends=["create"]) + def test_delete_template(self): + """Test deleting a space template.""" + template = HPSpaceTemplate.get(name=TEMPLATE_NAME, namespace=NAMESPACE) + template.delete() + + # Verify template is deleted + templates = HPSpaceTemplate.list(namespace=NAMESPACE) + names = [template.name for template in templates] + assert TEMPLATE_NAME not in names diff --git a/test/unit_tests/cli/test_space.py b/test/unit_tests/cli/test_space.py index f2073f82..8d9eaf63 100644 --- a/test/unit_tests/cli/test_space.py +++ b/test/unit_tests/cli/test_space.py @@ -144,9 +144,6 @@ def test_space_list_empty(self, mock_hp_space_class): def test_space_describe_yaml_output(self, mock_hp_space_class): """Test space describe with YAML output""" mock_resource = {"metadata": {"name": "test-space"}} - # mock_hp_space_instance = Mock() - # mock_hp_space_instance.raw_resource = mock_resource - # mock_hp_space_class.get.return_value = mock_hp_space_instance with patch('yaml.dump') as mock_yaml_dump: mock_yaml_dump.return_value = "yaml_output" From caf618deebda609e15b7ffab3b0348fd83000d98 Mon Sep 17 00:00:00 2001 From: Mohamed Zeidan Date: Thu, 20 Nov 2025 21:02:00 -0800 Subject: [PATCH 25/31] merge conflicts fixed --- .../v1_1/model.py | 100 +++++++--- .../v1_1/schema.json | 23 +++ .../v1_1/template.py | 17 +- .../hyperpod/cli/commands/training.py | 21 +- src/sagemaker/hyperpod/cli/hyp_cli.py | 2 + .../training/accelerator_partition_util.py | 125 ++++++++++++ src/sagemaker/hyperpod/training/constants.py | 131 ++++++++++++ .../hyperpod/training/hyperpod_pytorch_job.py | 59 +++++- .../training/quota_allocation_util.py | 188 ++++++------------ test/conftest.py | 7 + .../cli/test_accelerator_partition.py | 164 +++++++++++++++ .../sdk/test_sdk_resource_processing.py | 101 ++++++++++ .../cli/test_accelerator_partition_util.py | 87 ++++++++ .../cli/test_quota_allocation_util.py | 81 ++++++-- test/unit_tests/cli/test_training.py | 58 ++++++ .../training/test_hyperpod_pytorch_job.py | 82 ++++++++ .../test_pytorch_job_template_model.py | 128 ++++++------ 17 files changed, 1120 insertions(+), 254 deletions(-) create mode 100644 src/sagemaker/hyperpod/training/accelerator_partition_util.py create mode 100644 src/sagemaker/hyperpod/training/constants.py create mode 100644 test/integration_tests/training/cli/test_accelerator_partition.py create mode 100644 test/unit_tests/cli/test_accelerator_partition_util.py diff --git a/hyperpod-pytorch-job-template/hyperpod_pytorch_job_template/v1_1/model.py b/hyperpod-pytorch-job-template/hyperpod_pytorch_job_template/v1_1/model.py index 01cf8075..9011c44e 100644 --- a/hyperpod-pytorch-job-template/hyperpod_pytorch_job_template/v1_1/model.py +++ b/hyperpod-pytorch-job-template/hyperpod_pytorch_job_template/v1_1/model.py @@ -24,6 +24,9 @@ 'topology.k8s.aws/network-node-layer-3' } +from sagemaker.hyperpod.training.accelerator_partition_util import _validate_accelerator_partition_parameters +from sagemaker.hyperpod.training.constants import ALLOWED_ACCELERATOR_PARTITION_TYPES + class VolumeConfig(BaseModel): model_config = ConfigDict(extra="forbid") @@ -191,6 +194,20 @@ class PyTorchJobConfig(BaseModel): default=None, description="Limit for the amount of memory in GiB", ) + accelerator_partition_type: Optional[str] = Field( + default=None, + description="Type of accelerator partition" + ) + accelerator_partition_count: Optional[int] = Field( + default=None, + description="Number of accelerator partitions to request", + ge=1 + ) + accelerator_partition_limit: Optional[int] = Field( + default=None, + description="Limit for the number of accelerator partitions", + ge=1 + ) max_retry: Optional[int] = Field( default=None, @@ -325,6 +342,29 @@ def validate_topology_labels(cls, v): return v + @field_validator('accelerator_partition_type') + def validate_accelerator_partition_type(v): + """Basic validation for accelerator partition type.""" + if v not in ALLOWED_ACCELERATOR_PARTITION_TYPES: + raise ValueError(f"Accelerator partition type '{v}' must be one of: {', '.join(sorted(ALLOWED_ACCELERATOR_PARTITION_TYPES))}") + + return v + + @model_validator(mode='after') + def validate_accelerator_partition_options(self): + has_accelerator_partition_parameters = (self.accelerator_partition_type is not None or self.accelerator_partition_count is not None + or self.accelerator_partition_limit is not None) + + if not has_accelerator_partition_parameters: + return self + + valid, error = _validate_accelerator_partition_parameters( + self.accelerator_partition_type, self.accelerators, self.accelerators_limit, self.node_count, self.instance_type + ) + if not valid: + raise ValueError(error) + return self + def to_domain(self) -> Dict: """Convert flat config to domain model (HyperPodPytorchJobSpec)""" @@ -333,37 +373,32 @@ def build_dict(**kwargs): return {k: v for k, v in kwargs.items() if v is not None} # Build resources - requests_value = {} - limits_value = {} - - # Add GPU resources (respect accelerators regardless of instance_type) - if self.accelerators: - requests_value["nvidia.com/gpu"] = str(self.accelerators) - if self.accelerators_limit: - limits_value["nvidia.com/gpu"] = str(self.accelerators_limit) - - # Add CPU resources - if self.vcpu: - requests_value["cpu"] = str(self.vcpu) - if self.vcpu_limit: - limits_value["cpu"] = str(self.vcpu_limit) - - # Add memory resources - if self.memory: - requests_value["memory"] = f"{self.memory}Gi" - if self.memory_limit: - limits_value["memory"] = f"{self.memory_limit}Gi" - - # Add EFA for multi-node jobs - if self.node_count and self.node_count > 1: - requests_value["vpc.amazonaws.com/efa"] = "1" - limits_value["vpc.amazonaws.com/efa"] = "1" - - # Set default GPU to "0" only if no resources specified at all - if not requests_value: - requests_value = {"nvidia.com/gpu": "0"} - if not limits_value: - limits_value = {"nvidia.com/gpu": "0"} + if self.instance_type is None: + requests_value = limits_value = {"nvidia.com/gpu": "0"} + else: + if self.accelerator_partition_type: + partition_resource_key = f"nvidia.com/{self.accelerator_partition_type}" + requests_value = build_dict( + **{partition_resource_key: str(self.accelerator_partition_count)} if self.accelerator_partition_count else {}, + vcpu=str(self.vcpu) if self.vcpu else None, + memory=str(self.memory) if self.memory else None + ) + limits_value = build_dict( + **{partition_resource_key: str(self.accelerator_partition_limit)} if self.accelerator_partition_limit else {}, + vcpu=str(self.vcpu_limit) if self.vcpu_limit else None, + memory=str(self.memory_limit) if self.memory_limit else None + ) + else: + requests_value = build_dict( + accelerators=str(self.accelerators) if self.accelerators else None, + vcpu=str(self.vcpu) if self.vcpu else None, + memory=str(self.memory) if self.memory else None + ) + limits_value = build_dict( + accelerators=str(self.accelerators_limit) if self.accelerators_limit else None, + vcpu=str(self.vcpu_limit) if self.vcpu_limit else None, + memory=str(self.memory_limit) if self.memory_limit else None + ) # Build container container_kwargs = build_dict( @@ -397,7 +432,8 @@ def build_dict(**kwargs): node_selector = build_dict( **{"node.kubernetes.io/instance-type": self.instance_type} if self.instance_type else {}, **self.label_selector if self.label_selector else {}, - **{"deep-health-check-passed": "true"} if self.deep_health_check_passed_nodes_only else {} + **{"deep-health-check-passed": "true"} if self.deep_health_check_passed_nodes_only else {}, + **{"nvidia.com/mig.config.state": "success"} if self.accelerator_partition_type else {} ) # Build spec diff --git a/hyperpod-pytorch-job-template/hyperpod_pytorch_job_template/v1_1/schema.json b/hyperpod-pytorch-job-template/hyperpod_pytorch_job_template/v1_1/schema.json index 41abed18..f6dc79ac 100644 --- a/hyperpod-pytorch-job-template/hyperpod_pytorch_job_template/v1_1/schema.json +++ b/hyperpod-pytorch-job-template/hyperpod_pytorch_job_template/v1_1/schema.json @@ -305,6 +305,29 @@ "minimum": 0, "description": "Limit for the amount of memory in GiB" }, + "accelerator_partition_type": { + "type": "string", + "enum": [ + "mig-1g.5gb", "mig-1g.10gb", "mig-1g.18gb", "mig-1g.20gb", "mig-1g.23gb", "mig-1g.35gb", + "mig-1g.45gb", "mig-1g.47gb", "mig-2g.10gb", "mig-2g.20gb", "mig-2g.35gb", "mig-2g.45gb", + "mig-2g.47gb", "mig-3g.20gb", "mig-3g.40gb", "mig-3g.71gb", "mig-3g.90gb", "mig-3g.93gb", + "mig-4g.20gb", "mig-4g.40gb", "mig-4g.71gb", "mig-4g.90gb", "mig-4g.93gb", "mig-7g.40gb", + "mig-7g.80gb", "mig-7g.141gb", "mig-7g.180gb", "mig-7g.186gb" + ], + "default": null, + "description": "Type of accelerator partition", + "title": "Accelerator Partition Type" + }, + "accelerator_partition_count": { + "type": "integer", + "minimum": 0, + "description": "Number of accelerator partitions to request" + }, + "accelerator_partition_limit": { + "type": "integer", + "minimum": 0, + "description": "Limit for the number of accelerator partitions" + }, "priority": { "anyOf": [ { diff --git a/hyperpod-pytorch-job-template/hyperpod_pytorch_job_template/v1_1/template.py b/hyperpod-pytorch-job-template/hyperpod_pytorch_job_template/v1_1/template.py index 98b55475..1a61f6df 100644 --- a/hyperpod-pytorch-job-template/hyperpod_pytorch_job_template/v1_1/template.py +++ b/hyperpod-pytorch-job-template/hyperpod_pytorch_job_template/v1_1/template.py @@ -84,9 +84,11 @@ {%- endfor %} {%- endif %} resources: -{%- if accelerators or vcpu or memory or (node_count and node_count > 1) %} +{%- if accelerator_partition_count or accelerators or vcpu or memory %} requests: -{%- if accelerators %} +{%- if accelerator_partition_type and accelerator_partition_count %} + nvidia.com/{{ accelerator_partition_type }}: {{ accelerator_partition_count }} +{%- elif accelerators %} nvidia.com/gpu: {{ accelerators }} {%- endif %} {%- if vcpu %} @@ -102,9 +104,11 @@ requests: nvidia.com/gpu: "0" {%- endif %} -{%- if accelerators_limit or vcpu_limit or memory_limit or (node_count and node_count > 1) %} +{%- if accelerator_partition_limit or accelerators_limit or vcpu_limit or memory_limit %} limits: -{%- if accelerators_limit %} +{%- if accelerator_partition_type and accelerator_partition_limit %} + nvidia.com/{{ accelerator_partition_type }}: {{ accelerator_partition_limit }} +{%- elif accelerators_limit %} nvidia.com/gpu: {{ accelerators_limit }} {%- endif %} {%- if vcpu_limit %} @@ -120,7 +124,7 @@ limits: nvidia.com/gpu: "0" {%- endif %} -{%- if instance_type or label_selector or deep_health_check_passed_nodes_only %} +{%- if instance_type or label_selector or deep_health_check_passed_nodes_only or accelerator_partition_type %} nodeSelector: {%- if instance_type %} node.kubernetes.io/instance-type: {{ instance_type }} @@ -133,6 +137,9 @@ {%- if deep_health_check_passed_nodes_only %} deep-health-check-passed: "true" {%- endif %} +{%- if accelerator_partition_type %} + nvidia.com/mig.config.state: "success" +{%- endif %} {%- endif %} {%- if service_account_name %} serviceAccountName: {{ service_account_name }} diff --git a/src/sagemaker/hyperpod/cli/commands/training.py b/src/sagemaker/hyperpod/cli/commands/training.py index 9788cf1f..4376438c 100644 --- a/src/sagemaker/hyperpod/cli/commands/training.py +++ b/src/sagemaker/hyperpod/cli/commands/training.py @@ -1,5 +1,5 @@ import click -from sagemaker.hyperpod.training.hyperpod_pytorch_job import HyperPodPytorchJob +from sagemaker.hyperpod.training.hyperpod_pytorch_job import HyperPodPytorchJob, list_accelerator_partition_types from sagemaker.hyperpod.common.config import Metadata from sagemaker.hyperpod.cli.training_utils import generate_click_command from hyperpod_pytorch_job_template.registry import SCHEMA_REGISTRY @@ -336,3 +336,22 @@ def pytorch_exec(job_name: str, pod: str, all_pods: bool, namespace: str, contai except Exception as e: # Other errors (API, network, etc.) raise click.UsageError(f"Failed to execute command: {str(e)}") + +@click.command("list-accelerator-partition-type") +@click.option( + "--instance-type", + required=True, + help="The instance type to list accelerator partition types for." +) +@_hyperpod_telemetry_emitter(Feature.HYPERPOD_CLI, "list_accelerator_partition_types_cli") +@handle_cli_exceptions() +def list_accelerator_partition_type(instance_type: str): + """List available accelerator partition types for an instance type.""" + try: + partition_types = list_accelerator_partition_types(instance_type) + for partition_type in partition_types: + click.echo(partition_type) + except (ValueError, RuntimeError) as e: + raise click.UsageError(str(e)) + except Exception as e: + raise click.UsageError(f"Failed to execute command: {str(e)}") diff --git a/src/sagemaker/hyperpod/cli/hyp_cli.py b/src/sagemaker/hyperpod/cli/hyp_cli.py index bf4701e2..d33b5f85 100644 --- a/src/sagemaker/hyperpod/cli/hyp_cli.py +++ b/src/sagemaker/hyperpod/cli/hyp_cli.py @@ -20,6 +20,7 @@ pytorch_get_logs, pytorch_get_operator_logs, pytorch_exec, + list_accelerator_partition_type, ) from sagemaker.hyperpod.cli.commands.inference import ( js_create, @@ -265,6 +266,7 @@ def exec(): cli.add_command(get_cluster_context) cli.add_command(get_monitoring) # cli.add_command(create_cluster_stack) # Not supported yet +cli.add_command(list_accelerator_partition_type) exec.add_command(pytorch_exec) diff --git a/src/sagemaker/hyperpod/training/accelerator_partition_util.py b/src/sagemaker/hyperpod/training/accelerator_partition_util.py new file mode 100644 index 00000000..45b490c5 --- /dev/null +++ b/src/sagemaker/hyperpod/training/accelerator_partition_util.py @@ -0,0 +1,125 @@ +import os +import re +from sagemaker.hyperpod.cli.clients.kubernetes_client import KubernetesClient +from sagemaker.hyperpod.training.constants import ( + INSTANCE_RESOURCES, + INSTANCE_TYPE_MIG_PROFILES, + VALIDATE_PROFILE_IN_CLUSTER, + ALLOWED_ACCELERATOR_PARTITION_TYPES +) +from typing import Optional, Tuple + + + +def _validate_accelerator_partition_parameters(accelerator_partition_type: Optional[str], + accelerators: Optional[int], + accelerators_limit: Optional[int], + node_count: Optional[int], + instance_type: Optional[str]) -> Tuple[bool, str]: + """Basic accelerator partition validation without cluster checks.""" + if not accelerator_partition_type: + return False, "accelerator_partition_type must be specified to use accelerator partitions." + for param, name in [(accelerators, "accelerators"), (accelerators_limit, "accelerators_limit"), (node_count, "node_count")]: + if param is not None and param > 0: + return False, f"accelerator_partition_type cannot be used together with {name}." + + if instance_type not in INSTANCE_TYPE_MIG_PROFILES: + return False, f"Instance type '{instance_type}' does not support accelerator partitions." + if accelerator_partition_type not in ALLOWED_ACCELERATOR_PARTITION_TYPES: + return False, f"Accelerator partition type '{accelerator_partition_type}' must be one of: {', '.join(sorted(ALLOWED_ACCELERATOR_PARTITION_TYPES))}" + allowed_profiles = INSTANCE_TYPE_MIG_PROFILES.get(instance_type, []) + if accelerator_partition_type not in allowed_profiles: + return False, (f"Accelerator partition '{accelerator_partition_type}' is not supported on instance type '{instance_type}'. " + f"Allowed partitions: {', '.join(sorted(allowed_profiles))}") + return True, "" + +def _validate_accelerator_partition(accelerator_partition_type: Optional[str], + accelerators: Optional[int], + accelerators_limit: Optional[int], + node_count: Optional[int], + instance_type: Optional[str]) -> Tuple[bool, str]: + valid, err = _validate_accelerator_partition_parameters(accelerator_partition_type, accelerators, accelerators_limit, node_count, instance_type) + if not valid: + return valid, err + + if os.getenv(VALIDATE_PROFILE_IN_CLUSTER) == "false": + return True, "" + + # Validate accelerator partition in cluster + resource_key = f"nvidia.com/{accelerator_partition_type}" + for node in KubernetesClient().get_core_v1_api().list_node().items: + if node.status: + allocatable_accelerator_partitions = node.status.allocatable.get(resource_key) + if allocatable_accelerator_partitions and int(allocatable_accelerator_partitions) > 0: + return True, "" + return False, (f"accelerator partition type '{accelerator_partition_type}' does not exist in this cluster. " + f"Use 'hyp list-accelerator-partition-type' to check for available resources.") + +def _get_accelerator_partition_defaults(instance_type: str, + accelerator_partition_type: str, + accelerator_partition_count: int) -> dict: + """Calculate default CPU/memory for accelerator partitions when both CPU and memory are not provided.""" + instance = INSTANCE_RESOURCES.get(instance_type, {}) + instance_vcpu = instance.get("cpu", 0) + instance_memory = instance.get("memory", 0) + + gpu_slices_per_profile = _extract_gpu_slices_from_accelerator_partition_type(accelerator_partition_type) + total_gpus_per_instance = instance.get("gpu", 0) + MAX_GPU_SLICES = 7 + + ratio = (accelerator_partition_count * gpu_slices_per_profile) / (total_gpus_per_instance * MAX_GPU_SLICES) + + calculated_vcpu = float(int(ratio * instance_vcpu)) + calculated_memory = float(int(ratio * instance_memory)) + + return { + "cpu": str(calculated_vcpu), + "memory": f"{calculated_memory}Gi", + } + + +def _get_accelerator_partition(requests: dict, limits: dict) -> tuple: + accelerator_partition_resource_key = None + accelerator_partition_type = None + accelerator_partition_count = None + accelerator_partition_limit = None + + for key in requests.keys(): + if key.startswith('nvidia.com/mig-'): + accelerator_partition_resource_key = key + accelerator_partition_type = key.replace('nvidia.com/', '') + accelerator_partition_count = int(requests.get(key)) + break + + if not accelerator_partition_resource_key: + for key in limits.keys(): + if key.startswith('nvidia.com/mig-'): + accelerator_partition_resource_key = key + accelerator_partition_type = key.replace('nvidia.com/', '') + break + + if accelerator_partition_resource_key and limits.get(accelerator_partition_resource_key): + accelerator_partition_limit = int(limits.get(accelerator_partition_resource_key)) + + return accelerator_partition_type, accelerator_partition_count, accelerator_partition_limit + +def _set_default_accelerator_partition_val(accelerator_partition_count: Optional[int], accelerator_partition_limit: Optional[int]) -> Tuple[Optional[int], Optional[int]]: + if accelerator_partition_count is None and accelerator_partition_limit is None: + return None, None + elif accelerator_partition_count is not None and accelerator_partition_limit is None: + return accelerator_partition_count, accelerator_partition_count + elif accelerator_partition_count is None and accelerator_partition_limit is not None: + return accelerator_partition_limit, accelerator_partition_limit + else: + return accelerator_partition_count, accelerator_partition_limit + +def _extract_gpu_slices_from_accelerator_partition_type(partition_type: str) -> int: + """Extract GPU slices from MIG partition type (e.g., 'mig-1g.5gb' -> 1, 'mig-7g.40gb' -> 7).""" + if not partition_type.startswith('mig-'): + raise ValueError(f"Invalid MIG partition type: {partition_type}") + + match = re.search(r'mig-(\d+)g\.[\d.]+gb', partition_type) + if not match: + raise ValueError(f"Invalid MIG partition format: {partition_type}") + + return int(match.group(1)) diff --git a/src/sagemaker/hyperpod/training/constants.py b/src/sagemaker/hyperpod/training/constants.py new file mode 100644 index 00000000..32fdc8a2 --- /dev/null +++ b/src/sagemaker/hyperpod/training/constants.py @@ -0,0 +1,131 @@ +# TODO: currently there is no API for instances and they are hardcoded; post GA work with partner team on adding support for such API +INSTANCE_RESOURCES = { + "ml.p4d.24xlarge": {"cpu": 96, "gpu": 8, "trainium": 0, "memory": 1152}, + "ml.p4de.24xlarge": {"cpu": 96, "gpu": 8, "trainium": 0, "memory": 1152}, + "ml.p5.48xlarge": {"cpu": 192, "gpu": 8, "trainium": 0, "memory": 2048}, + "ml.p5.4xlarge": {"cpu": 16, "gpu": 1, "trainium": 0, "memory": 256}, + "ml.trn1.32xlarge": {"cpu": 128, "gpu": 0, "trainium": 16, "memory": 512}, + "ml.trn1n.32xlarge": {"cpu": 128, "gpu": 0, "trainium": 16, "memory": 512}, + "ml.g5.xlarge": {"cpu": 4, "gpu": 1, "trainium": 0, "memory": 16}, + "ml.g5.2xlarge": {"cpu": 8, "gpu": 1, "trainium": 0, "memory": 32}, + "ml.g5.4xlarge": {"cpu": 16, "gpu": 1, "trainium": 0, "memory": 64}, + "ml.g5.8xlarge": {"cpu": 32, "gpu": 1, "trainium": 0, "memory": 128}, + "ml.g5.12xlarge": {"cpu": 48, "gpu": 4, "trainium": 0, "memory": 192}, + "ml.g5.16xlarge": {"cpu": 64, "gpu": 1, "trainium": 0, "memory": 256}, + "ml.g5.24xlarge": {"cpu": 96, "gpu": 4, "trainium": 0, "memory": 384}, + "ml.g5.48xlarge": {"cpu": 192, "gpu": 8, "trainium": 0, "memory": 768}, + "ml.g6.xlarge": {"cpu": 4, "gpu": 1, "trainium": 0, "memory": 16}, + "ml.g6.2xlarge": {"cpu": 8, "gpu": 1, "trainium": 0, "memory": 32}, + "ml.g6.4xlarge": {"cpu": 16, "gpu": 1, "trainium": 0, "memory": 64}, + "ml.g6.8xlarge": {"cpu": 32, "gpu": 1, "trainium": 0, "memory": 128}, + "ml.g6.16xlarge": {"cpu": 64, "gpu": 1, "trainium": 0, "memory": 256}, + "ml.g6.12xlarge": {"cpu": 48, "gpu": 4, "trainium": 0, "memory": 192}, + "ml.g6.24xlarge": {"cpu": 96, "gpu": 4, "trainium": 0, "memory": 384}, + "ml.g6.48xlarge": {"cpu": 192, "gpu": 8, "trainium": 0, "memory": 768}, + "ml.gr6.4xlarge": {"cpu": 16, "gpu": 1, "trainium": 0, "memory": 128}, + "ml.gr6.8xlarge": {"cpu": 32, "gpu": 1, "trainium": 0, "memory": 256}, + "ml.g6e.xlarge": {"cpu": 4, "gpu": 1, "trainium": 0, "memory": 32}, + "ml.g6e.2xlarge": {"cpu": 8, "gpu": 1, "trainium": 0, "memory": 64}, + "ml.g6e.4xlarge": {"cpu": 16, "gpu": 1, "trainium": 0, "memory": 128}, + "ml.g6e.8xlarge": {"cpu": 32, "gpu": 1, "trainium": 0, "memory": 256}, + "ml.g6e.16xlarge": {"cpu": 64, "gpu": 1, "trainium": 0, "memory": 512}, + "ml.g6e.12xlarge": {"cpu": 48, "gpu": 4, "trainium": 0, "memory": 384}, + "ml.g6e.24xlarge": {"cpu": 96, "gpu": 4, "trainium": 0, "memory": 768}, + "ml.g6e.48xlarge": {"cpu": 192, "gpu": 8, "trainium": 0, "memory": 1536}, + "ml.p5e.48xlarge": {"cpu": 192, "gpu": 8, "trainium": 0, "memory": 2048}, + "ml.p5en.48xlarge": {"cpu": 192, "gpu": 8, "trainium": 0, "memory": 2048}, + "ml.trn2.48xlarge": {"cpu": 192, "gpu": 0, "trainium": 16, "memory": 2048}, + "ml.p6e-gb200.36xlarge": {"cpu": 144, "gpu": 4, "trainium": 0, "memory": 960}, + "ml.p6-b200.48xlarge": {"cpu": 192, "gpu": 8, "trainium": 0, "memory": 2024}, + "ml.c5.large": {"cpu": 2, "gpu": 0, "trainium": 0, "memory": 4}, + "ml.c5.xlarge": {"cpu": 4, "gpu": 0, "trainium": 0, "memory": 8}, + "ml.c5.2xlarge": {"cpu": 8, "gpu": 0, "trainium": 0, "memory": 16}, + "ml.c5.4xlarge": {"cpu": 16, "gpu": 0, "trainium": 0, "memory": 32}, + "ml.c5.9xlarge": {"cpu": 36, "gpu": 0, "trainium": 0, "memory": 72}, + "ml.c5.12xlarge": {"cpu": 48, "gpu": 0, "trainium": 0, "memory": 96}, + "ml.c5.18xlarge": {"cpu": 72, "gpu": 0, "trainium": 0, "memory": 144}, + "ml.c5.24xlarge": {"cpu": 96, "gpu": 0, "trainium": 0, "memory": 192}, + "ml.c5n.large": {"cpu": 2, "gpu": 0, "trainium": 0, "memory": 5}, + "ml.c5n.2xlarge": {"cpu": 8, "gpu": 0, "trainium": 0, "memory": 21}, + "ml.c5n.4xlarge": {"cpu": 16, "gpu": 0, "trainium": 0, "memory": 42}, + "ml.c5n.9xlarge": {"cpu": 36, "gpu": 0, "trainium": 0, "memory": 96}, + "ml.c5n.18xlarge": {"cpu": 72, "gpu": 0, "trainium": 0, "memory": 192}, + "ml.m5.large": {"cpu": 2, "gpu": 0, "trainium": 0, "memory": 8}, + "ml.m5.xlarge": {"cpu": 4, "gpu": 0, "trainium": 0, "memory": 16}, + "ml.m5.2xlarge": {"cpu": 8, "gpu": 0, "trainium": 0, "memory": 32}, + "ml.m5.4xlarge": {"cpu": 16, "gpu": 0, "trainium": 0, "memory": 64}, + "ml.m5.8xlarge": {"cpu": 32, "gpu": 0, "trainium": 0, "memory": 128}, + "ml.m5.12xlarge": {"cpu": 48, "gpu": 0, "trainium": 0, "memory": 192}, + "ml.m5.16xlarge": {"cpu": 64, "gpu": 0, "trainium": 0, "memory": 256}, + "ml.m5.24xlarge": {"cpu": 96, "gpu": 0, "trainium": 0, "memory": 384}, + "ml.t3.medium": {"cpu": 2, "gpu": 0, "trainium": 0, "memory": 4}, + "ml.t3.large": {"cpu": 2, "gpu": 0, "trainium": 0, "memory": 8}, + "ml.t3.xlarge": {"cpu": 4, "gpu": 0, "trainium": 0, "memory": 16}, + "ml.t3.2xlarge": {"cpu": 8, "gpu": 0, "trainium": 0, "memory": 32}, + "ml.c6i.large": {"cpu": 2, "gpu": 0, "trainium": 0, "memory": 4}, + "ml.c6i.xlarge": {"cpu": 4, "gpu": 0, "trainium": 0, "memory": 8}, + "ml.c6i.2xlarge": {"cpu": 8, "gpu": 0, "trainium": 0, "memory": 16}, + "ml.c6i.4xlarge": {"cpu": 16, "gpu": 0, "trainium": 0, "memory": 32}, + "ml.c6i.8xlarge": {"cpu": 32, "gpu": 0, "trainium": 0, "memory": 64}, + "ml.c6i.12xlarge": {"cpu": 48, "gpu": 0, "trainium": 0, "memory": 96}, + "ml.c6i.16xlarge": {"cpu": 64, "gpu": 0, "trainium": 0, "memory": 128}, + "ml.c6i.24xlarge": {"cpu": 96, "gpu": 0, "trainium": 0, "memory": 192}, + "ml.c6i.32xlarge": {"cpu": 128, "gpu": 0, "trainium": 0, "memory": 256}, + "ml.m6i.large": {"cpu": 2, "gpu": 0, "trainium": 0, "memory": 8}, + "ml.m6i.xlarge": {"cpu": 4, "gpu": 0, "trainium": 0, "memory": 16}, + "ml.m6i.2xlarge": {"cpu": 8, "gpu": 0, "trainium": 0, "memory": 32}, + "ml.m6i.4xlarge": {"cpu": 16, "gpu": 0, "trainium": 0, "memory": 64}, + "ml.m6i.8xlarge": {"cpu": 32, "gpu": 0, "trainium": 0, "memory": 128}, + "ml.m6i.12xlarge": {"cpu": 48, "gpu": 0, "trainium": 0, "memory": 192}, + "ml.m6i.16xlarge": {"cpu": 64, "gpu": 0, "trainium": 0, "memory": 256}, + "ml.m6i.24xlarge": {"cpu": 96, "gpu": 0, "trainium": 0, "memory": 384}, + "ml.m6i.32xlarge": {"cpu": 128, "gpu": 0, "trainium": 0, "memory": 512}, + "ml.r6i.large": {"cpu": 2, "gpu": 0, "trainium": 0, "memory": 16}, + "ml.r6i.xlarge": {"cpu": 4, "gpu": 0, "trainium": 0, "memory": 32}, + "ml.r6i.2xlarge": {"cpu": 8, "gpu": 0, "trainium": 0, "memory": 64}, + "ml.r6i.4xlarge": {"cpu": 16, "gpu": 0, "trainium": 0, "memory": 128}, + "ml.r6i.8xlarge": {"cpu": 32, "gpu": 0, "trainium": 0, "memory": 256}, + "ml.r6i.12xlarge": {"cpu": 48, "gpu": 0, "trainium": 0, "memory": 384}, + "ml.r6i.16xlarge": {"cpu": 64, "gpu": 0, "trainium": 0, "memory": 512}, + "ml.r6i.24xlarge": {"cpu": 96, "gpu": 0, "trainium": 0, "memory": 768}, + "ml.r6i.32xlarge": {"cpu": 128, "gpu": 0, "trainium": 0, "memory": 1024}, + "ml.m7i.large": {"cpu": 2, "gpu": 0, "trainium": 0, "memory": 8}, + "ml.m7i.xlarge": {"cpu": 4, "gpu": 0, "trainium": 0, "memory": 16}, + "ml.m7i.2xlarge": {"cpu": 8, "gpu": 0, "trainium": 0, "memory": 32}, + "ml.m7i.4xlarge": {"cpu": 16, "gpu": 0, "trainium": 0, "memory": 64}, + "ml.m7i.8xlarge": {"cpu": 32, "gpu": 0, "trainium": 0, "memory": 128}, + "ml.m7i.12xlarge": {"cpu": 48, "gpu": 0, "trainium": 0, "memory": 192}, + "ml.m7i.16xlarge": {"cpu": 64, "gpu": 0, "trainium": 0, "memory": 256}, + "ml.m7i.24xlarge": {"cpu": 96, "gpu": 0, "trainium": 0, "memory": 384}, + "ml.m7i.48xlarge": {"cpu": 192, "gpu": 0, "trainium": 0, "memory": 768}, + "ml.r7i.large": {"cpu": 2, "gpu": 0, "trainium": 0, "memory": 16}, + "ml.r7i.xlarge": {"cpu": 4, "gpu": 0, "trainium": 0, "memory": 32}, + "ml.r7i.2xlarge": {"cpu": 8, "gpu": 0, "trainium": 0, "memory": 64}, + "ml.r7i.4xlarge": {"cpu": 16, "gpu": 0, "trainium": 0, "memory": 128}, + "ml.r7i.8xlarge": {"cpu": 32, "gpu": 0, "trainium": 0, "memory": 256}, + "ml.r7i.12xlarge": {"cpu": 48, "gpu": 0, "trainium": 0, "memory": 384}, + "ml.r7i.16xlarge": {"cpu": 64, "gpu": 0, "trainium": 0, "memory": 512}, + "ml.r7i.24xlarge": {"cpu": 96, "gpu": 0, "trainium": 0, "memory": 768}, + "ml.r7i.48xlarge": {"cpu": 192, "gpu": 0, "trainium": 0, "memory": 1536}, + "ml.i3en.large": {"cpu": 2, "gpu": 0, "trainium": 0, "memory": 16}, + "ml.i3en.xlarge": {"cpu": 4, "gpu": 0, "trainium": 0, "memory": 32}, + "ml.i3en.2xlarge": {"cpu": 8, "gpu": 0, "trainium": 0, "memory": 64}, + "ml.i3en.3xlarge": {"cpu": 12, "gpu": 0, "trainium": 0, "memory": 96}, + "ml.i3en.6xlarge": {"cpu": 24, "gpu": 0, "trainium": 0, "memory": 192}, + "ml.i3en.12xlarge": {"cpu": 48, "gpu": 0, "trainium": 0, "memory": 384}, + "ml.i3en.24xlarge": {"cpu": 96, "gpu": 0, "trainium": 0, "memory": 768} +} + +# MIG profiles by instance type +INSTANCE_TYPE_MIG_PROFILES = { + 'ml.p4d.24xlarge': ['mig-1g.5gb', 'mig-1g.10gb', 'mig-2g.10gb', 'mig-3g.20gb', 'mig-4g.20gb', 'mig-7g.40gb'], + 'ml.p4de.24xlarge': ['mig-1g.10gb', 'mig-1g.20gb', 'mig-2g.20gb', 'mig-3g.40gb', 'mig-4g.40gb', 'mig-7g.80gb'], + 'ml.p5.48xlarge': ['mig-1g.10gb', 'mig-1g.20gb', 'mig-2g.20gb', 'mig-3g.40gb', 'mig-4g.40gb', 'mig-7g.80gb'], + 'ml.p5e.48xlarge': ['mig-1g.18gb', 'mig-1g.35gb', 'mig-2g.35gb', 'mig-3g.71gb', 'mig-4g.71gb', 'mig-7g.141gb'], + 'ml.p5en.48xlarge': ['mig-1g.18gb', 'mig-1g.35gb', 'mig-2g.35gb', 'mig-3g.71gb', 'mig-4g.71gb', 'mig-7g.141gb'], + 'p6-b200.48xlarge': ['mig-1g.23gb', 'mig-1g.45gb', 'mig-2g.45gb', 'mig-3g.90gb', 'mig-4g.90gb', 'mig-7g.180gb'], + 'ml.p6e-gb200.36xlarge': ['mig-1g.23gb', 'mig-1g.47gb', 'mig-2g.47gb', 'mig-3g.93gb', 'mig-4g.93gb', 'mig-7g.186gb'] +} + +ALLOWED_ACCELERATOR_PARTITION_TYPES = set().union(*INSTANCE_TYPE_MIG_PROFILES.values()) +VALIDATE_PROFILE_IN_CLUSTER = "VALIDATE_PROFILE_IN_CLUSTER" \ No newline at end of file diff --git a/src/sagemaker/hyperpod/training/hyperpod_pytorch_job.py b/src/sagemaker/hyperpod/training/hyperpod_pytorch_job.py index 6a5847ca..356767f2 100644 --- a/src/sagemaker/hyperpod/training/hyperpod_pytorch_job.py +++ b/src/sagemaker/hyperpod/training/hyperpod_pytorch_job.py @@ -30,7 +30,12 @@ _set_default_accelerators_val, _validate_accelerators_inputs, _resolve_default_cpu_values, - _trim_resource_requests + _trim_resource_requests, +) +from sagemaker.hyperpod.training.constants import INSTANCE_RESOURCES, INSTANCE_TYPE_MIG_PROFILES +from sagemaker.hyperpod.training.accelerator_partition_util import ( + _get_accelerator_partition, + _set_default_accelerator_partition_val, ) TRAINING_GROUP = "sagemaker.amazonaws.com" @@ -141,13 +146,20 @@ def _process_replica_resources(cls, data): acc_req, acc_lim = _set_default_accelerators_val(instance_type, accelerators, accelerators_limit) _validate_accelerators_inputs(instance_type, acc_req, acc_lim) - # Validate configuration - valid, error = _is_valid(vcpu, memory, acc_req, node_count, instance_type) + accelerator_partition_type, accelerator_partition_count, accelerator_partition_limit = ( + _get_accelerator_partition(requests, limits) + ) + + # Validate configuration + valid, error = _is_valid(vcpu, memory, acc_req, acc_lim, node_count, instance_type, accelerator_partition_type, + accelerator_partition_count, accelerator_partition_limit) if not valid: raise ValueError(error) + acc_partition_req, acc_partition_lim = _set_default_accelerator_partition_val(accelerator_partition_count, accelerator_partition_limit) + # Calculate resource values - requests_values = _get_resources_from_compute_quotas(instance_type, vcpu, memory, acc_req) + requests_values = _get_resources_from_compute_quotas(instance_type, vcpu, memory, acc_req, accelerator_partition_type, acc_partition_req) if requests_values is None: requests_values = _get_resources_from_instance(instance_type, node_count=1) _trim_resource_requests(instance_type, requests_values) @@ -156,7 +168,7 @@ def _process_replica_resources(cls, data): elif NEURON_RESOURCE_KEY in requests_values: acc_lim = requests_values[NEURON_RESOURCE_KEY] - limits_values = _get_limits(instance_type, vcpu_limit, memory_limit, acc_lim) + limits_values = _get_limits(instance_type, vcpu_limit, memory_limit, acc_lim, accelerator_partition_type, acc_partition_lim) _resolve_default_memory_values(instance_type, requests_values, limits_values) _resolve_default_cpu_values(instance_type, requests_values) @@ -670,6 +682,43 @@ def get_operator_logs(cls, since_hours: float): return logs +def list_accelerator_partition_types(instance_type: str) -> List[str]: + """List available accelerator partition types for an instance type.""" + config.load_kube_config() + + if instance_type not in INSTANCE_RESOURCES: + raise ValueError(f"Invalid instance type '{instance_type}'") + + if instance_type not in INSTANCE_TYPE_MIG_PROFILES: + raise ValueError(f"Instance type '{instance_type}' does not support accelerator partitions") + + try: + possible_partition_types = set(INSTANCE_TYPE_MIG_PROFILES[instance_type]) + available_partition_types = set() + + v1 = client.CoreV1Api() + label_selector = f"node.kubernetes.io/instance-type={instance_type}" + nodes = v1.list_node(label_selector=label_selector).items + + for node in nodes: + if not node.status or not node.status.allocatable: + continue + + for partition_type in possible_partition_types: + if partition_type in available_partition_types: + continue + + resource_key = f"nvidia.com/{partition_type}" + allocatable_partitions = node.status.allocatable.get(resource_key) + if allocatable_partitions and int(allocatable_partitions) > 0: + available_partition_types.add(partition_type) + + return sorted(available_partition_types) + + except Exception as e: + raise RuntimeError(f"Failed to query cluster for accelerator partitions: {e}") + + def _load_hp_job(response: dict) -> HyperPodPytorchJob: spec = _HyperPodPytorchJob.model_validate(response["spec"], by_name=True) diff --git a/src/sagemaker/hyperpod/training/quota_allocation_util.py b/src/sagemaker/hyperpod/training/quota_allocation_util.py index d34fff12..291bf3c2 100644 --- a/src/sagemaker/hyperpod/training/quota_allocation_util.py +++ b/src/sagemaker/hyperpod/training/quota_allocation_util.py @@ -16,127 +16,10 @@ setup_logger ) from typing import Optional, Tuple - +from sagemaker.hyperpod.training.accelerator_partition_util import _validate_accelerator_partition, _get_accelerator_partition_defaults +from sagemaker.hyperpod.training.constants import INSTANCE_RESOURCES logger = setup_logger(__name__) -# TODO: currently there is no API for instances and they are hardcoded; post GA work with partner team on adding support for such API -INSTANCE_RESOURCES = { - "ml.p4d.24xlarge": {"cpu": 96, "gpu": 8, "trainium": 0, "memory": 1152}, - "ml.p4de.24xlarge": {"cpu": 96, "gpu": 8, "trainium": 0, "memory": 1152}, - "ml.p5.48xlarge": {"cpu": 192, "gpu": 8, "trainium": 0, "memory": 2048}, - "ml.p5.4xlarge": {"cpu": 16, "gpu": 1, "trainium": 0, "memory": 256}, - "ml.trn1.32xlarge": {"cpu": 128, "gpu": 0, "trainium": 16, "memory": 512}, - "ml.trn1n.32xlarge": {"cpu": 128, "gpu": 0, "trainium": 16, "memory": 512}, - "ml.g5.xlarge": {"cpu": 4, "gpu": 1, "trainium": 0, "memory": 16}, - "ml.g5.2xlarge": {"cpu": 8, "gpu": 1, "trainium": 0, "memory": 32}, - "ml.g5.4xlarge": {"cpu": 16, "gpu": 1, "trainium": 0, "memory": 64}, - "ml.g5.8xlarge": {"cpu": 32, "gpu": 1, "trainium": 0, "memory": 128}, - "ml.g5.12xlarge": {"cpu": 48, "gpu": 4, "trainium": 0, "memory": 192}, - "ml.g5.16xlarge": {"cpu": 64, "gpu": 1, "trainium": 0, "memory": 256}, - "ml.g5.24xlarge": {"cpu": 96, "gpu": 4, "trainium": 0, "memory": 384}, - "ml.g5.48xlarge": {"cpu": 192, "gpu": 8, "trainium": 0, "memory": 768}, - "ml.g6.xlarge": {"cpu": 4, "gpu": 1, "trainium": 0, "memory": 16}, - "ml.g6.2xlarge": {"cpu": 8, "gpu": 1, "trainium": 0, "memory": 32}, - "ml.g6.4xlarge": {"cpu": 16, "gpu": 1, "trainium": 0, "memory": 64}, - "ml.g6.8xlarge": {"cpu": 32, "gpu": 1, "trainium": 0, "memory": 128}, - "ml.g6.16xlarge": {"cpu": 64, "gpu": 1, "trainium": 0, "memory": 256}, - "ml.g6.12xlarge": {"cpu": 48, "gpu": 4, "trainium": 0, "memory": 192}, - "ml.g6.24xlarge": {"cpu": 96, "gpu": 4, "trainium": 0, "memory": 384}, - "ml.g6.48xlarge": {"cpu": 192, "gpu": 8, "trainium": 0, "memory": 768}, - "ml.gr6.4xlarge": {"cpu": 16, "gpu": 1, "trainium": 0, "memory": 128}, - "ml.gr6.8xlarge": {"cpu": 32, "gpu": 1, "trainium": 0, "memory": 256}, - "ml.g6e.xlarge": {"cpu": 4, "gpu": 1, "trainium": 0, "memory": 32}, - "ml.g6e.2xlarge": {"cpu": 8, "gpu": 1, "trainium": 0, "memory": 64}, - "ml.g6e.4xlarge": {"cpu": 16, "gpu": 1, "trainium": 0, "memory": 128}, - "ml.g6e.8xlarge": {"cpu": 32, "gpu": 1, "trainium": 0, "memory": 256}, - "ml.g6e.16xlarge": {"cpu": 64, "gpu": 1, "trainium": 0, "memory": 512}, - "ml.g6e.12xlarge": {"cpu": 48, "gpu": 4, "trainium": 0, "memory": 384}, - "ml.g6e.24xlarge": {"cpu": 96, "gpu": 4, "trainium": 0, "memory": 768}, - "ml.g6e.48xlarge": {"cpu": 192, "gpu": 8, "trainium": 0, "memory": 1536}, - "ml.p5e.48xlarge": {"cpu": 192, "gpu": 8, "trainium": 0, "memory": 2048}, - "ml.p5en.48xlarge": {"cpu": 192, "gpu": 8, "trainium": 0, "memory": 2048}, - "ml.trn2.48xlarge": {"cpu": 192, "gpu": 0, "trainium": 16, "memory": 2048}, - "ml.p6e-gb200.36xlarge": {"cpu": 144, "gpu": 4, "trainium": 0, "memory": 960}, - "ml.p6-b200.48xlarge": {"cpu": 192, "gpu": 8, "trainium": 0, "memory": 2024}, - "ml.c5.large": {"cpu": 2, "gpu": 0, "trainium": 0, "memory": 4}, - "ml.c5.xlarge": {"cpu": 4, "gpu": 0, "trainium": 0, "memory": 8}, - "ml.c5.2xlarge": {"cpu": 8, "gpu": 0, "trainium": 0, "memory": 16}, - "ml.c5.4xlarge": {"cpu": 16, "gpu": 0, "trainium": 0, "memory": 32}, - "ml.c5.9xlarge": {"cpu": 36, "gpu": 0, "trainium": 0, "memory": 72}, - "ml.c5.12xlarge": {"cpu": 48, "gpu": 0, "trainium": 0, "memory": 96}, - "ml.c5.18xlarge": {"cpu": 72, "gpu": 0, "trainium": 0, "memory": 144}, - "ml.c5.24xlarge": {"cpu": 96, "gpu": 0, "trainium": 0, "memory": 192}, - "ml.c5n.large": {"cpu": 2, "gpu": 0, "trainium": 0, "memory": 5}, - "ml.c5n.2xlarge": {"cpu": 8, "gpu": 0, "trainium": 0, "memory": 21}, - "ml.c5n.4xlarge": {"cpu": 16, "gpu": 0, "trainium": 0, "memory": 42}, - "ml.c5n.9xlarge": {"cpu": 36, "gpu": 0, "trainium": 0, "memory": 96}, - "ml.c5n.18xlarge": {"cpu": 72, "gpu": 0, "trainium": 0, "memory": 192}, - "ml.m5.large": {"cpu": 2, "gpu": 0, "trainium": 0, "memory": 8}, - "ml.m5.xlarge": {"cpu": 4, "gpu": 0, "trainium": 0, "memory": 16}, - "ml.m5.2xlarge": {"cpu": 8, "gpu": 0, "trainium": 0, "memory": 32}, - "ml.m5.4xlarge": {"cpu": 16, "gpu": 0, "trainium": 0, "memory": 64}, - "ml.m5.8xlarge": {"cpu": 32, "gpu": 0, "trainium": 0, "memory": 128}, - "ml.m5.12xlarge": {"cpu": 48, "gpu": 0, "trainium": 0, "memory": 192}, - "ml.m5.16xlarge": {"cpu": 64, "gpu": 0, "trainium": 0, "memory": 256}, - "ml.m5.24xlarge": {"cpu": 96, "gpu": 0, "trainium": 0, "memory": 384}, - "ml.t3.medium": {"cpu": 2, "gpu": 0, "trainium": 0, "memory": 4}, - "ml.t3.large": {"cpu": 2, "gpu": 0, "trainium": 0, "memory": 8}, - "ml.t3.xlarge": {"cpu": 4, "gpu": 0, "trainium": 0, "memory": 16}, - "ml.t3.2xlarge": {"cpu": 8, "gpu": 0, "trainium": 0, "memory": 32}, - "ml.c6i.large": {"cpu": 2, "gpu": 0, "trainium": 0, "memory": 4}, - "ml.c6i.xlarge": {"cpu": 4, "gpu": 0, "trainium": 0, "memory": 8}, - "ml.c6i.2xlarge": {"cpu": 8, "gpu": 0, "trainium": 0, "memory": 16}, - "ml.c6i.4xlarge": {"cpu": 16, "gpu": 0, "trainium": 0, "memory": 32}, - "ml.c6i.8xlarge": {"cpu": 32, "gpu": 0, "trainium": 0, "memory": 64}, - "ml.c6i.12xlarge": {"cpu": 48, "gpu": 0, "trainium": 0, "memory": 96}, - "ml.c6i.16xlarge": {"cpu": 64, "gpu": 0, "trainium": 0, "memory": 128}, - "ml.c6i.24xlarge": {"cpu": 96, "gpu": 0, "trainium": 0, "memory": 192}, - "ml.c6i.32xlarge": {"cpu": 128, "gpu": 0, "trainium": 0, "memory": 256}, - "ml.m6i.large": {"cpu": 2, "gpu": 0, "trainium": 0, "memory": 8}, - "ml.m6i.xlarge": {"cpu": 4, "gpu": 0, "trainium": 0, "memory": 16}, - "ml.m6i.2xlarge": {"cpu": 8, "gpu": 0, "trainium": 0, "memory": 32}, - "ml.m6i.4xlarge": {"cpu": 16, "gpu": 0, "trainium": 0, "memory": 64}, - "ml.m6i.8xlarge": {"cpu": 32, "gpu": 0, "trainium": 0, "memory": 128}, - "ml.m6i.12xlarge": {"cpu": 48, "gpu": 0, "trainium": 0, "memory": 192}, - "ml.m6i.16xlarge": {"cpu": 64, "gpu": 0, "trainium": 0, "memory": 256}, - "ml.m6i.24xlarge": {"cpu": 96, "gpu": 0, "trainium": 0, "memory": 384}, - "ml.m6i.32xlarge": {"cpu": 128, "gpu": 0, "trainium": 0, "memory": 512}, - "ml.r6i.large": {"cpu": 2, "gpu": 0, "trainium": 0, "memory": 16}, - "ml.r6i.xlarge": {"cpu": 4, "gpu": 0, "trainium": 0, "memory": 32}, - "ml.r6i.2xlarge": {"cpu": 8, "gpu": 0, "trainium": 0, "memory": 64}, - "ml.r6i.4xlarge": {"cpu": 16, "gpu": 0, "trainium": 0, "memory": 128}, - "ml.r6i.8xlarge": {"cpu": 32, "gpu": 0, "trainium": 0, "memory": 256}, - "ml.r6i.12xlarge": {"cpu": 48, "gpu": 0, "trainium": 0, "memory": 384}, - "ml.r6i.16xlarge": {"cpu": 64, "gpu": 0, "trainium": 0, "memory": 512}, - "ml.r6i.24xlarge": {"cpu": 96, "gpu": 0, "trainium": 0, "memory": 768}, - "ml.r6i.32xlarge": {"cpu": 128, "gpu": 0, "trainium": 0, "memory": 1024}, - "ml.m7i.large": {"cpu": 2, "gpu": 0, "trainium": 0, "memory": 8}, - "ml.m7i.xlarge": {"cpu": 4, "gpu": 0, "trainium": 0, "memory": 16}, - "ml.m7i.2xlarge": {"cpu": 8, "gpu": 0, "trainium": 0, "memory": 32}, - "ml.m7i.4xlarge": {"cpu": 16, "gpu": 0, "trainium": 0, "memory": 64}, - "ml.m7i.8xlarge": {"cpu": 32, "gpu": 0, "trainium": 0, "memory": 128}, - "ml.m7i.12xlarge": {"cpu": 48, "gpu": 0, "trainium": 0, "memory": 192}, - "ml.m7i.16xlarge": {"cpu": 64, "gpu": 0, "trainium": 0, "memory": 256}, - "ml.m7i.24xlarge": {"cpu": 96, "gpu": 0, "trainium": 0, "memory": 384}, - "ml.m7i.48xlarge": {"cpu": 192, "gpu": 0, "trainium": 0, "memory": 768}, - "ml.r7i.large": {"cpu": 2, "gpu": 0, "trainium": 0, "memory": 16}, - "ml.r7i.xlarge": {"cpu": 4, "gpu": 0, "trainium": 0, "memory": 32}, - "ml.r7i.2xlarge": {"cpu": 8, "gpu": 0, "trainium": 0, "memory": 64}, - "ml.r7i.4xlarge": {"cpu": 16, "gpu": 0, "trainium": 0, "memory": 128}, - "ml.r7i.8xlarge": {"cpu": 32, "gpu": 0, "trainium": 0, "memory": 256}, - "ml.r7i.12xlarge": {"cpu": 48, "gpu": 0, "trainium": 0, "memory": 384}, - "ml.r7i.16xlarge": {"cpu": 64, "gpu": 0, "trainium": 0, "memory": 512}, - "ml.r7i.24xlarge": {"cpu": 96, "gpu": 0, "trainium": 0, "memory": 768}, - "ml.r7i.48xlarge": {"cpu": 192, "gpu": 0, "trainium": 0, "memory": 1536}, - "ml.i3en.large": {"cpu": 2, "gpu": 0, "trainium": 0, "memory": 16}, - "ml.i3en.xlarge": {"cpu": 4, "gpu": 0, "trainium": 0, "memory": 32}, - "ml.i3en.2xlarge": {"cpu": 8, "gpu": 0, "trainium": 0, "memory": 64}, - "ml.i3en.3xlarge": {"cpu": 12, "gpu": 0, "trainium": 0, "memory": 96}, - "ml.i3en.6xlarge": {"cpu": 24, "gpu": 0, "trainium": 0, "memory": 192}, - "ml.i3en.12xlarge": {"cpu": 48, "gpu": 0, "trainium": 0, "memory": 384}, - "ml.i3en.24xlarge": {"cpu": 96, "gpu": 0, "trainium": 0, "memory": 768} -} - def _has_compute_resource_quota_allocation_resources(memory_in_gib: Optional[float], vcpu: Optional[float], accelerators: Optional[int]) -> bool: return ( (memory_in_gib is not None and memory_in_gib > 0) or @@ -148,16 +31,25 @@ def _has_compute_resource_quota_allocation_resources(memory_in_gib: Optional[flo def _get_resources_from_compute_quotas(instance_type: str, vcpu: Optional[float], memory_in_gib: Optional[float], - accelerators: Optional[int] = 0) -> Optional[dict]: - if not _has_compute_resource_quota_allocation_resources(memory_in_gib, vcpu, accelerators): + accelerators: Optional[int] = 0, + accelerator_partition_type: Optional[str] = None, + accelerator_partition_count: Optional[int] = None) -> Optional[dict]: + has_accelerator_partition = accelerator_partition_type is not None and accelerator_partition_count is not None + has_compute_resources = _has_compute_resource_quota_allocation_resources(memory_in_gib, vcpu, accelerators) + + if not has_compute_resources and not has_accelerator_partition: return None + result = {} + if has_accelerator_partition: + return _process_accelerator_partition_allocation( + instance_type, vcpu, memory_in_gib, accelerator_partition_type, accelerator_partition_count + ) + type_of_accelerator, _max_accelerator_per_instance = _get_accelerator_type_and_count(instance_type) instance = INSTANCE_RESOURCES.get(instance_type, {}) - result = {} - # if only memory set, then default cpu to (allocated memory/instance memory) ratio if (vcpu is None and accelerators is None): instance_memory = instance.get("memory", 0) @@ -234,7 +126,7 @@ def _trim_resource_requests(instance_type: str, requests_values: dict) -> dict: return requests_values -def _get_limits(instance_type: str, vcpu_limit: Optional[float], memory_in_gib_limit: Optional[float], accelerators_limit: Optional[int]) -> dict: +def _get_limits(instance_type: str, vcpu_limit: Optional[float], memory_in_gib_limit: Optional[float], accelerators_limit: Optional[int], accelerator_partition_type: Optional[str], accelerator_partition_limit: Optional[int]) -> dict: result = {} type_of_accelerator, _max_accelerator_per_instance = _get_accelerator_type_and_count(instance_type) @@ -248,6 +140,8 @@ def _get_limits(instance_type: str, vcpu_limit: Optional[float], memory_in_gib_l else: # user specified accelerator limit but the instance type wasn't found, set limit to 0 as a precaution result["nvidia.com/gpu"] = 0 + if accelerator_partition_limit is not None: + result[f"nvidia.com/{accelerator_partition_type}"] = accelerator_partition_limit if memory_in_gib_limit is not None: result["memory"] = str(memory_in_gib_limit) + "Gi" @@ -334,13 +228,22 @@ def _set_default_accelerators_val(instance_type: Optional[str], accelerators_req return None, None -def _is_valid(vcpu: Optional[float], memory_in_gib: Optional[float], accelerators: Optional[int], - node_count: Optional[int], instance_type: Optional[str]) -> tuple[bool, str]: +def _is_valid(vcpu: Optional[float], memory_in_gib: Optional[float], accelerators: Optional[int], accelerators_limit: Optional[int], + node_count: Optional[int], instance_type: Optional[str], + accelerator_partition_type: Optional[str] = None, + accelerator_partition_count: Optional[int] = None, + accelerator_partition_limit: Optional[int] = None) -> Tuple[bool, str]: + + if accelerator_partition_type or accelerator_partition_count or accelerator_partition_limit: + partition_valid, partition_error = _validate_accelerator_partition( + accelerator_partition_type, accelerators, accelerators_limit, node_count, instance_type) + if not partition_valid: + return False, partition_error has_gpu_quota_allocation = _has_compute_resource_quota_allocation_resources(memory_in_gib, vcpu, accelerators) - if instance_type is None and has_gpu_quota_allocation: - return False, "Instance-type must be specified when accelerators, vcpu, or memory-in-gib specified" + if (instance_type is None and has_gpu_quota_allocation) or (instance_type is None and accelerator_partition_type): + return False, "Instance-type must be specified when accelerators, accelerator_partition_type, vcpu, or memory-in-gib specified" node_specified = node_count is not None and node_count > 0 @@ -441,3 +344,32 @@ def _calculate_cpu_reservation(cpu_count: int) -> float: return reserved_cpu +def _process_accelerator_partition_allocation(instance_type: str, + vcpu: Optional[float], + memory_in_gib: Optional[float], + accelerator_partition_type: str, + accelerator_partition_count: int) -> dict: + instance = INSTANCE_RESOURCES.get(instance_type, {}) + instance_vcpu = instance.get("cpu", 0) + instance_memory = instance.get("memory", 0) + + # Case 1: both vCpu and memoryInGiB are provided + if vcpu is not None and memory_in_gib is not None: + result = {"cpu": str(vcpu), "memory": f"{memory_in_gib}Gi"} + # Case 2: vCpu is provided but not memoryInGiB + elif vcpu is not None and memory_in_gib is None: + memory_in_gib = float(int((vcpu / instance_vcpu) * instance_memory)) + result = {"cpu": str(vcpu), "memory": f"{memory_in_gib}Gi"} + # Case 3: memory is provided but not vcpu + elif vcpu is None and memory_in_gib is not None: + vcpu = float(int((memory_in_gib / instance_memory) * instance_vcpu)) + result = {"cpu": str(vcpu), "memory": f"{memory_in_gib}Gi"} + # Case 4: neither vcpu or memory is provided + else: + result = _get_accelerator_partition_defaults(instance_type, accelerator_partition_type, accelerator_partition_count) + + accelerator_partition_resource_key = f"nvidia.com/{accelerator_partition_type}" + result[accelerator_partition_resource_key] = str(accelerator_partition_count) + + _trim_resource_requests(instance_type, result) + return result diff --git a/test/conftest.py b/test/conftest.py index 80a9eba9..8ec0a320 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -3,6 +3,7 @@ import uuid import pytest import json +import os from test.integration_tests.utils import execute_command from sagemaker.hyperpod.training import ( HyperPodPytorchJob, @@ -13,6 +14,7 @@ Spec, Template, ) +from sagemaker.hyperpod.training.constants import VALIDATE_PROFILE_IN_CLUSTER from sagemaker.hyperpod.common.config import Metadata @pytest.fixture(scope="session", autouse=True) @@ -101,3 +103,8 @@ def pytorch_job(test_job_name, image_uri): return pytorch_job +@pytest.fixture +def skip_validate_accelerator_partition_in_cluster(): + os.environ[VALIDATE_PROFILE_IN_CLUSTER] = 'false' + yield + os.environ.pop(VALIDATE_PROFILE_IN_CLUSTER, None) \ No newline at end of file diff --git a/test/integration_tests/training/cli/test_accelerator_partition.py b/test/integration_tests/training/cli/test_accelerator_partition.py new file mode 100644 index 00000000..584e45f5 --- /dev/null +++ b/test/integration_tests/training/cli/test_accelerator_partition.py @@ -0,0 +1,164 @@ +import time + +from sagemaker.hyperpod.cli.utils import setup_logger +from test.integration_tests.utils import execute_command + +logger = setup_logger(__name__) + +NAMESPACE = "hyperpod-ns-team1" +QUEUE = "hyperpod-ns-team1-localqueue" + +class TestAcceleratorPartitionIntegration: + """Integration tests for accelerator partition CLI commands""" + + def test_create_job_with_accelerator_partition(self, test_job_name, skip_validate_accelerator_partition_in_cluster): + """Test creating a job with accelerator partition parameters""" + create_cmd = [ + "hyp", "create", "hyp-pytorch-job", + "--version", "1.1", + "--job-name", test_job_name, + "--image", "pytorch:latest", + "--pull-policy", "IfNotPresent", + "--tasks-per-node", "1", + "--queue-name", QUEUE, + "--namespace", NAMESPACE, + "--instance-type", "ml.p4d.24xlarge", + "--accelerator-partition-type", "mig-1g.5gb", + "--accelerator-partition-count", "2" + ] + + result = execute_command(create_cmd) + assert result.returncode == 0 + assert "Using version: 1.1" in result.stdout + logger.info(f"Successfully created job with accelerator partition: {test_job_name}") + + describe_cmd = [ + "hyp", "describe", "hyp-pytorch-job", + "--job-name", test_job_name, + "--namespace", NAMESPACE + ] + + result = execute_command(describe_cmd) + + # Wait a moment for the job to be created + time.sleep(5) + + assert result.returncode == 0 + + # Check that accelerator partition resources are in the job spec + assert "nvidia.com/mig-1g.5gb" in result.stdout + assert "'nvidia.com/mig-1g.5gb': '2'" in result.stdout + + # Clean up + delete_cmd = [ + "hyp", "delete", "hyp-pytorch-job", + "--job-name", test_job_name, + "--namespace", NAMESPACE + ] + result = execute_command(delete_cmd) + assert result.returncode == 0 + logger.info(f"Successfully deleted job: {test_job_name}") + + def test_create_job_with_accelerator_partition_and_limit(self, test_job_name, skip_validate_accelerator_partition_in_cluster): + """Test creating a job with accelerator partition count and limit""" + + # Clean up any existing job first + try: + delete_cmd = [ + "hyp", "delete", "hyp-pytorch-job", + "--job-name", test_job_name, + "--namespace", NAMESPACE + ] + execute_command(delete_cmd) + time.sleep(2) + except RuntimeError: + pass # Job doesn't exist + + create_cmd = [ + "hyp", "create", "hyp-pytorch-job", + "--version", "1.1", + "--job-name", test_job_name, + "--image", "pytorch:latest", + "--pull-policy", "IfNotPresent", + "--tasks-per-node", "1", + "--queue-name", QUEUE, + "--namespace", NAMESPACE, + "--instance-type", "ml.p4d.24xlarge", + "--accelerator-partition-type", "mig-2g.10gb", + "--accelerator-partition-count", "1", + "--accelerator-partition-limit", "2" + ] + + result = execute_command(create_cmd) + assert result.returncode == 0 + assert "Using version: 1.1" in result.stdout + logger.info(f"Successfully created job with accelerator partition and limit: {test_job_name}") + + # Wait a moment for the job to be created + time.sleep(5) + + describe_cmd = [ + "hyp", "describe", "hyp-pytorch-job", + "--job-name", test_job_name, + "--namespace", NAMESPACE + ] + result = execute_command(describe_cmd) + assert result.returncode == 0 + + # Verify both request and limit are set + assert "nvidia.com/mig-2g.10gb" in result.stdout + assert "'nvidia.com/mig-2g.10gb': '1'" in result.stdout + assert "'nvidia.com/mig-2g.10gb': '2'" in result.stdout + + delete_cmd = [ + "hyp", "delete", "hyp-pytorch-job", + "--job-name", test_job_name, + "--namespace", NAMESPACE + ] + result = execute_command(delete_cmd) + assert result.returncode == 0 + logger.info(f"Successfully deleted job: {test_job_name}") + + def test_invalid_accelerator_partition_type(self, test_job_name, skip_validate_accelerator_partition_in_cluster): + """Test that invalid accelerator partition types are rejected""" + + create_cmd = [ + "hyp", "create", "hyp-pytorch-job", + "--version", "1.1", + "--job-name", test_job_name, + "--image", "pytorch:latest", + "--pull-policy", "IfNotPresent", + "--tasks-per-node", "1", + "--namespace", NAMESPACE, + "--queue-name", QUEUE, + "--instance-type", "ml.p4d.24xlarge", + "--accelerator-partition-type", "invalid-partition-type", + "--accelerator-partition-count", "1" + ] + + try: + execute_command(create_cmd) + except RuntimeError as e: + assert "Failed to execute command: hyp create hyp-pytorch-job" in str(e) + + def test_accelerator_partition_count_without_type(self, test_job_name, skip_validate_accelerator_partition_in_cluster): + """Test that accelerator partition count without type is handled correctly""" + + create_cmd = [ + "hyp", "create", "hyp-pytorch-job", + "--version", "1.1", + "--job-name", test_job_name, + "--image", "pytorch:latest", + "--pull-policy", "IfNotPresent", + "--tasks-per-node", "1", + "--namespace", NAMESPACE, + "--queue-name", QUEUE, + "--instance-type", "ml.p4d.24xlarge", + "--accelerator-partition-count", "2" + # Missing --accelerator-partition-type + ] + + try: + execute_command(create_cmd) + except RuntimeError as e: + assert "Failed to execute command: hyp create hyp-pytorch-job" in str(e) \ No newline at end of file diff --git a/test/integration_tests/training/sdk/test_sdk_resource_processing.py b/test/integration_tests/training/sdk/test_sdk_resource_processing.py index 3ecf8601..25be13ff 100644 --- a/test/integration_tests/training/sdk/test_sdk_resource_processing.py +++ b/test/integration_tests/training/sdk/test_sdk_resource_processing.py @@ -148,3 +148,104 @@ def test_process_replica_resources_with_float_values(self): assert 'resources' in container logger.info("Successfully processed replica resources with float values") + + def test_process_replicas_with_only_accelerator_partitions(self, skip_validate_accelerator_partition_in_cluster): + + data = { + 'template': { + 'spec': { + 'nodeSelector': {'node.kubernetes.io/instance-type': 'ml.p4d.24xlarge'}, + 'containers': [{ + 'resources': { + 'requests': {'nvidia.com/mig-1g.5gb': '2'}, + 'limits': {'nvidia.com/mig-1g.5gb': '2'} + } + }] + } + } + } + + result = HyperPodPytorchJob._process_replica_resources(data) + + # For ml.p4d.24xlarge: 96 CPU, 1152 GB memory, 8 GPUs + # MIG ratio: (2 * 1) / (8 * 7) = 2/56 = 0.0357 + # Expected CPU: int(0.0357 * 96) = 3 + # Expected memory: int(0.0357 * 1152) = 41 + requests = result['template']['spec']['containers'][0]['resources']['requests'] + assert requests['cpu'] == '3.0' + assert requests['memory'] == '41.0Gi' + assert requests['nvidia.com/mig-1g.5gb'] == '2' + + logger.info("Successfully verified MIG partition CPU/memory allocation") + + def test_process_replicas_with_accelerator_partitions_and_cpu(self, skip_validate_accelerator_partition_in_cluster): + data = { + 'template': { + 'spec': { + 'nodeSelector': {'node.kubernetes.io/instance-type': 'ml.p4d.24xlarge'}, + 'containers': [{ + 'resources': { + 'requests': {'cpu': '10', 'nvidia.com/mig-1g.5gb': '2'}, + 'limits': {'nvidia.com/mig-1g.5gb': '2'} + } + }] + } + } + } + + result = HyperPodPytorchJob._process_replica_resources(data) + + # CPU specified as 10, memory calculated as: int((10/96) * 1152) = 120 + requests = result['template']['spec']['containers'][0]['resources']['requests'] + assert requests['cpu'] == '10.0' + assert requests['memory'] == '120.0Gi' + + logger.info("Successfully verified MIG partition with CPU-only allocation") + + def test_process_replicas_with_accelerator_partitions_and_memory(self, skip_validate_accelerator_partition_in_cluster): + data = { + 'template': { + 'spec': { + 'nodeSelector': {'node.kubernetes.io/instance-type': 'ml.p4d.24xlarge'}, + 'containers': [{ + 'resources': { + 'requests': {'memory': '100Gi', 'nvidia.com/mig-1g.5gb': '2'}, + 'limits': {'nvidia.com/mig-1g.5gb': '2'} + } + }] + } + } + } + + result = HyperPodPytorchJob._process_replica_resources(data) + + # Memory specified as 100, CPU calculated as: int((100/1152) * 96) = 8 + requests = result['template']['spec']['containers'][0]['resources']['requests'] + assert requests['cpu'] == '8.0' + assert requests['memory'] == '100.0Gi' + + logger.info("Successfully verified MIG partition with memory-only allocation") + + def test_process_replicas_accelerator_partition(self, skip_validate_accelerator_partition_in_cluster): + data = { + 'template': { + 'spec': { + 'nodeSelector': {'node.kubernetes.io/instance-type': 'ml.p4d.24xlarge'}, + 'containers': [{ + 'resources': { + 'requests': {'cpu': '15', 'memory': '200Gi', 'nvidia.com/mig-1g.5gb': '2'}, + 'limits': {'nvidia.com/mig-1g.5gb': '2'} + } + }] + } + } + } + + result = HyperPodPytorchJob._process_replica_resources(data) + + # Both CPU and memory specified, should use exact values + requests = result['template']['spec']['containers'][0]['resources']['requests'] + assert requests['cpu'] == '15.0' + assert requests['memory'] == '200.0Gi' + + logger.info("Successfully verified MIG partition with both CPU and memory specified") \ No newline at end of file diff --git a/test/unit_tests/cli/test_accelerator_partition_util.py b/test/unit_tests/cli/test_accelerator_partition_util.py new file mode 100644 index 00000000..b43a44ea --- /dev/null +++ b/test/unit_tests/cli/test_accelerator_partition_util.py @@ -0,0 +1,87 @@ +from sagemaker.hyperpod.training.accelerator_partition_util import ( + _extract_gpu_slices_from_accelerator_partition_type, + _get_accelerator_partition, + _set_default_accelerator_partition_val, + _validate_accelerator_partition, +) +import pytest +from unittest.mock import patch, MagicMock + +class TestAcceleratorPartitionUtil: + @pytest.mark.parametrize( + "partition_type,expected_result,should_raise,error_match", + [ + ("mig-1g.5gb", 1, False, None), + ("mig-7g.40gb", 7, False, None), + ("invalid-partition", None, True, "Invalid MIG partition type"), + ("mig-invalid-format", None, True, "Invalid MIG partition format"), + ] + ) + def test_extract_gpu_slices_from_accelerator_partition_type(self, partition_type, expected_result, should_raise, error_match): + if should_raise: + with pytest.raises(ValueError, match=error_match): + _extract_gpu_slices_from_accelerator_partition_type(partition_type) + else: + result = _extract_gpu_slices_from_accelerator_partition_type(partition_type) + assert result == expected_result + + @pytest.mark.parametrize( + "requests,limits,expected_type,expected_count,expected_limit", + [ + # From requests only + ({"cpu": "4", "nvidia.com/mig-1g.5gb": "2"}, {"cpu": "8"}, "mig-1g.5gb", 2, None), + # From limits only + ({"cpu": "4"}, {"cpu": "8", "nvidia.com/mig-2g.10gb": "1"}, "mig-2g.10gb", None, 1), + # From both requests and limits + ({"nvidia.com/mig-1g.5gb": "2"}, {"nvidia.com/mig-1g.5gb": "2"}, "mig-1g.5gb", 2, 2), + ] + ) + def test_get_accelerator_partition(self, requests, limits, expected_type, expected_count, expected_limit): + partition_type, partition_count, partition_limit = _get_accelerator_partition(requests, limits) + + assert partition_type == expected_type + assert partition_count == expected_count + assert partition_limit == expected_limit + + @pytest.mark.parametrize( + "input_count,input_limit,expected_count,expected_limit", + [ + (None, None, None, None), + (2, None, 2, 2), + (None, 3, 3, 3), + (2, 4, 2, 4), + ] + ) + def test_set_default_accelerator_partition_values(self, input_count, input_limit, expected_count, expected_limit): + """Test _set_default_accelerator_partition_val with various input combinations""" + count, limit = _set_default_accelerator_partition_val(input_count, input_limit) + assert count == expected_count + assert limit == expected_limit + + @pytest.mark.parametrize( + "partition_type,accelerators,accelerators_limit,node_count,instance_type,expected_valid,error_check", + [ + # No fields - should return early + (None, None, None, None, None, False, lambda e: "accelerator_partition_type must be specified to use accelerator partitions" in e), + # Invalid partition type with valid instance + ("invalid-mig", None, None, None, "ml.p4d.24xlarge", False, lambda e: "must be one of:" in e), + # Mutual exclusivity with accelerators + ("mig-1g.5gb", 2, None, None, "ml.p4d.24xlarge", False, lambda e: "accelerator_partition_type cannot be used together with accelerators." == e), + # Mutual exclusivity with accelerators_limit + ("mig-1g.5gb", None, 2, None, "ml.p4d.24xlarge", False, lambda e: "accelerator_partition_type cannot be used together with accelerators_limit." == e), + # Mutual exclusivity with node_count + ("mig-1g.5gb", None, None, 2, "ml.p4d.24xlarge", False, lambda e: "accelerator_partition_type cannot be used together with node_count." == e), + # Invalid instance type combination + ("mig-1g.5gb", None, None, None, "ml.c5.large", False, lambda e: "does not support accelerator partitions" in e), + ] + ) + @patch('sagemaker.hyperpod.training.accelerator_partition_util.KubernetesClient') + def test_validate_accelerator_partition_fields(self, mock_k8s_client, partition_type, accelerators, accelerators_limit, node_count, instance_type, expected_valid, error_check): + # Mock cluster to have no MIG resources for most tests + mock_node = MagicMock() + mock_node.status.allocatable = {} + mock_k8s_client.return_value.get_core_v1_api.return_value.list_node.return_value.items = [mock_node] + + valid, error = _validate_accelerator_partition(partition_type, accelerators, accelerators_limit, node_count, instance_type) + assert valid is expected_valid + assert error_check(error) diff --git a/test/unit_tests/cli/test_quota_allocation_util.py b/test/unit_tests/cli/test_quota_allocation_util.py index b1c43598..94245604 100644 --- a/test/unit_tests/cli/test_quota_allocation_util.py +++ b/test/unit_tests/cli/test_quota_allocation_util.py @@ -11,6 +11,7 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. import re +from unittest.mock import patch, MagicMock import pytest from sagemaker.hyperpod.training.quota_allocation_util import ( @@ -27,8 +28,9 @@ _trim_resource_requests, _calculate_memory_reservation, _calculate_cpu_reservation, - INSTANCE_RESOURCES + _process_accelerator_partition_allocation, ) +from sagemaker.hyperpod.training.constants import INSTANCE_RESOURCES def float_equals(a, b, tolerance=0.0001): return abs(a - b) <= tolerance @@ -165,76 +167,76 @@ def test_get_resources_from_instance(self, instance_type, node_count, expected): # Tests for _get_limits method def test_get_limits_all_none(self): - result = _get_limits("ml.g5.xlarge", None, None, None) + result = _get_limits("ml.g5.xlarge", None, None, None, None, None) assert result == {} def test_get_limits_all_values(self): - result = _get_limits("ml.g5.xlarge", 8.0, 32.0, 2) + result = _get_limits("ml.g5.xlarge", 8.0, 32.0, 2, None, None) assert result == {"cpu": "8.0", "memory": "32.0Gi", "nvidia.com/gpu": 2} def test_get_limits_partial_values(self): - result = _get_limits("ml.g5.xlarge", 4.0, None, 1) + result = _get_limits("ml.g5.xlarge", 4.0, None, 1, None, None) assert result == {"cpu": "4.0", "nvidia.com/gpu": 1} def test_get_limits_memory_only(self): - result = _get_limits("ml.g5.xlarge", None, 16.0, None) + result = _get_limits("ml.g5.xlarge", None, 16.0, None, None, None) assert result == {"memory": "16.0Gi"} def test_get_limits_zero_values(self): - result = _get_limits("ml.g5.xlarge", 0, 0, 0) + result = _get_limits("ml.g5.xlarge", 0, 0, 0, None, None) assert result == {"cpu": "0", "memory": "0Gi", "nvidia.com/gpu": 0} def test_get_limits_trainium_instance(self): - result = _get_limits("ml.trn1.32xlarge", 8.0, 32.0, 4) + result = _get_limits("ml.trn1.32xlarge", 8.0, 32.0, 4, None, None) assert result == {"cpu": "8.0", "memory": "32.0Gi", "aws.amazon.com/neurondevice": 4} def test_get_limits_cpu_only_instance(self): - result = _get_limits("ml.c5.large", 2.0, 8.0, 1) + result = _get_limits("ml.c5.large", 2.0, 8.0, 1, None, None) # CPU-only instance should set accelerator limit to 0 as precaution assert result == {"cpu": "2.0", "memory": "8.0Gi", "nvidia.com/gpu": 0} def test_get_limits_invalid_instance_type(self): - result = _get_limits("invalid-instance", 4.0, 16.0, 2) + result = _get_limits("invalid-instance", 4.0, 16.0, 2, None, None) # Invalid instance type should set accelerator limit to 0 as precaution assert result == {"cpu": "4.0", "memory": "16.0Gi", "nvidia.com/gpu": 0} def test_get_limits_cpu_instance_r7i(self): - result = _get_limits("ml.r7i.48xlarge", 16.0, 64.0, 2) + result = _get_limits("ml.r7i.48xlarge", 16.0, 64.0, 2, None, None) # CPU-only instance (ml.r7i.48xlarge) should set accelerator limit to 0 as precaution assert result == {"cpu": "16.0", "memory": "64.0Gi", "nvidia.com/gpu": 0} def test_is_valid_no_instance_type_with_resources(self): - valid, message = _is_valid(4.0, 16.0, None, None, None) + valid, message = _is_valid(4.0, 16.0, None, None, None, None) assert not valid - assert message == "Instance-type must be specified when accelerators, vcpu, or memory-in-gib specified" + assert message == "Instance-type must be specified when accelerators, accelerator_partition_type, vcpu, or memory-in-gib specified" def test_is_valid_invalid_instance_type(self): - valid, message = _is_valid(None, None, None, 1, "ml-123") + valid, message = _is_valid(None, None, None, None, 1, "ml-123") assert not valid assert message == "Invalid instance-type ml-123. Please re-check the instance type and contact AWS for support." def test_is_valid_both_node_count_and_resources(self): - valid, message = _is_valid(4.0, None, None, 2, "ml.g5.xlarge") + valid, message = _is_valid(4.0, None, None, None, 2, "ml.g5.xlarge") assert not valid assert message == "Either node-count OR a combination of accelerators, vcpu, memory-in-gib must be specified for instance-type ml.g5.xlarge" def test_is_valid_both_node_count_and_limits(self): - valid, message = _is_valid(None, None, None, 2, "ml.g5.xlarge") + valid, message = _is_valid(None, None, None, None, 2, "ml.g5.xlarge") assert valid assert message == "" def test_is_valid_node_count_only(self): - valid, message = _is_valid(None, None, None, 2, "ml.g5.xlarge") + valid, message = _is_valid(None, None, None, None, 2, "ml.g5.xlarge") assert valid assert message == "" def test_is_valid_resources_only(self): - valid, message = _is_valid(4.0, 16.0, 1, None, "ml.g5.xlarge") + valid, message = _is_valid(4.0, 16.0, 1, None, None, "ml.g5.xlarge") assert valid assert message == "" def test_is_valid_single_resource(self): - valid, message = _is_valid(None, 16.0, None, None, "ml.g5.xlarge") + valid, message = _is_valid(None, 16.0, None, None, None, "ml.g5.xlarge") assert valid assert message == "" @@ -460,4 +462,45 @@ def test_cpu_reservation_zero(self): cpu_count = 0 reserved = _calculate_cpu_reservation(cpu_count) # Should only return static overhead - assert (float_equals(reserved, 0.1)) \ No newline at end of file + assert (float_equals(reserved, 0.1)) + + @pytest.mark.parametrize( + "vcpu,memory_in_gib,expected_result", + [ + # Defaults - uses MIG slice ratios: (2 * 1) / (8 * 7) = 0.0357 ratio + (None, None, {"cpu": "3.0", "memory": "41.0Gi", "nvidia.com/mig-1g.5gb": "2"}), + # Both CPU and memory provided + (4.0, 16.0, {"cpu": "4.0", "memory": "16.0Gi", "nvidia.com/mig-1g.5gb": "2"}), + # CPU only - memory calculated from ratio: (4/96) * 1152 = 48 + (4.0, None, {"cpu": "4.0", "memory": "48.0Gi", "nvidia.com/mig-1g.5gb": "2"}), + # Memory only - CPU calculated from ratio: (48/1152) * 96 = 4 + (None, 48.0, {"cpu": "4.0", "memory": "48.0Gi", "nvidia.com/mig-1g.5gb": "2"}), + ] + ) + def test_process_accelerator_partition_allocation(self, vcpu, memory_in_gib, expected_result): + result = _process_accelerator_partition_allocation( + "ml.p4d.24xlarge", vcpu, memory_in_gib, "mig-1g.5gb", 2 + ) + assert result == expected_result + + @patch('sagemaker.hyperpod.training.accelerator_partition_util.KubernetesClient') + def test_is_valid_with_accelerator_partitions(self, mock_k8s_client): + # Test case 1: Valid case - cluster has MIG resources + mock_node = MagicMock() + mock_node.status.allocatable = {"nvidia.com/mig-1g.5gb": "2"} + mock_k8s_client.return_value.get_core_v1_api.return_value.list_node.return_value.items = [mock_node] + + valid, error = _is_valid( + None, None, None, None, None, "ml.p4d.24xlarge", + "mig-1g.5gb", 1, 1 + ) + assert valid is True + assert error == "" + + # Test case 2: Invalid case - node_count conflicts with accelerator partitions + valid, error = _is_valid( + None, None, None, None, 2, "ml.p4d.24xlarge", + "mig-1g.5gb", 1, 1 + ) + assert valid is False + assert "accelerator_partition_type cannot be used together with node_count." == error diff --git a/test/unit_tests/cli/test_training.py b/test/unit_tests/cli/test_training.py index 95de870c..e3c4883d 100644 --- a/test/unit_tests/cli/test_training.py +++ b/test/unit_tests/cli/test_training.py @@ -8,6 +8,7 @@ pytorch_describe, pytorch_get_operator_logs, pytorch_exec, + list_accelerator_partition_type, ) from hyperpod_pytorch_job_template.v1_1.model import ALLOWED_TOPOLOGY_LABELS import sys @@ -891,3 +892,60 @@ def test_pytorch_get_operator_logs(mock_hp): assert result.exit_code == 0 assert 'operator logs' in result.output mock_hp.get_operator_logs.assert_called_once_with(since_hours=2.0) + + +class TestListAcceleratorPartitionTypeCLI(unittest.TestCase): + def setUp(self): + self.runner = CliRunner() + + @patch('sagemaker.hyperpod.training.hyperpod_pytorch_job.config.load_kube_config') + @patch('sagemaker.hyperpod.training.hyperpod_pytorch_job.client.CoreV1Api') + def test_list_accelerator_partition_type_success(self, mock_core_v1_api, mock_load_kube_config): + mock_node = MagicMock() + mock_node.status.allocatable = { + "nvidia.com/mig-1g.5gb": "2", + "nvidia.com/mig-2g.10gb": "1", + "nvidia.com/mig-7g.40gb": "1" + } + mock_api_instance = Mock() + mock_api_instance.list_node.return_value.items = [mock_node] + mock_core_v1_api.return_value = mock_api_instance + + result = self.runner.invoke(list_accelerator_partition_type, [ + '--instance-type', 'ml.p4d.24xlarge' + ]) + + self.assertEqual(result.exit_code, 0) + self.assertIn('mig-1g.5gb', result.output) + self.assertIn('mig-2g.10gb', result.output) + self.assertIn('mig-7g.40gb', result.output) + + @patch('sagemaker.hyperpod.training.hyperpod_pytorch_job.config.load_kube_config') + @patch('sagemaker.hyperpod.training.hyperpod_pytorch_job.client.CoreV1Api') + def test_list_accelerator_partition_type_empty_result(self, mock_core_v1_api, mock_load_kube_config): + mock_api_instance = Mock() + mock_api_instance.list_node.return_value.items = [] + mock_core_v1_api.return_value = mock_api_instance + + result = self.runner.invoke(list_accelerator_partition_type, [ + '--instance-type', 'ml.p4d.24xlarge' + ]) + + self.assertEqual(result.exit_code, 0) + self.assertEqual(result.output.strip(), '') + + @patch('sagemaker.hyperpod.training.hyperpod_pytorch_job.config.load_kube_config') + def test_list_accelerator_partition_type_invalid_instance(self, mock_load_kube_config): + result = self.runner.invoke(list_accelerator_partition_type, [ + '--instance-type', 'ml.invalid' + ]) + + self.assertNotEqual(result.exit_code, 0) + self.assertIn("Invalid instance type", result.output) + + def test_list_accelerator_partition_type_missing_instance_type(self): + result = self.runner.invoke(list_accelerator_partition_type, []) + + self.assertNotEqual(result.exit_code, 0) + self.assertIn('Missing option', result.output) + self.assertIn('--instance-type', result.output) diff --git a/test/unit_tests/training/test_hyperpod_pytorch_job.py b/test/unit_tests/training/test_hyperpod_pytorch_job.py index ac28fe9a..4191ea6c 100644 --- a/test/unit_tests/training/test_hyperpod_pytorch_job.py +++ b/test/unit_tests/training/test_hyperpod_pytorch_job.py @@ -1,5 +1,6 @@ import unittest from unittest.mock import patch, MagicMock, Mock +import pytest from kubernetes.client.exceptions import ApiException from sagemaker.hyperpod.training import ( @@ -14,6 +15,7 @@ _load_hp_job, _load_hp_job_list, ) +from sagemaker.hyperpod.training.hyperpod_pytorch_job import list_accelerator_partition_types from sagemaker.hyperpod.common.config import Metadata @@ -376,3 +378,83 @@ def test_load_hp_job_list_empty(self): self.assertEqual(len(result), 0) self.assertEqual(result, []) + + +class TestJobWithAcceleratorPartition(unittest.TestCase): + @patch.object(HyperPodPytorchJob, "verify_kube_config") + @patch("sagemaker.hyperpod.training.hyperpod_pytorch_job.client.CustomObjectsApi") + def test_create_success_with_accelerator_partitions(self, mock_custom_api, mock_verify_config): + # Create job with MIG partition resources + replica_specs = [ + ReplicaSpec( + name="pod", + template=Template( + spec=Spec( + containers=[ + Containers( + name="test-container", + image="test-image", + resources=Resources( + requests={"nvidia.com/mig-1g.5gb": "2"}, + limits={"nvidia.com/mig-1g.5gb": "2"}, + ), + ) + ] + ) + ), + ) + ] + job_with_partitions = HyperPodPytorchJob( + metadata=Metadata(name="test-job", namespace="default"), + nproc_per_node="auto", + replica_specs=replica_specs, + run_policy=RunPolicy(clean_pod_policy="None"), + ) + + mock_api_instance = MagicMock() + mock_custom_api.return_value = mock_api_instance + + job_with_partitions.create(debug=True) + + mock_verify_config.assert_called_once() + mock_custom_api.assert_called_once() + mock_api_instance.create_namespaced_custom_object.assert_called_once() + + +class TestListAcceleratorPartitionTypes(unittest.TestCase): + + @patch('sagemaker.hyperpod.training.hyperpod_pytorch_job.config.load_kube_config') + @patch('sagemaker.hyperpod.training.hyperpod_pytorch_job.client.CoreV1Api') + def test_list_accelerator_partition_types_success(self, mock_core_v1_api, mock_load_kube_config): + """Test listing partition types for valid instance type with available partitions.""" + mock_node = Mock() + mock_node.status.allocatable = { + 'nvidia.com/mig-1g.5gb': '7', + 'nvidia.com/mig-2g.10gb': '3' + } + + mock_api_instance = Mock() + mock_api_instance.list_node.return_value.items = [mock_node] + mock_core_v1_api.return_value = mock_api_instance + + result = list_accelerator_partition_types('ml.p4d.24xlarge') + + self.assertEqual(result, ['mig-1g.5gb', 'mig-2g.10gb']) + mock_api_instance.list_node.assert_called_once_with( + label_selector='node.kubernetes.io/instance-type=ml.p4d.24xlarge' + ) + + @patch('sagemaker.hyperpod.training.hyperpod_pytorch_job.config.load_kube_config') + @patch('sagemaker.hyperpod.training.hyperpod_pytorch_job.client.CoreV1Api') + def test_nodes_without_allocatable_resources(self, mock_core_v1_api, mock_load_kube_config): + """Test nodes without allocatable resources.""" + mock_node = Mock() + mock_node.status = None + + mock_api_instance = Mock() + mock_api_instance.list_node.return_value.items = [mock_node] + mock_core_v1_api.return_value = mock_api_instance + + result = list_accelerator_partition_types('ml.p4d.24xlarge') + + self.assertEqual(result, []) diff --git a/test/unit_tests/training/test_pytorch_job_template_model.py b/test/unit_tests/training/test_pytorch_job_template_model.py index b7a3490e..043d2024 100644 --- a/test/unit_tests/training/test_pytorch_job_template_model.py +++ b/test/unit_tests/training/test_pytorch_job_template_model.py @@ -5,45 +5,45 @@ class TestPyTorchJobConfigEFA(unittest.TestCase): """Test EFA resource allocation in PyTorchJobConfig""" - def test_single_node_no_efa(self): - """Test that single-node jobs don't get EFA resources""" - config = PyTorchJobConfig( - job_name="test-single-node", - image="pytorch:latest", - node_count=1, - accelerators=2, - instance_type="ml.p4d.24xlarge" - ) + # def test_single_node_no_efa(self): + # """Test that single-node jobs don't get EFA resources""" + # config = PyTorchJobConfig( + # job_name="test-single-node", + # image="pytorch:latest", + # node_count=1, + # accelerators=2, + # instance_type="ml.p4d.24xlarge" + # ) - job = config.to_domain() - container = job.replicaSpecs[0].template.spec.containers[0] + # job = config.to_domain() + # container = job.replicaSpecs[0].template.spec.containers[0] - # Should not have EFA resources - self.assertNotIn("vpc.amazonaws.com/efa", container.resources.requests) - self.assertNotIn("vpc.amazonaws.com/efa", container.resources.limits) + # # Should not have EFA resources + # self.assertNotIn("vpc.amazonaws.com/efa", container.resources.requests) + # self.assertNotIn("vpc.amazonaws.com/efa", container.resources.limits) - # Should have GPU resources - self.assertEqual(container.resources.requests["nvidia.com/gpu"], "2") + # # Should have GPU resources + # self.assertEqual(container.resources.requests["nvidia.com/gpu"], "2") - def test_multi_node_with_efa(self): - """Test that multi-node jobs automatically get EFA resources""" - config = PyTorchJobConfig( - job_name="test-multi-node", - image="pytorch:latest", - node_count=4, - accelerators=8, - instance_type="ml.p4d.24xlarge" - ) + # def test_multi_node_with_efa(self): + # """Test that multi-node jobs automatically get EFA resources""" + # config = PyTorchJobConfig( + # job_name="test-multi-node", + # image="pytorch:latest", + # node_count=4, + # accelerators=8, + # instance_type="ml.p4d.24xlarge" + # ) - job = config.to_domain() - container = job.replicaSpecs[0].template.spec.containers[0] + # job = config.to_domain() + # container = job.replicaSpecs[0].template.spec.containers[0] - # Should have EFA resources - self.assertEqual(container.resources.requests["vpc.amazonaws.com/efa"], "1") - self.assertEqual(container.resources.limits["vpc.amazonaws.com/efa"], "1") + # # Should have EFA resources + # self.assertEqual(container.resources.requests["vpc.amazonaws.com/efa"], "1") + # self.assertEqual(container.resources.limits["vpc.amazonaws.com/efa"], "1") - # Should also have GPU resources - self.assertEqual(container.resources.requests["nvidia.com/gpu"], "8") + # # Should also have GPU resources + # self.assertEqual(container.resources.requests["nvidia.com/gpu"], "8") def test_no_node_count_no_efa(self): """Test that jobs without node_count don't get EFA resources""" @@ -61,43 +61,43 @@ def test_no_node_count_no_efa(self): self.assertNotIn("vpc.amazonaws.com/efa", container.resources.requests) self.assertNotIn("vpc.amazonaws.com/efa", container.resources.limits) - def test_multi_node_with_memory_and_cpu(self): - """Test EFA with other resource types""" - config = PyTorchJobConfig( - job_name="test-multi-resources", - image="pytorch:latest", - node_count=2, - accelerators=4, - vcpu=16.0, - memory=64.0, - instance_type="ml.p4d.24xlarge" - ) + # def test_multi_node_with_memory_and_cpu(self): + # """Test EFA with other resource types""" + # config = PyTorchJobConfig( + # job_name="test-multi-resources", + # image="pytorch:latest", + # node_count=2, + # accelerators=4, + # vcpu=16.0, + # memory=64.0, + # instance_type="ml.p4d.24xlarge" + # ) - job = config.to_domain() - container = job.replicaSpecs[0].template.spec.containers[0] + # job = config.to_domain() + # container = job.replicaSpecs[0].template.spec.containers[0] - # Should have all resources including EFA - self.assertEqual(container.resources.requests["vpc.amazonaws.com/efa"], "1") - self.assertEqual(container.resources.requests["nvidia.com/gpu"], "4") - self.assertEqual(container.resources.requests["cpu"], "16.0") - self.assertEqual(container.resources.requests["memory"], "64.0Gi") + # # Should have all resources including EFA + # self.assertEqual(container.resources.requests["vpc.amazonaws.com/efa"], "1") + # self.assertEqual(container.resources.requests["nvidia.com/gpu"], "4") + # self.assertEqual(container.resources.requests["cpu"], "16.0") + # self.assertEqual(container.resources.requests["memory"], "64.0Gi") - def test_accelerators_without_instance_type(self): - """Test that accelerators work without instance_type (fixes the main issue)""" - config = PyTorchJobConfig( - job_name="test-no-instance-type", - image="pytorch:latest", - accelerators=4 - # No instance_type specified - ) + # def test_accelerators_without_instance_type(self): + # """Test that accelerators work without instance_type (fixes the main issue)""" + # config = PyTorchJobConfig( + # job_name="test-no-instance-type", + # image="pytorch:latest", + # accelerators=4 + # # No instance_type specified + # ) - job = config.to_domain() - container = job.replicaSpecs[0].template.spec.containers[0] + # job = config.to_domain() + # container = job.replicaSpecs[0].template.spec.containers[0] - # Should respect accelerators value even without instance_type - self.assertEqual(container.resources.requests["nvidia.com/gpu"], "4") - # Limits should default to "0" since accelerators_limit not specified - self.assertEqual(container.resources.limits["nvidia.com/gpu"], "0") + # # Should respect accelerators value even without instance_type + # self.assertEqual(container.resources.requests["nvidia.com/gpu"], "4") + # # Limits should default to "0" since accelerators_limit not specified + # self.assertEqual(container.resources.limits["nvidia.com/gpu"], "0") if __name__ == '__main__': From 1098ff270c6dec4155e8988a89319154ebaf96fa Mon Sep 17 00:00:00 2001 From: Ophelia Yang <86372475+oyangz@users.noreply.github.com> Date: Tue, 18 Nov 2025 18:32:48 -0800 Subject: [PATCH 26/31] Update README for fractional gpu support (#294) * Update README for fractional gpu support * update pytorch job example * add example for accelerator partitions --- README.md | 42 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/README.md b/README.md index 9dbf7aa0..148fd677 100644 --- a/README.md +++ b/README.md @@ -302,6 +302,37 @@ hyp create hyp-pytorch-job \ --volume name=training-output,type=pvc,mount_path=/data2,claim_name=my-pvc,read_only=false ``` +**Example with accelerator parititons:** + +```bash +hyp create hyp-pytorch-job \ + --version 1.1 \ + --job-name test-pytorch-job \ + --image pytorch/pytorch:latest \ + --command '[python, train.py]' \ + --args '[--epochs=10, --batch-size=32]' \ + --environment '{"PYTORCH_CUDA_ALLOC_CONF": "max_split_size_mb:32"}' \ + --pull-policy "IfNotPresent" \ + --instance-type ml.p4d.24xlarge \ + --tasks-per-node 8 \ + --label-selector '{"accelerator": "nvidia", "network": "efa"}' \ + --deep-health-check-passed-nodes-only true \ + --scheduler-type "kueue" \ + --queue-name "training-queue" \ + --priority "high" \ + --max-retry 3 \ + --accelerator-partition-type "mig-1g.5gb" \ + --accelerator-partition-count 2 \ + --accelerator-partition-limit 4 \ + --vcpu 96.0 \ + --memory 1152.0 \ + --vcpu-limit 96.0 \ + --memory-limit 1152.0 \ + --preferred-topology "topology.kubernetes.io/zone=us-west-2a" \ + --volume name=model-data,type=hostPath,mount_path=/data,path=/data \ + --volume name=training-output,type=pvc,mount_path=/data2,claim_name=my-pvc,read_only=false +``` + | Parameter | Type | Required | Description | |-----------|------|----------|-------------| | `--job-name` | TEXT | Yes | Unique name for the training job (1-63 characters, alphanumeric with hyphens) | @@ -328,10 +359,21 @@ hyp create hyp-pytorch-job \ | `--accelerators-limit` | INTEGER | No | Limit for the number of accelerators a.k.a GPUs or Trainium Chips | | `--vcpu-limit` | FLOAT | No | Limit for the number of vCPUs | | `--memory-limit` | FLOAT | No | Limit for the amount of memory in GiB | +| `--accelerator-partition-type` | TEXT | No | Type of accelerator partition (e.g., mig-1g.5gb, mig-2g.10gb, mig-3g.20gb, mig-4g.20gb, mig-7g.40gb) | +| `--accelerator-partition-count` | INTEGER | No | Number of accelerator partitions to request (minimum: 1) | +| `--accelerator-partition-limit` | INTEGER | No | Limit for the number of accelerator partitions (minimum: 1) | | `--preferred-topology` | TEXT | No | Preferred topology annotation for scheduling | | `--required-topology` | TEXT | No | Required topology annotation for scheduling | | `--debug` | FLAG | No | Enable debug mode (default: false) | +#### List Available Accelerator Partition Types + +This command lists the available accelerator partition types on the cluster for a specific instance type. + +```bash +hyp list-accelerator-partition-type --instance-type +``` + #### List Training Jobs ```bash From f9a65829acdb2855d9759998652dc4ebc4aa6bc4 Mon Sep 17 00:00:00 2001 From: Mohamed Zeidan Date: Thu, 20 Nov 2025 21:08:01 -0800 Subject: [PATCH 27/31] merge conflicts from js template and inference --- .../registry.py | 12 +- .../v1_1/__init__.py | 12 + .../v1_1/model.py | 136 ++++++++ .../v1_1/schema.json | 132 ++++++++ .../v1_1/template.py | 21 ++ .../hyperpod/cli/commands/inference.py | 4 +- .../config/hp_jumpstart_endpoint_config.py | 21 ++ src/sagemaker/hyperpod/inference/constant.py | 58 ++++ .../inference/hp_jumpstart_endpoint.py | 98 +++++- test/unit_tests/cli/test_inference.py | 149 ++++++++- .../inference/test_hp_jumpstart_endpoint.py | 306 +++++++++++++++++- 11 files changed, 914 insertions(+), 35 deletions(-) create mode 100644 hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_1/__init__.py create mode 100644 hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_1/model.py create mode 100644 hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_1/schema.json create mode 100644 hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_1/template.py create mode 100644 src/sagemaker/hyperpod/inference/constant.py diff --git a/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/registry.py b/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/registry.py index d1abfdea..96b80a47 100644 --- a/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/registry.py +++ b/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/registry.py @@ -10,13 +10,17 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -from hyperpod_jumpstart_inference_template.v1_0 import model as v1 -from hyperpod_jumpstart_inference_template.v1_0.template import TEMPLATE_CONTENT as v1_template +from hyperpod_jumpstart_inference_template.v1_0 import model as v1_0 +from hyperpod_jumpstart_inference_template.v1_1 import model as v1_1 +from hyperpod_jumpstart_inference_template.v1_0.template import TEMPLATE_CONTENT as v1_0_template +from hyperpod_jumpstart_inference_template.v1_1.template import TEMPLATE_CONTENT as v1_1_template SCHEMA_REGISTRY = { - "1.0": v1.FlatHPJumpStartEndpoint, + "1.0": v1_0.FlatHPJumpStartEndpoint, + "1.1": v1_1.FlatHPJumpStartEndpoint, } TEMPLATE_REGISTRY = { - "1.0": v1_template + "1.0": v1_0_template, + "1.1": v1_1_template, } diff --git a/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_1/__init__.py b/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_1/__init__.py new file mode 100644 index 00000000..68054b98 --- /dev/null +++ b/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_1/__init__.py @@ -0,0 +1,12 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. \ No newline at end of file diff --git a/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_1/model.py b/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_1/model.py new file mode 100644 index 00000000..3b428f13 --- /dev/null +++ b/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_1/model.py @@ -0,0 +1,136 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from pydantic import BaseModel, Field, model_validator, ConfigDict +from typing import Optional + +# reuse the nested types +from sagemaker.hyperpod.inference.config.hp_jumpstart_endpoint_config import ( + Model, + SageMakerEndpoint, + Server, + TlsConfig, + Validations, +) +from sagemaker.hyperpod.inference.hp_jumpstart_endpoint import HPJumpStartEndpoint +from sagemaker.hyperpod.common.config.metadata import Metadata + + +class FlatHPJumpStartEndpoint(BaseModel): + model_config = ConfigDict(extra="forbid") + + namespace: Optional[str] = Field( + default=None, description="Kubernetes namespace", min_length=1 + ) + + accept_eula: bool = Field( + False, + alias="accept_eula", + description="Whether model terms of use have been accepted", + ) + + metadata_name: Optional[str] = Field( + None, + alias="metadata_name", + description="Name of the jumpstart endpoint object", + max_length=63, + pattern=r"^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}$", + ) + + model_id: str = Field( + ..., + alias="model_id", + description="Unique identifier of the model within the hub", + min_length=1, + max_length=63, + pattern=r"^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}$", + ) + + model_version: Optional[str] = Field( + None, + alias="model_version", + description="Semantic version of the model to deploy (e.g. 1.0.0)", + min_length=5, + max_length=14, + pattern=r"^\d{1,4}\.\d{1,4}\.\d{1,4}$", + ) + + instance_type: str = Field( + ..., + alias="instance_type", + description="EC2 instance type for the inference server", + pattern=r"^ml\..*", + ) + + accelerator_partition_type: Optional[str] = Field( + None, + alias="accelerator_partition_type", + description="MIG profile to use for GPU partitioning", + pattern=r"^mig-.*$", + ) + + accelerator_partition_validation: Optional[bool] = Field( + True, + alias="accelerator_partition_validation", + description="Enable MIG validation for GPU partitioning. Default is true." + ) + + endpoint_name: Optional[str] = Field( + None, + alias="endpoint_name", + description="Name of SageMaker endpoint; empty string means no creation", + max_length=63, + pattern=r"^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}$", + ) + tls_certificate_output_s3_uri: Optional[str] = Field( + None, + alias="tls_certificate_output_s3_uri", + description="S3 URI to write the TLS certificate", + pattern=r"^s3://([^/]+)/?(.*)$", + ) + + @model_validator(mode="after") + def validate_name(self): + if not self.metadata_name and not self.endpoint_name: + raise ValueError("Either metadata_name or endpoint_name must be provided") + return self + + def to_domain(self) -> HPJumpStartEndpoint: + if self.endpoint_name and not self.metadata_name: + self.metadata_name = self.endpoint_name + + metadata = Metadata(name=self.metadata_name, namespace=self.namespace) + + model = Model( + accept_eula=self.accept_eula, + model_id=self.model_id, + model_version=self.model_version, + ) + validations = Validations( + accelerator_partition_validation=self.accelerator_partition_validation, + ) + server = Server( + instance_type=self.instance_type, + accelerator_partition_type=self.accelerator_partition_type, + validations=validations, + ) + sage_ep = SageMakerEndpoint(name=self.endpoint_name) + tls = TlsConfig( + tls_certificate_output_s3_uri=self.tls_certificate_output_s3_uri + ) + return HPJumpStartEndpoint( + metadata=metadata, + model=model, + server=server, + sage_maker_endpoint=sage_ep, + tls_config=tls, + ) diff --git a/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_1/schema.json b/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_1/schema.json new file mode 100644 index 00000000..df966f63 --- /dev/null +++ b/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_1/schema.json @@ -0,0 +1,132 @@ +{ + "additionalProperties": false, + "properties": { + "namespace": { + "anyOf": [ + { + "minLength": 1, + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Kubernetes namespace", + "title": "Namespace" + }, + "accept_eula": { + "default": false, + "description": "Whether model terms of use have been accepted", + "title": "Accept Eula", + "type": "boolean" + }, + "metadata_name": { + "anyOf": [ + { + "maxLength": 63, + "pattern": "^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}$", + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Name of the jumpstart endpoint object", + "title": "Metadata Name" + }, + "model_id": { + "description": "Unique identifier of the model within the hub", + "maxLength": 63, + "minLength": 1, + "pattern": "^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}$", + "title": "Model Id", + "type": "string" + }, + "model_version": { + "anyOf": [ + { + "maxLength": 14, + "minLength": 5, + "pattern": "^\\d{1,4}\\.\\d{1,4}\\.\\d{1,4}$", + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Semantic version of the model to deploy (e.g. 1.0.0)", + "title": "Model Version" + }, + "instance_type": { + "description": "EC2 instance type for the inference server", + "pattern": "^ml\\..*", + "title": "Instance Type", + "type": "string" + }, + "accelerator_partition_type": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "MIG profile to use for GPU partitioning", + "pattern": "^mig-.*$", + "title": "Accelerator Partition Type" + }, + "accelerator_partition_validation": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "default": true, + "description": "Enable MIG validation for GPU partitioning. Default is true.", + "title": "Accelerator Partition Validation" + }, + "endpoint_name": { + "anyOf": [ + { + "maxLength": 63, + "pattern": "^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}$", + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Name of SageMaker endpoint; empty string means no creation", + "title": "Endpoint Name" + }, + "tls_certificate_output_s3_uri": { + "anyOf": [ + { + "pattern": "^s3://([^/]+)/?(.*)$", + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "S3 URI to write the TLS certificate", + "title": "Tls Certificate Output S3 Uri" + } + }, + "required": [ + "model_id", + "instance_type" + ], + "title": "FlatHPJumpStartEndpoint", + "type": "object" +} \ No newline at end of file diff --git a/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_1/template.py b/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_1/template.py new file mode 100644 index 00000000..580cf514 --- /dev/null +++ b/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_1/template.py @@ -0,0 +1,21 @@ +TEMPLATE_CONTENT = """ +apiVersion: inference.sagemaker.aws.amazon.com/v1alpha1 +kind: JumpStartModel +metadata: + name: {{ model_id }} + namespace: {{ namespace or "default" }} +spec: + model: + acceptEula: {{ accept_eula or false }} + modelHubName: "SageMakerPublicHub" + modelId: {{ model_id }} + modelVersion: {{ model_version or "" }} + sageMakerEndpoint: + name: {{ endpoint_name or "" }} + server: + instanceType: {{ instance_type }} + {% if accelerator_partition_type is not none %}acceleratorPartitionType: "{{ accelerator_partition_type }}"{% endif %} + {% if accelerator_partition_validation is not none %}validations: + {% if accelerator_partition_validation is not none %} acceleratorPartitionValidation: {{ accelerator_partition_validation }}{% endif %} + {% endif %} +""" \ No newline at end of file diff --git a/src/sagemaker/hyperpod/cli/commands/inference.py b/src/sagemaker/hyperpod/cli/commands/inference.py index f63cb590..20440dc4 100644 --- a/src/sagemaker/hyperpod/cli/commands/inference.py +++ b/src/sagemaker/hyperpod/cli/commands/inference.py @@ -20,7 +20,7 @@ # CREATE @click.command("hyp-jumpstart-endpoint") -@click.option("--version", default="1.0", help="Schema version to use") +@click.option("--version", default="1.1", help="Schema version to use") @click.option("--debug", default=False, help="Enable debug mode") @generate_click_command( schema_pkg="hyperpod_jumpstart_inference_template", @@ -37,7 +37,7 @@ def js_create(version, debug, js_endpoint): @click.command("hyp-custom-endpoint") -@click.option("--version", default="1.0", help="Schema version to use") +@click.option("--version", default="1.1", help="Schema version to use") @click.option("--debug", default=False, help="Enable debug mode") @generate_click_command( schema_pkg="hyperpod_custom_inference_template", diff --git a/src/sagemaker/hyperpod/inference/config/hp_jumpstart_endpoint_config.py b/src/sagemaker/hyperpod/inference/config/hp_jumpstart_endpoint_config.py index ff4e4fc6..5e971868 100644 --- a/src/sagemaker/hyperpod/inference/config/hp_jumpstart_endpoint_config.py +++ b/src/sagemaker/hyperpod/inference/config/hp_jumpstart_endpoint_config.py @@ -255,6 +255,16 @@ class SageMakerEndpoint(BaseModel): ) +class Validations(BaseModel): + model_config = ConfigDict(extra='forbid') + + acceleratorPartitionValidation: Optional[bool] = Field( + default=True, + alias="accelerator_partition_validation", + description="Enable MIG validation for GPU partitioning. Default is true." + ) + + class Server(BaseModel): model_config = ConfigDict(extra="forbid") @@ -268,6 +278,17 @@ class Server(BaseModel): description="The EC2 instance type to use for the inference server. Must be one of the supported types.", ) + acceleratorPartitionType: Optional[str] = Field( + default=None, + alias="accelerator_partition_type", + description="MIG profile to use for GPU partitioning" + ) + + validations: Optional[Validations] = Field( + default=None, + description="Validations configuration for the server" + ) + class TlsConfig(BaseModel): model_config = ConfigDict(extra="forbid") diff --git a/src/sagemaker/hyperpod/inference/constant.py b/src/sagemaker/hyperpod/inference/constant.py new file mode 100644 index 00000000..edf6fa78 --- /dev/null +++ b/src/sagemaker/hyperpod/inference/constant.py @@ -0,0 +1,58 @@ +INSTANCE_MIG_PROFILES = { + "ml.p4d.24xlarge": [ + "mig-1g.5gb", + "mig-1g.10gb", + "mig-2g.10gb", + "mig-3g.20gb", + "mig-4g.20gb", + "mig-7g.40gb" + ], + "ml.p4de.24xlarge": [ + "mig-1g.5gb", + "mig-1g.10gb", + "mig-2g.10gb", + "mig-3g.20gb", + "mig-4g.20gb", + "mig-7g.40gb" + ], + "ml.p5.48xlarge": [ + "mig-1g.10gb", + "mig-1g.20gb", + "mig-2g.20gb", + "mig-3g.40gb", + "mig-4g.40gb", + "mig-7g.80gb" + ], + "ml.p5e.48xlarge": [ + "mig-1g.18gb", + "mig-1g.35gb", + "mig-2g.35gb", + "mig-3g.71gb", + "mig-4g.71gb", + "mig-7g.141gb" + ], + "ml.p5en.48xlarge": [ + "mig-1g.18gb", + "mig-1g.35gb", + "mig-2g.35gb", + "mig-3g.71gb", + "mig-4g.71gb", + "mig-7g.141gb" + ], + "p6-b200.48xlarge": [ + "mig-1g.23gb", + "mig-1g.47gb", + "mig-2g.47gb", + "mig-3g.93gb", + "mig-4g.93gb", + "mig-7g.186gb" + ], + "ml.p6e-gb200.36xlarge": [ + "mig-1g.23gb", + "mig-1g.47gb", + "mig-2g.47gb", + "mig-3g.93gb", + "mig-4g.93gb", + "mig-7g.186gb" + ] +} \ No newline at end of file diff --git a/src/sagemaker/hyperpod/inference/hp_jumpstart_endpoint.py b/src/sagemaker/hyperpod/inference/hp_jumpstart_endpoint.py index d406dc07..e98f7dec 100644 --- a/src/sagemaker/hyperpod/inference/hp_jumpstart_endpoint.py +++ b/src/sagemaker/hyperpod/inference/hp_jumpstart_endpoint.py @@ -1,6 +1,7 @@ from typing import Dict, List, Optional from pydantic import Field, ValidationError from sagemaker.hyperpod.inference.config.constants import * +from sagemaker.hyperpod.inference.constant import INSTANCE_MIG_PROFILES from sagemaker.hyperpod.inference.hp_endpoint_base import HPEndpointBase from sagemaker.hyperpod.common.config.metadata import Metadata from sagemaker.hyperpod.common.utils import ( @@ -40,7 +41,7 @@ def _create_internal(self, spec, debug=False): endpoint_name = spec.sageMakerEndpoint.name if not endpoint_name and not name: - raise Exception('Either metadata name or endpoint name must be provided') + raise Exception("Either metadata name or endpoint name must be provided") if not name: name = endpoint_name @@ -48,6 +49,7 @@ def _create_internal(self, spec, debug=False): if not namespace: namespace = get_default_namespace() + # Create metadata object with labels and annotations if available metadata = Metadata( name=name, @@ -56,7 +58,11 @@ def _create_internal(self, spec, debug=False): annotations=self.metadata.annotations if self.metadata else None, ) - self.validate_instance_type(spec.model.modelId, spec.server.instanceType) + # Only validate instance type if accelerator_partition_validation is provided + if not spec.server.acceleratorPartitionType: + self.validate_instance_type(spec.model.modelId, spec.server.instanceType) + else: + self.validate_mig_profile(spec.server.acceleratorPartitionType, spec.server.instanceType) self.call_create_api( metadata=metadata, @@ -76,17 +82,57 @@ def create( self, debug=False ) -> None: + logger = self.get_logger() + logger = setup_logging(logger, debug) spec = _HPJumpStartEndpoint(**self.model_dump(by_alias=True, exclude_none=True)) self._create_internal(spec, debug) + @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "create_js_endpoint_from_dict") - def create_from_dict( - self, - input: Dict, - debug = False - ) -> None: + def create_from_dict(self, input: Dict, debug=False) -> None: + logger = self.get_logger() + logger = setup_logging(logger, debug) + spec = _HPJumpStartEndpoint.model_validate(input, by_name=True) - self._create_internal(spec, debug) + + endpoint_name = "" + name = self.metadata.name if self.metadata else None + namespace = self.metadata.namespace if self.metadata else None + + if spec.sageMakerEndpoint and spec.sageMakerEndpoint.name: + endpoint_name = spec.sageMakerEndpoint.name + + if not endpoint_name and not name: + raise Exception('Input "name" is required if endpoint name is not provided') + + if not name: + name = endpoint_name + + if not namespace: + namespace = get_default_namespace() + + # Only validate instance type if accelerator_partition_validation is provided + if not spec.server.acceleratorPartitionType: + self.validate_instance_type(spec.model.modelId, spec.server.instanceType) + else: + self.validate_mig_profile(spec.server.acceleratorPartitionType, spec.server.instanceType) + + self.call_create_api( + name=name, # use model name as metadata name + kind=JUMPSTART_MODEL_KIND, + namespace=namespace, + spec=spec, + debug=debug, + ) + + self.metadata = Metadata( + name=name, + namespace=namespace, + ) + + logger.info( + f"Creating JumpStart model and sagemaker endpoint. Endpoint name: {endpoint_name}.\n The process may take a few minutes..." + ) def refresh(self): @@ -224,6 +270,40 @@ def validate_instance_type(self, model_id: str, instance_type: str): f"Current HyperPod cluster does not have instance type {instance_type}. Supported instance types are {cluster_instance_types}" ) + def validate_mig_profile(self, mig_profile: str, instance_type: str): + """ + Validate if the MIG profile is supported for the given instance type. + + Args: + instance_type: SageMaker instance type (e.g., "ml.p4d.24xlarge") + mig_profile: MIG profile (e.g., "1g.10gb") + + Raises: + ValueError: If the instance type doesn't support MIG profiles or if the MIG profile is not supported for the instance type + """ + logger = self.get_logger() + logger = setup_logging(logger) + + if instance_type not in INSTANCE_MIG_PROFILES: + error_msg = ( + f"Instance type '{instance_type}' does not support MIG profiles. " + f"Supported instance types: {list(INSTANCE_MIG_PROFILES.keys())}" + ) + logger.error(error_msg) + raise ValueError(error_msg) + + if mig_profile not in INSTANCE_MIG_PROFILES[instance_type]: + error_msg = ( + f"MIG profile '{mig_profile}' is not supported for instance type '{instance_type}'. " + f"Supported MIG profiles for {instance_type}: {INSTANCE_MIG_PROFILES[instance_type]}" + ) + logger.error(error_msg) + raise ValueError(error_msg) + + logger.info( + f"MIG profile '{mig_profile}' is valid for instance type '{instance_type}'" + ) + @classmethod @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "list_pods_endpoint") def list_pods(cls, namespace=None, endpoint_name=None): @@ -255,4 +335,4 @@ def list_pods(cls, namespace=None, endpoint_name=None): # out the pods that are created by jumpstart endpoint pods.append(item.metadata.name) - return pods + return pods \ No newline at end of file diff --git a/test/unit_tests/cli/test_inference.py b/test/unit_tests/cli/test_inference.py index c9e3e695..a85c1c00 100644 --- a/test/unit_tests/cli/test_inference.py +++ b/test/unit_tests/cli/test_inference.py @@ -29,6 +29,7 @@ # --------- JumpStart Commands --------- @patch("sys.argv", ["pytest", "--version", "1.0"]) + def test_js_create_with_required_args(): """ Test js_create with all required options via CLI runner, mocking schema and endpoint. @@ -47,11 +48,82 @@ def test_js_create_with_required_args(): "sagemaker.hyperpod.common.cli_decorators._is_valid_jumpstart_model_id" ) as mock_model_validation, patch( "sagemaker.hyperpod.common.cli_decorators._namespace_exists" - ) as mock_namespace_exists: + ) as mock_namespace_exists, patch( + "sagemaker.hyperpod.inference.hp_jumpstart_endpoint.HPJumpStartEndpoint.validate_instance_type" + ) as mock_validate_instance, patch( + "sagemaker.hyperpod.common.utils.get_jumpstart_model_instance_types" + ) as mock_get_instance_types, patch( + "sagemaker.hyperpod.common.utils.get_cluster_instance_types" + ) as mock_get_cluster_types, patch( + "sagemaker.hyperpod.inference.hp_jumpstart_endpoint.HPJumpStartEndpoint.create" + ) as mock_create: # Mock enhanced error handling mock_model_validation.return_value = True # Allow test model-id mock_namespace_exists.return_value = True # Allow test namespace + mock_validate_instance.return_value = None # Skip validation + mock_get_instance_types.return_value = [ + "ml.p4d.24xlarge" + ] # Mock supported types + mock_get_cluster_types.return_value = ["ml.p4d.24xlarge"] # Mock cluster types + mock_create.return_value = None # Mock successful creation + + # Prepare mock model-to-domain mapping + mock_model_class = Mock() + mock_model_instance = Mock() + domain_obj = Mock() + domain_obj.create = mock_create + mock_model_instance.to_domain.return_value = domain_obj + mock_model_class.return_value = mock_model_instance + + # Set up the registry for version 1.0 + jreg.SCHEMA_REGISTRY["1.0"] = mock_model_class + + runner = CliRunner() + result = runner.invoke( + js_create, + [ + "--namespace", + "test-ns", + "--version", + "1.0", + "--model-id", + "test-model-id", + "--instance-type", + "ml.p4d.24xlarge", # Use a supported instance type + "--endpoint-name", + "test-endpoint", + ], + ) + + assert result.exit_code == 0, result.output + mock_create.assert_called_once_with(debug=False) + + +def test_js_create_missing_required_args(): + runner = CliRunner() + result = runner.invoke(js_create, []) + assert result.exit_code != 0 + assert "Missing option" in result.output + + +def test_js_create_with_mig_profile(): + """ + Test js_create with MIG profile (accelerator partition) options using v1.1 schema. + """ + with patch( + "sagemaker.hyperpod.cli.inference_utils.load_schema_for_version" + ) as mock_load_schema, patch( + "sagemaker.hyperpod.cli.commands.inference.HPJumpStartEndpoint" + ) as mock_endpoint_class, patch( + "sagemaker.hyperpod.common.cli_decorators._is_valid_jumpstart_model_id" + ) as mock_model_validation, patch( + "sagemaker.hyperpod.common.cli_decorators._namespace_exists" + ) as mock_namespace_exists: + + # Mock enhanced error handling + mock_model_validation.return_value = True + mock_namespace_exists.return_value = True # Mock schema loading mock_load_schema.return_value = { @@ -71,7 +143,7 @@ def test_js_create_with_required_args(): mock_endpoint_class.model_construct.return_value = domain_obj jreg.SCHEMA_REGISTRY.clear() - jreg.SCHEMA_REGISTRY["1.0"] = mock_model_class + jreg.SCHEMA_REGISTRY["1.1"] = mock_model_class runner = CliRunner() result = runner.invoke( @@ -80,11 +152,15 @@ def test_js_create_with_required_args(): "--namespace", "test-ns", "--version", - "1.0", + "1.1", "--model-id", "test-model-id", "--instance-type", - "ml.t2.micro", + "ml.p4d.24xlarge", + "--accelerator-partition-type", + "mig-1g.5gb", + "--accelerator-partition-validation", + "true", "--endpoint-name", "test-endpoint", ], @@ -93,6 +169,12 @@ def test_js_create_with_required_args(): assert result.exit_code == 0, result.output domain_obj.create.assert_called_once_with(debug=False) + # Verify the model instance was created with MIG profile parameters + mock_model_class.assert_called_once() + call_args = mock_model_class.call_args[1] + assert "accelerator_partition_type" in call_args + assert "accelerator_partition_validation" in call_args + def test_js_create_missing_required_args(): runner = CliRunner() @@ -101,6 +183,63 @@ def test_js_create_missing_required_args(): assert "Missing option" in result.output +def test_js_create_mig_validation_error_handling(): + """ + Test js_create properly handles MIG profile validation errors using v1.1 schema. + """ + with patch( + "sagemaker.hyperpod.cli.commands.inference.HPJumpStartEndpoint" + ) as mock_endpoint_class, patch( + "sagemaker.hyperpod.common.cli_decorators._is_valid_jumpstart_model_id" + ) as mock_model_validation, patch( + "sagemaker.hyperpod.common.cli_decorators._namespace_exists" + ) as mock_namespace_exists: + + # Mock enhanced error handling + mock_model_validation.return_value = True + mock_namespace_exists.return_value = True + + # Prepare mock model-to-domain mapping that raises validation error + mock_model_class = Mock() + mock_model_instance = Mock() + domain_obj = Mock() + # Simulate MIG validation error during create + domain_obj.create.side_effect = ValueError( + "MIG profile '1g.5gb' is not supported for instance type 'ml.c5.2xlarge'" + ) + mock_model_instance.to_domain.return_value = domain_obj + mock_model_class.return_value = mock_model_instance + mock_endpoint_class.model_construct.return_value = domain_obj + + # Set up the registry for version 1.1 + jreg.SCHEMA_REGISTRY["1.1"] = mock_model_class + + runner = CliRunner() + result = runner.invoke( + js_create, + [ + "--namespace", + "test-ns", + "--version", + "1.1", + "--model-id", + "test-model-id", + "--instance-type", + "ml.c5.2xlarge", # Instance type that doesn't support MIG + "--accelerator-partition-type", + "1g.5gb", # Invalid MIG profile for this instance + "--accelerator-partition-validation", + "true", + "--endpoint-name", + "test-endpoint", + ], + ) + + # Should fail due to MIG validation error + assert result.exit_code != 0 + assert "MIG profile" in result.output or "not supported" in result.output + + @patch("sagemaker.hyperpod.common.cli_decorators._namespace_exists") @patch("sagemaker.hyperpod.cli.commands.inference.HPJumpStartEndpoint") def test_js_list(mock_hp, mock_namespace_exists): @@ -497,4 +636,4 @@ def test_custom_create_with_intelligent_routing_and_kv_cache(): ) assert result.exit_code == 0, result.output - domain_obj.create.assert_called_once_with(debug=False) + domain_obj.create.assert_called_once_with(debug=False) \ No newline at end of file diff --git a/test/unit_tests/inference/test_hp_jumpstart_endpoint.py b/test/unit_tests/inference/test_hp_jumpstart_endpoint.py index 09999b56..a418dea9 100644 --- a/test/unit_tests/inference/test_hp_jumpstart_endpoint.py +++ b/test/unit_tests/inference/test_hp_jumpstart_endpoint.py @@ -7,9 +7,11 @@ Server, SageMakerEndpoint, TlsConfig, + Validations, ) from sagemaker.hyperpod.common.config import Metadata + class TestHPJumpStartEndpoint(unittest.TestCase): def setUp(self): @@ -35,8 +37,13 @@ def setUp(self): @patch.object(HPJumpStartEndpoint, "validate_instance_type") @patch.object(HPJumpStartEndpoint, "call_create_api") - @patch('sagemaker.hyperpod.inference.hp_jumpstart_endpoint.get_default_namespace', return_value='default') - def test_create(self, mock_get_namespace, mock_create_api, mock_validate_instance_type): + @patch( + "sagemaker.hyperpod.inference.hp_jumpstart_endpoint.get_default_namespace", + return_value="default", + ) + def test_create( + self, mock_get_namespace, mock_create_api, mock_validate_instance_type + ): self.endpoint.create() @@ -48,18 +55,17 @@ def test_create(self, mock_get_namespace, mock_create_api, mock_validate_instanc ) self.assertEqual(self.endpoint.metadata.name, "bert-testing-jumpstart-7-2-2") - @patch.object(HPJumpStartEndpoint, "validate_instance_type") @patch.object(HPJumpStartEndpoint, "call_create_api") def test_create_with_metadata(self, mock_create_api, mock_validate_instance_type): """Test create_from_dict uses metadata name and namespace when endpoint name not provided""" - + # Create endpoint without sageMakerEndpoint name to force using metadata endpoint_without_name = HPJumpStartEndpoint( model=Model(model_id="test-model"), server=Server(instance_type="ml.c5.2xlarge"), tls_config=TlsConfig(tls_certificate_output_s3_uri="s3://test-bucket"), - metadata=Metadata(name="metadata-test-name", namespace="metadata-test-ns") + metadata=Metadata(name="metadata-test-name", namespace="metadata-test-ns"), ) endpoint_without_name.create() @@ -73,8 +79,13 @@ def test_create_with_metadata(self, mock_create_api, mock_validate_instance_type @patch.object(HPJumpStartEndpoint, "validate_instance_type") @patch.object(HPJumpStartEndpoint, "call_create_api") - @patch('sagemaker.hyperpod.inference.hp_jumpstart_endpoint.get_default_namespace', return_value='default') - def test_create_from_dict(self, mock_get_namespace, mock_create_api, mock_validate_instance_type): + @patch( + "sagemaker.hyperpod.inference.hp_jumpstart_endpoint.get_default_namespace", + return_value="default", + ) + def test_create_from_dict( + self, mock_get_namespace, mock_create_api, mock_validate_instance_type + ): input_dict = self.endpoint.model_dump(exclude_none=True) @@ -178,13 +189,7 @@ def test_list_pods(self, mock_verify_config, mock_core_api, mock_list_api): mock_pod3, ] - mock_list_api.return_value = { - "items": [ - { - "metadata": {"name": "js-endpoint"} - } - ] - } + mock_list_api.return_value = {"items": [{"metadata": {"name": "js-endpoint"}}]} result = self.endpoint.list_pods(namespace="test-ns") @@ -211,9 +216,280 @@ def test_list_pods_with_endpoint_name(self, mock_verify_config, mock_core_api): mock_pod3, ] - result = self.endpoint.list_pods(namespace="test-ns", endpoint_name="js-endpoint1") + result = self.endpoint.list_pods( + namespace="test-ns", endpoint_name="js-endpoint1" + ) self.assertEqual(result, ["js-endpoint1-pod1", "js-endpoint1-pod2"]) mock_core_api.return_value.list_namespaced_pod.assert_called_once_with( namespace="test-ns" ) + + def test_validate_mig_profile_valid(self): + """Test validate_mig_profile with valid instance type and MIG profile""" + # Test with valid combinations + self.endpoint.validate_mig_profile("mig-1g.5gb", "ml.p4d.24xlarge") + self.endpoint.validate_mig_profile("mig-7g.40gb", "ml.p4d.24xlarge") + self.endpoint.validate_mig_profile("mig-1g.10gb", "ml.p4de.24xlarge") + self.endpoint.validate_mig_profile("mig-7g.80gb", "ml.p5.48xlarge") + + def test_validate_mig_profile_invalid_instance_type(self): + """Test validate_mig_profile with unsupported instance type""" + with self.assertRaises(ValueError) as context: + self.endpoint.validate_mig_profile("1g.5gb", "ml.c5.2xlarge") + + self.assertIn( + "Instance type 'ml.c5.2xlarge' does not support MIG profiles", + str(context.exception), + ) + self.assertIn("Supported instance types:", str(context.exception)) + + def test_validate_mig_profile_invalid_mig_profile(self): + """Test validate_mig_profile with unsupported MIG profile for valid instance type""" + with self.assertRaises(ValueError) as context: + self.endpoint.validate_mig_profile("invalid.profile", "ml.p4d.24xlarge") + + self.assertIn( + "MIG profile 'invalid.profile' is not supported for instance type 'ml.p4d.24xlarge'", + str(context.exception), + ) + self.assertIn( + "Supported MIG profiles for ml.p4d.24xlarge:", str(context.exception) + ) + + def test_validate_mig_profile_wrong_profile_for_instance(self): + """Test validate_mig_profile with MIG profile that exists but not for the specific instance type""" + # 7g.80gb is valid for p4de but not p4d + with self.assertRaises(ValueError) as context: + self.endpoint.validate_mig_profile("7g.80gb", "ml.p4d.24xlarge") + + self.assertIn( + "MIG profile '7g.80gb' is not supported for instance type 'ml.p4d.24xlarge'", + str(context.exception), + ) + + @patch.object(HPJumpStartEndpoint, "validate_mig_profile") + @patch.object(HPJumpStartEndpoint, "call_create_api") + @patch( + "sagemaker.hyperpod.inference.hp_jumpstart_endpoint.get_default_namespace", + return_value="default", + ) + def test_create_with_accelerator_partition_validation( + self, mock_get_namespace, mock_create_api, mock_validate_mig + ): + """Test create method uses MIG validation when accelerator_partition_validation is True""" + # Create endpoint with accelerator partition validation enabled + model = Model(model_id="test-model") + validations = Validations( + accelerator_partition_validation=True, + ) + server = Server( + instance_type="ml.p4d.24xlarge", + validations=validations, + accelerator_partition_type="1g.5gb", + ) + endpoint = HPJumpStartEndpoint( + model=model, + server=server, + sage_maker_endpoint=SageMakerEndpoint(name="test-endpoint"), + tls_config=TlsConfig(tls_certificate_output_s3_uri="s3://test-bucket"), + ) + + endpoint.create() + + # Should call validate_mig_profile instead of validate_instance_type + mock_validate_mig.assert_called_once_with("1g.5gb", "ml.p4d.24xlarge") + mock_create_api.assert_called_once() + + @patch.object(HPJumpStartEndpoint, "validate_instance_type") + @patch.object(HPJumpStartEndpoint, "call_create_api") + @patch( + "sagemaker.hyperpod.inference.hp_jumpstart_endpoint.get_default_namespace", + return_value="default", + ) + def test_create_without_accelerator_partition_validation( + self, mock_get_namespace, mock_create_api, mock_validate_instance + ): + """Test create method uses instance type validation when accelerator_partition_validation is False/None""" + # Create endpoint without accelerator partition validation (default behavior) + model = Model(model_id="test-model") + server = Server(instance_type="ml.c5.2xlarge") + endpoint = HPJumpStartEndpoint( + model=model, + server=server, + sage_maker_endpoint=SageMakerEndpoint(name="test-endpoint"), + tls_config=TlsConfig(tls_certificate_output_s3_uri="s3://test-bucket"), + ) + + endpoint.create() + + # Should call validate_instance_type instead of validate_mig_profile + mock_validate_instance.assert_called_once_with("test-model", "ml.c5.2xlarge") + mock_create_api.assert_called_once() + + @patch.object(HPJumpStartEndpoint, "validate_mig_profile") + @patch.object(HPJumpStartEndpoint, "call_create_api") + @patch( + "sagemaker.hyperpod.inference.hp_jumpstart_endpoint.get_default_namespace", + return_value="default", + ) + def test_create_from_dict_with_accelerator_partition_validation( + self, mock_get_namespace, mock_create_api, mock_validate_mig + ): + """Test create_from_dict method uses MIG validation when accelerator_partition_validation is True""" + input_dict = { + "model": {"modelId": "test-model"}, + "server": { + "instanceType": "ml.p4d.24xlarge", + "validations": { + "acceleratorPartitionValidation": True + }, + "acceleratorPartitionType": "1g.5gb", + }, + "sageMakerEndpoint": {"name": "test-endpoint"}, + "tlsConfig": {"tlsCertificateOutputS3Uri": "s3://test-bucket"}, + } + + endpoint = HPJumpStartEndpoint( + model=Model(model_id="dummy"), + server=Server(instance_type="dummy"), + tls_config=TlsConfig(tls_certificate_output_s3_uri="s3://dummy"), + ) + endpoint.create_from_dict(input_dict) + + # Should call validate_mig_profile instead of validate_instance_type + mock_validate_mig.assert_called_once_with("1g.5gb", "ml.p4d.24xlarge") + mock_create_api.assert_called_once() + + @patch.object(HPJumpStartEndpoint, "validate_instance_type") + @patch.object(HPJumpStartEndpoint, "call_create_api") + @patch( + "sagemaker.hyperpod.inference.hp_jumpstart_endpoint.get_default_namespace", + return_value="default", + ) + def test_create_from_dict_without_accelerator_partition_validation( + self, mock_get_namespace, mock_create_api, mock_validate_instance + ): + """Test create_from_dict method uses instance type validation when accelerator_partition_validation is False/None""" + input_dict = { + "model": {"modelId": "test-model"}, + "server": {"instanceType": "ml.c5.2xlarge"}, + "sageMakerEndpoint": {"name": "test-endpoint"}, + "tlsConfig": {"tlsCertificateOutputS3Uri": "s3://test-bucket"}, + } + + endpoint = HPJumpStartEndpoint( + model=Model(model_id="dummy"), + server=Server(instance_type="dummy"), + tls_config=TlsConfig(tls_certificate_output_s3_uri="s3://dummy"), + ) + endpoint.create_from_dict(input_dict) + + # Should call validate_instance_type instead of validate_mig_profile + mock_validate_instance.assert_called_once_with("test-model", "ml.c5.2xlarge") + mock_create_api.assert_called_once() + + def test_validate_mig_profile_edge_cases(self): + """Test validate_mig_profile with various edge cases""" + # Test with different instance types and their specific profiles + test_cases = [ + ("ml.p4de.24xlarge", "mig-1g.5gb"), + ("ml.p5.48xlarge", "mig-3g.40gb"), + ("ml.p5e.48xlarge", "mig-1g.18gb"), + ("ml.p5en.48xlarge", "mig-7g.141gb"), + ("p6-b200.48xlarge", "mig-1g.23gb"), + ("ml.p6e-gb200.36xlarge", "mig-7g.186gb"), + ] + + for instance_type, mig_profile in test_cases: + with self.subTest(instance_type=instance_type, mig_profile=mig_profile): + # Should not raise any exception + self.endpoint.validate_mig_profile(mig_profile, instance_type) + + def test_validate_mig_profile_case_sensitivity(self): + """Test that MIG profile validation is case sensitive""" + with self.assertRaises(ValueError): + # Test uppercase - should fail as profiles are lowercase + self.endpoint.validate_mig_profile("1G.5GB", "ml.p4d.24xlarge") + + @patch.object(HPJumpStartEndpoint, "validate_mig_profile") + @patch.object(HPJumpStartEndpoint, "validate_instance_type") + @patch.object(HPJumpStartEndpoint, "call_create_api") + @patch( + "sagemaker.hyperpod.inference.hp_jumpstart_endpoint.get_default_namespace", + return_value="default", + ) + def test_create_validation_logic_priority( + self, + mock_get_namespace, + mock_create_api, + mock_validate_instance, + mock_validate_mig, + ): + """Test that accelerator_partition_validation takes priority over regular validation""" + # Create endpoint with both accelerator partition validation and regular fields + model = Model(model_id="test-model") + validations = Validations( + accelerator_partition_validation=True, + ) + server = Server( + instance_type="ml.p4d.24xlarge", + validations=validations, + accelerator_partition_type="1g.5gb", + ) + endpoint = HPJumpStartEndpoint( + model=model, + server=server, + sage_maker_endpoint=SageMakerEndpoint(name="test-endpoint"), + tls_config=TlsConfig(tls_certificate_output_s3_uri="s3://test-bucket"), + ) + + endpoint.create() + + # Should only call validate_mig_profile, not validate_instance_type + mock_validate_mig.assert_called_once_with("1g.5gb", "ml.p4d.24xlarge") + mock_validate_instance.assert_not_called() + mock_create_api.assert_called_once() + + def test_create_missing_name_and_endpoint_name(self): + """Test create method raises exception when both metadata name and endpoint name are missing""" + model = Model(model_id="test-model") + server = Server(instance_type="ml.c5.2xlarge") + endpoint = HPJumpStartEndpoint( + model=model, + server=server, + tls_config=TlsConfig(tls_certificate_output_s3_uri="s3://test-bucket"), + # No sageMakerEndpoint name and no metadata + ) + + with self.assertRaises(Exception) as context: + endpoint.create() + + self.assertIn( + "Either metadata name or endpoint name must be provided", + str(context.exception), + ) + + def test_create_from_dict_missing_name_and_endpoint_name(self): + """Test create_from_dict method raises exception when both name and endpoint name are missing""" + input_dict = { + "model": {"modelId": "test-model"}, + "server": {"instanceType": "ml.c5.2xlarge"}, + "tlsConfig": {"tlsCertificateOutputS3Uri": "s3://test-bucket"}, + # No sageMakerEndpoint name + } + + endpoint = HPJumpStartEndpoint( + model=Model(model_id="dummy"), + server=Server(instance_type="dummy"), + tls_config=TlsConfig(tls_certificate_output_s3_uri="s3://dummy"), + # No metadata + ) + + with self.assertRaises(Exception) as context: + endpoint.create_from_dict(input_dict) + + self.assertIn( + 'Input "name" is required if endpoint name is not provided', + str(context.exception), + ) \ No newline at end of file From a82841b88aaf1a7b1b3109d1405a8f6965b92545 Mon Sep 17 00:00:00 2001 From: Mohamed Zeidan Date: Thu, 20 Nov 2025 19:59:50 -0800 Subject: [PATCH 28/31] update changelog --- CHANGELOG.md | 12 ++++++++++++ hyperpod-custom-inference-template/CHANGELOG.md | 1 + hyperpod-jumpstart-inference-template/CHANGELOG.md | 6 ++++++ hyperpod-pytorch-job-template/CHANGELOG.md | 6 ++++++ hyperpod-space-template/CHANGELOG.md | 6 ++++++ setup.py | 2 +- 6 files changed, 32 insertions(+), 1 deletion(-) create mode 100644 hyperpod-space-template/CHANGELOG.md diff --git a/CHANGELOG.md b/CHANGELOG.md index 4dc1d7d3..e9377db3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,17 @@ # Changelog +## v.3.4.0 (2025-11-20) + +### Features + + * HyperPod Dev Spaces template for data scientists to create, manage, and access interactive ML development environments with configurable resource allocation and namespace isolation + * Support for KVCaching, intelligent routing, tiered storage, MIG + * Support for fractional gpu + * Support KVCache and Intelligent Routing support in template version 1.1 + * User can modify jinja template to add parameters supported by CRD through init experience, for further CLI customization + * MIG support for model deployment on SageMaker Hyperpod Inference + + ## v.3.3.1 (2025-10-30) ### Features diff --git a/hyperpod-custom-inference-template/CHANGELOG.md b/hyperpod-custom-inference-template/CHANGELOG.md index effe0b04..565df479 100644 --- a/hyperpod-custom-inference-template/CHANGELOG.md +++ b/hyperpod-custom-inference-template/CHANGELOG.md @@ -4,6 +4,7 @@ * Support KVCache and Intelligent Routing support in template version 1.1 * User can modify jinja template to add parameters supported by CRD through init experience, for further CLI customization +* Support for MIG ## v1.0.1 (2025-08-27) diff --git a/hyperpod-jumpstart-inference-template/CHANGELOG.md b/hyperpod-jumpstart-inference-template/CHANGELOG.md index d7f796de..9afbd9a2 100644 --- a/hyperpod-jumpstart-inference-template/CHANGELOG.md +++ b/hyperpod-jumpstart-inference-template/CHANGELOG.md @@ -1,3 +1,9 @@ +## v1.1.0 (2025-11-20) + +### Features + +* Support for KVCaching, intelligent routing, tiered storage, MIG + ## v1.0.3 (2025-10-30) ### Features diff --git a/hyperpod-pytorch-job-template/CHANGELOG.md b/hyperpod-pytorch-job-template/CHANGELOG.md index d525c429..b872a9c4 100644 --- a/hyperpod-pytorch-job-template/CHANGELOG.md +++ b/hyperpod-pytorch-job-template/CHANGELOG.md @@ -1,3 +1,9 @@ +## v1.2.0 (2025-11-20) + +### Features + +* Support for fractional gpu + ## v1.1.4 (2025-10-30) ### Features diff --git a/hyperpod-space-template/CHANGELOG.md b/hyperpod-space-template/CHANGELOG.md new file mode 100644 index 00000000..5c47e7f5 --- /dev/null +++ b/hyperpod-space-template/CHANGELOG.md @@ -0,0 +1,6 @@ +## v1.0.0 (2025-11-20) + +### Features + +* HyperPod Dev Spaces template for data scientists to create, manage, and access interactive ML development environments with configurable resource allocation and namespace isolation + diff --git a/setup.py b/setup.py index 14014833..1476b311 100644 --- a/setup.py +++ b/setup.py @@ -47,7 +47,7 @@ setup( data_files=sagemaker_hyperpod_recipes, name="sagemaker-hyperpod", - version="3.3.1", + version="3.4.0", description="Amazon SageMaker HyperPod SDK and CLI", long_description=open("README.md").read(), long_description_content_type="text/markdown", From 0c1cc04a79c79a8752bf47402fec91e1f057568c Mon Sep 17 00:00:00 2001 From: Mohamed Zeidan Date: Thu, 20 Nov 2025 21:25:04 -0800 Subject: [PATCH 29/31] uncommented install req --- setup.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 1476b311..e78ed227 100644 --- a/setup.py +++ b/setup.py @@ -90,9 +90,8 @@ "hyperpod-pytorch-job-template>=1.0.0, <2.0.0", "hyperpod-custom-inference-template>=1.0.0, <2.0.0", "hyperpod-jumpstart-inference-template>=1.0.0, <2.0.0", - "hyperpod-cluster-stack-template>=1.0.0, <2.0.0" - # TODO: need to uncomment before pushing to master - # "hyperpod_space_template>=1.0.0, <2.0.0" + "hyperpod-cluster-stack-template>=1.0.0, <2.0.0", + "hyperpod-space-template>=1.0.0, <2.0.0" ], entry_points={ "console_scripts": [ From 88762882f5dca93f97546569b06432e08e99ccb6 Mon Sep 17 00:00:00 2001 From: Mohamed Zeidan Date: Thu, 20 Nov 2025 21:51:40 -0800 Subject: [PATCH 30/31] uncommented --- setup.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index e78ed227..757ad6f8 100644 --- a/setup.py +++ b/setup.py @@ -47,7 +47,7 @@ setup( data_files=sagemaker_hyperpod_recipes, name="sagemaker-hyperpod", - version="3.4.0", + version="3.3.1", description="Amazon SageMaker HyperPod SDK and CLI", long_description=open("README.md").read(), long_description_content_type="text/markdown", @@ -90,8 +90,8 @@ "hyperpod-pytorch-job-template>=1.0.0, <2.0.0", "hyperpod-custom-inference-template>=1.0.0, <2.0.0", "hyperpod-jumpstart-inference-template>=1.0.0, <2.0.0", - "hyperpod-cluster-stack-template>=1.0.0, <2.0.0", - "hyperpod-space-template>=1.0.0, <2.0.0" + "hyperpod-cluster-stack-template>=1.0.0, <2.0.0" + "hyperpod_dev_space_template>=1.0.0, <2.0.0" ], entry_points={ "console_scripts": [ From 6455f6a2be41dcea8535e36290a92b429de9f82f Mon Sep 17 00:00:00 2001 From: Mohamed Zeidan Date: Thu, 20 Nov 2025 21:52:20 -0800 Subject: [PATCH 31/31] fixed uncomment --- setup.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 757ad6f8..23554355 100644 --- a/setup.py +++ b/setup.py @@ -47,7 +47,7 @@ setup( data_files=sagemaker_hyperpod_recipes, name="sagemaker-hyperpod", - version="3.3.1", + version="3.4.0", description="Amazon SageMaker HyperPod SDK and CLI", long_description=open("README.md").read(), long_description_content_type="text/markdown", @@ -91,7 +91,8 @@ "hyperpod-custom-inference-template>=1.0.0, <2.0.0", "hyperpod-jumpstart-inference-template>=1.0.0, <2.0.0", "hyperpod-cluster-stack-template>=1.0.0, <2.0.0" - "hyperpod_dev_space_template>=1.0.0, <2.0.0" + # TODO: need to uncomment before pushing to master + "hyperpod_space_template>=1.0.0, <2.0.0" ], entry_points={ "console_scripts": [