Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 22 additions & 54 deletions sagemaker-train/tests/integ/train/aws_batch/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,18 @@ class BatchTestResourceManager:
def __init__(
self,
batch_client,
test_id,
queue_name="pysdk-test-qm-queue",
service_env_name="pysdk-test-qm-queue-service-environment",
scheduling_policy_name="pysdk-test-qm-scheduling-policy",
quota_share_name="pysdk-test-quota-share",
):
self.test_id = test_id
self.batch_client = batch_client
self.queue_name = queue_name
self.service_environment_name = service_env_name
self.scheduling_policy_name = scheduling_policy_name
self.quota_share_name = quota_share_name
self.queue_name = f"{queue_name}-{test_id}"
self.service_environment_name = f"{service_env_name}-{test_id}"
self.scheduling_policy_name = f"{scheduling_policy_name}-{test_id}"
self.quota_share_name = f"{quota_share_name}-{test_id}"

def _create_or_get_service_environment(self, service_environment_name):
print(f"Creating service environment: {service_environment_name}")
Expand Down Expand Up @@ -277,65 +279,31 @@ def _delete_quota_share(self, quota_share_arn: str):
print("Waiting for QuotaShare deletion to finish...")
self._wait_for_quota_share_state(quota_share_arn, "DELETED", "DISABLED")

def get_or_create_resources(
self,
queue_name=None,
service_environment_name=None,
scheduling_policy_name=None,
quota_share_name=None
):
queue_name = queue_name or self.queue_name
service_environment_name = service_environment_name or self.service_environment_name
scheduling_policy_name = scheduling_policy_name or self.scheduling_policy_name
quota_share_name = quota_share_name or self.quota_share_name

service_environment = self._create_or_get_service_environment(service_environment_name)
if service_environment.get("state") != "ENABLED":
self._update_service_environment_state(service_environment_name, "ENABLED")
self._wait_for_service_environment_state(service_environment_name, "VALID", "ENABLED")
time.sleep(10)

scheduling_policy = self._create_or_get_scheduling_policy(scheduling_policy_name)
scheduling_policy_arn = scheduling_policy.get("arn")

queue = self._create_or_get_queue(queue_name, service_environment["serviceEnvironmentArn"],
scheduling_policy_arn)
if queue.get("state") != "ENABLED":
self._update_queue_state(queue_name, "ENABLED")
self._wait_for_queue_state(queue_name, "VALID", "ENABLED")
time.sleep(10)

quota_share = self._create_or_get_quota_share(quota_share_name, queue_name)
if quota_share.get("state") != "ENABLED":
self._update_quota_share_state(quota_share["quotaShareArn"], "ENABLED")
self._wait_for_quota_share_state(quota_share["quotaShareArn"], "VALID", "ENABLED")
time.sleep(10)
def get_or_create_resources(self):
service_environment = self._create_or_get_service_environment(self.service_environment_name)
scheduling_policy = self._create_or_get_scheduling_policy(self.scheduling_policy_name)

return queue, service_environment, scheduling_policy, quota_share
queue = self._create_or_get_queue(self.queue_name, service_environment["serviceEnvironmentArn"],
scheduling_policy.get("arn"))
self._wait_for_queue_state(self.queue_name, "VALID", "ENABLED")

def delete_resources(
self,
queue_name=None,
service_environment_name=None,
scheduling_policy_name=None,
quota_share_name=None
):
queue_name = queue_name or self.queue_name
service_environment_name = service_environment_name or self.service_environment_name
scheduling_policy_name = scheduling_policy_name or self.scheduling_policy_name
quota_share_name = quota_share_name or self.quota_share_name
quota_share = self._create_or_get_quota_share(self.quota_share_name, self.queue_name)
self._wait_for_quota_share_state(quota_share["quotaShareArn"], "VALID", "ENABLED")

return queue, service_environment, scheduling_policy, quota_share

def delete_resources(self):
# Get ARNs needed for deletion
desc_jq = self.batch_client.describe_job_queues(jobQueues=[queue_name])
desc_jq = self.batch_client.describe_job_queues(jobQueues=[self.queue_name])
if desc_jq["jobQueues"]:
jq_arn = desc_jq["jobQueues"][0]["jobQueueArn"]
quota_share_arn = f"{jq_arn}/quota-share/{quota_share_name}"
quota_share_arn = f"{jq_arn}/quota-share/{self.quota_share_name}"
self._delete_quota_share(quota_share_arn)

self._delete_job_queue(queue_name)
self._delete_job_queue(self.queue_name)

sp = self._find_scheduling_policy(scheduling_policy_name)
sp = self._find_scheduling_policy(self.scheduling_policy_name)
if sp:
self._delete_scheduling_policy(sp["arn"])

self._delete_service_environment(service_environment_name)
self._delete_service_environment(self.service_environment_name)
26 changes: 23 additions & 3 deletions sagemaker-train/tests/integ/train/aws_batch/test_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import boto3
import botocore
import pytest
import random
import string

from sagemaker.train.model_trainer import ModelTrainer
from sagemaker.train.configs import SourceCode, InputData, Compute
Expand All @@ -29,16 +31,34 @@
from .manager import BatchTestResourceManager


class ShortId:
ALPHABET = string.ascii_lowercase + string.digits
DEFAULT_LENGTH = 8

@staticmethod
def get(length=DEFAULT_LENGTH):
return "".join(random.choices(ShortId.ALPHABET, k=length))


@pytest.fixture(scope="module")
def batch_client():
return boto3.client("batch", region_name="us-west-2")


@pytest.fixture(scope="function")
def batch_test_resource_manager(batch_client):
resource_manager = BatchTestResourceManager(batch_client=batch_client)
resource_manager.get_or_create_resources()
yield resource_manager
# Guarantee AWS Batch resource name uniqueness across concurrent test runtimes
test_id = ShortId.get()
print(f"Integration test ID (used in AWS Batch resource naming): {test_id}")

resource_manager = BatchTestResourceManager(batch_client=batch_client, test_id=test_id)

try:
resource_manager.get_or_create_resources()
yield resource_manager
except Exception as e:
print(f"Exception thrown while creating or yielding AWS Batch resources: {str(e)}")

resource_manager.delete_resources()


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@
"metadata": {},
"source": [
"## Create TrainingQueue object\n",
"Using our queue is as easy as referring to it by name in the TrainingQueue contructor. The TrainingQueue class within the SageMaker Python SDK provides built in support for working with Batch queues."
"Using our queue is as easy as referring to it by name in the TrainingQueue constructor. The TrainingQueue class within the SageMaker Python SDK provides built in support for working with Batch queues."
]
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ def create_quota_share(self, create_qs_request: dict):
jobQueues=[create_qs_request["jobQueue"]]
)
jq_arn = desc_jqs_resp["jobQueues"][0]["jobQueueArn"]
quota_share_arn = f"{jq_arn}/quota-share/{create_qs_request["quotaShareName"]}"
quota_share_arn = f"{jq_arn}/quota-share/{create_qs_request['quotaShareName']}"
return {
"quotaShareName": create_qs_request["quotaShareName"],
"quotaShareArn": quota_share_arn,
Expand Down
Loading