Skip to content

Commit 36727cd

Browse files
Pathways-on-Cloud Teamcopybara-github
authored andcommitted
Make Pathways proxy server image configurable
The user can optionally pass a custom Pathways proxy server image. This will allow them to use the image corresponding to their head pod. PiperOrigin-RevId: 859207127
1 parent e7d15cc commit 36727cd

4 files changed

Lines changed: 50 additions & 1 deletion

File tree

pathwaysutils/experimental/shared_pathways_service/isc_pathways.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
_JAX_PLATFORM_PROXY = "proxy"
2828
_JAX_BACKEND_TARGET_KEY = "jax_backend_target"
2929
_JAX_BACKEND_TARGET_HOSTNAME = "grpc://localhost"
30+
_DEFAULT_PROXY_IMAGE = "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:jax-0.8.0@sha256:5296fa0819d8cbdfbcf951ffca2072128255411557240624ff4011522a6a2abe"
3031

3132
_logger = logging.getLogger(__name__)
3233

@@ -36,6 +37,7 @@ def _deploy_pathways_proxy_server(
3637
proxy_job_name: str,
3738
expected_instances: Mapping[Any, Any],
3839
gcs_scratch_location: str,
40+
proxy_server_image: str,
3941
) -> None:
4042
"""Deploys the Pathways proxy pods to the GKE cluster.
4143
@@ -45,6 +47,7 @@ def _deploy_pathways_proxy_server(
4547
expected_instances: A dictionary mapping instance types to the number of
4648
instances.
4749
gcs_scratch_location: The Google Cloud Storage location to use.
50+
proxy_server_image: The image to use for the proxy server.
4851
4952
Raises:
5053
subprocess.CalledProcessError: If the kubectl command fails.
@@ -70,6 +73,7 @@ def _deploy_pathways_proxy_server(
7073
PATHWAYS_HEAD_PORT=pathways_head_port,
7174
EXPECTED_INSTANCES=instances_str,
7275
GCS_SCRATCH_LOCATION=gcs_scratch_location,
76+
PROXY_SERVER_IMAGE=proxy_server_image,
7377
)
7478

7579
_logger.info("Deploying Pathways proxy: %s", proxy_job_name)
@@ -89,6 +93,8 @@ class _ISCPathways:
8993
pathways_service: The service name and port of the Pathways head pod.
9094
expected_tpu_instances: A dictionary mapping TPU machine types to the number
9195
of instances.
96+
proxy_job_name: The name to use for the deployed proxy.
97+
proxy_server_image: The image to use for the proxy server.
9298
"""
9399

