Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 35 additions & 30 deletions test/integration_tests/init/test_pytorch_job_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,17 +90,21 @@ def test_configure_pytorch_job(runner, pytorch_job_name, test_directory):
configure, [
# Required fields only
"--job-name", pytorch_job_name,
"--image", "pytorch/pytorch:2.0.1-cuda11.7-cudnn8-devel",
"--command", '["python", "-c", "import torch; print(torch.__version__); import time; time.sleep(3600)"]',
"--image", "448049793756.dkr.ecr.us-west-2.amazonaws.com/ptjob:mnist",
"--pull-policy", "Always",
"--tasks-per-node", "1",
"--max-retry", "1"
], catch_exceptions=False
)
assert_command_succeeded(result)

# Simplified expected_config
expected_config = {
"job_name": pytorch_job_name,
"image": "pytorch/pytorch:2.0.1-cuda11.7-cudnn8-devel",
"command": ["python", "-c", "import torch; print(torch.__version__); import time; time.sleep(3600)"],
"image": "448049793756.dkr.ecr.us-west-2.amazonaws.com/ptjob:mnist",
"pull_policy": "Always",
"tasks_per_node": 1,
"max_retry": 1
}
assert_config_values("./", expected_config)

Expand All @@ -124,6 +128,31 @@ def test_create_pytorch_job(runner, pytorch_job_name, test_directory):
assert pytorch_job_name in result.output


@pytest.mark.dependency(name="list_pods", depends=["create"])
def test_list_pods(pytorch_job_name, test_directory):
"""Test listing pods for a specific job."""
# Wait a moment to ensure pods are created
time.sleep(10)

list_pods_result = execute_command([
"hyp", "list-pods", "hyp-pytorch-job",
"--job-name", pytorch_job_name,
"--namespace", NAMESPACE
])
assert list_pods_result.returncode == 0

# Verify the output contains expected headers and job name
output = list_pods_result.stdout.strip()
assert f"Pods for job: {pytorch_job_name}" in output
assert "POD NAME" in output
assert "NAMESPACE" in output

# Verify at least one pod is listed (should contain the job name in the pod name)
assert f"{pytorch_job_name}-pod-" in output

print(f"[INFO] Successfully listed pods for job: {pytorch_job_name}")


@pytest.mark.dependency(name="wait", depends=["create"])
def test_wait_for_job_running(pytorch_job_name, test_directory):
"""Poll SDK until PyTorch job reaches Running state."""
Expand Down Expand Up @@ -158,31 +187,7 @@ def test_wait_for_job_running(pytorch_job_name, test_directory):
pytest.fail(f"[ERROR] Timed out waiting for job {pytorch_job_name} to be Running")


@pytest.mark.dependency(name="list_pods", depends=["wait"])
def test_list_pods(pytorch_job_name, test_directory):
"""Test listing pods for a specific job."""
# Wait a moment to ensure pods are created
time.sleep(10)

list_pods_result = execute_command([
"hyp", "list-pods", "hyp-pytorch-job",
"--job-name", pytorch_job_name,
"--namespace", NAMESPACE
])
assert list_pods_result.returncode == 0

# Verify the output contains expected headers and job name
output = list_pods_result.stdout.strip()
assert f"Pods for job: {pytorch_job_name}" in output
assert "POD NAME" in output
assert "NAMESPACE" in output

# Verify at least one pod is listed (should contain the job name in the pod name)
assert f"{pytorch_job_name}-pod-" in output

print(f"[INFO] Successfully listed pods for job: {pytorch_job_name}")


@pytest.mark.run(order=99)
@pytest.mark.dependency(depends=["create"])
def test_pytorch_job_delete(pytorch_job_name, test_directory):
"""Clean up deployed PyTorch job using CLI delete command and verify deletion."""
Expand All @@ -198,7 +203,7 @@ def test_pytorch_job_delete(pytorch_job_name, test_directory):
time.sleep(5)

# Verify the job is no longer listed
list_result = execute_command(["hyp", "list", "hyp-pytorch-job", "--namespace", NAMESPACE])
list_result = execute_command(["hyp", "list", "hyp-pytorch-job"])
assert list_result.returncode == 0

# The job name should no longer be in the output
Expand Down