94100
def __init__(
@@ -100,6 +106,7 @@ def __init__(
100106
pathways_service: str,
101107
expected_tpu_instances: Mapping[Any, Any],
102108
proxy_job_name: str | None,
109+
proxy_server_image: str,
103110
):
104111
"""Initializes the TPU manager."""
105112
self.cluster = cluster
@@ -115,6 +122,7 @@ def __init__(
115122
self._proxy_job_name = proxy_job_name or f"isc-proxy-{user}-{suffix}"
116123
self._port_forward_process = None
117124
self._proxy_port = None
125+
self.proxy_server_image = proxy_server_image
118126

119127
def __repr__(self):
120128
return (
@@ -133,6 +141,7 @@ def __enter__(self):
133141
proxy_job_name=self._proxy_job_name,
134142
expected_instances=self.expected_tpu_instances,
135143
gcs_scratch_location=self.bucket,
144+
proxy_server_image=self.proxy_server_image,
136145
)
137146
# Print a link to Cloud Logging
138147
cloud_logging_link = gke_utils.get_log_link(
@@ -196,6 +205,7 @@ def connect(
196205
pathways_service: str,
197206
expected_tpu_instances: Mapping[str, int],
198207
proxy_job_name: str | None = None,
208+
proxy_server_image: str | None = _DEFAULT_PROXY_IMAGE,
199209
) -> Iterator["_ISCPathways"]:
200210
"""Connects to a Pathways server if the cluster exists. If not, creates it.
201211
@@ -209,13 +219,16 @@ def connect(
209219
of instances. For example: {"tpuv6e:2x2": 2}
210220
proxy_job_name: The name to use for the deployed proxy. If not provided, a
211221
random name will be generated.
222+
proxy_server_image: The proxy server image to use. If not provided, a
223+
default will be used.
212224
213225
Yields:
214226
The Pathways manager.
215227
"""
216228
_logger.info("Validating Pathways service and TPU instances...")
217229
validators.validate_pathways_service(pathways_service)
218230
validators.validate_tpu_instances(expected_tpu_instances)
231+
validators.validate_proxy_server_image(proxy_server_image)
219232
_logger.info("Validation complete.")
220233
gke_utils.fetch_cluster_credentials(
221234
cluster_name=cluster, project_id=project, location=region
@@ -229,5 +242,6 @@ def connect(
229242
pathways_service=pathways_service,
230243
expected_tpu_instances=expected_tpu_instances,
231244
proxy_job_name=proxy_job_name,
245+
proxy_server_image=proxy_server_image,
232246
) as t:
233247
yield t

pathwaysutils/experimental/shared_pathways_service/run_connect_example.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,17 @@
2424
"tpu_type", "tpuv6e:2x2", "The TPU machine type and topology."
2525
)
2626
flags.DEFINE_integer("tpu_count", 1, "The number of TPU slices.")
27+
flags.DEFINE_string(
28+
"proxy_job_name",
29+
None,
30+
"The name to use for the deployed proxy. If not provided, a random name"
31+
" will be generated.",
32+
)
33+
flags.DEFINE_string(
34+
"proxy_server_image",
35+
None,
36+
"The proxy server image to use. If not provided, a default will be used.",
37+
)
2738

2839
flags.mark_flags_as_required([
2940
"cluster",
@@ -37,13 +48,21 @@
3748
def main(argv: Sequence[str]) -> None:
3849
if len(argv) > 1:
3950
raise app.UsageError("Too many command-line arguments.")
51+
52+
kwargs = {}
53+
if FLAGS.proxy_job_name:
54+
kwargs["proxy_job_name"] = FLAGS.proxy_job_name
55+
if FLAGS.proxy_server_image:
56+
kwargs["proxy_server_image"] = FLAGS.proxy_server_image
57+
4058
with isc_pathways.connect(
4159
cluster=FLAGS.cluster,
4260
project=FLAGS.project,
4361
region=FLAGS.region,
4462
gcs_bucket=FLAGS.gcs_bucket,
4563
pathways_service=FLAGS.pathways_service,
4664
expected_tpu_instances={FLAGS.tpu_type: FLAGS.tpu_count},
65+
**kwargs,
4766
):
4867
orig_matrix = jnp.zeros(5)
4968
result_matrix = orig_matrix + 1

pathwaysutils/experimental/shared_pathways_service/validators.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,19 @@ def validate_tpu_instances(expected_tpu_instances: Mapping[Any, Any]) -> None:
8989

9090
inst = next(iter(expected_tpu_instances.keys()))
9191
_validate_tpu_supported(inst)
92+
93+
94+
def validate_proxy_server_image(proxy_server_image: str) -> None:
95+
"""Validates the proxy server image format."""
96+
if not proxy_server_image or not proxy_server_image.strip():
97+
raise ValueError("Proxy server image cannot be empty.")
98+
if "/" not in proxy_server_image:
99+
raise ValueError(
100+
f"Proxy server image '{proxy_server_image}' must contain '/', "
101+
"separating the registry or namespace from the final image name."
102+
)
103+
if ":" not in proxy_server_image and "@" not in proxy_server_image:
104+
raise ValueError(
105+
f"Proxy server image '{proxy_server_image}' must contain a tag with ':'"
106+
" or a digest with '@'."
107+
)

pathwaysutils/experimental/shared_pathways_service/yamls/pw-proxy.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ spec:
1414
automountServiceAccountToken: false
1515
containers:
1616
- name: pathways-proxy
17-
image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:jax-0.8.0@sha256:5296fa0819d8cbdfbcf951ffca2072128255411557240624ff4011522a6a2abe
17+
image: ${PROXY_SERVER_IMAGE}
1818
imagePullPolicy: Always
1919
args:
2020
- --server_port=${PROXY_SERVER_PORT}

0 commit comments

Comments
 (0)