diff --git a/pathwaysutils/experimental/shared_pathways_service/gke_utils.py b/pathwaysutils/experimental/shared_pathways_service/gke_utils.py index 3b30dc1..4f44fef 100644 --- a/pathwaysutils/experimental/shared_pathways_service/gke_utils.py +++ b/pathwaysutils/experimental/shared_pathways_service/gke_utils.py @@ -3,6 +3,7 @@ import logging import socket import subprocess +import time import urllib.parse import portpicker @@ -189,6 +190,7 @@ def wait_for_pod(job_name: str) -> str: RuntimeError: If the pod is not ready. """ _logger.info("Waiting for pod to be created...") + time.sleep(1) pod_name = get_pod_from_job(job_name) _logger.info( @@ -296,6 +298,33 @@ def enable_port_forwarding( return (port_available, port_forward_process) +def stream_pod_logs(pod_name: str) -> subprocess.Popen[str]: + """Streams logs from the given pod. + + Args: + pod_name: The name of the pod. + + Returns: + The process for streaming the logs. + + Raises: + Exception: If the log streaming fails. + """ + command = ["kubectl", "logs", "-f", pod_name] + try: + process = subprocess.Popen( + command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + bufsize=1, # Line buffered + ) + return process + except Exception as e: + _logger.exception("Error streaming logs for pod %s: %r", pod_name, e) + raise + + def delete_gke_job(job_name: str) -> None: """Deletes the given job from the GKE cluster. diff --git a/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py b/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py index b25df7e..69a71e5 100644 --- a/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py +++ b/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py @@ -9,6 +9,7 @@ import random import string import subprocess +import threading from typing import Any import jax @@ -123,6 +124,41 @@ def _deploy_pathways_proxy_server( _logger.info("Successfully deployed Pathways proxy.") +def _wait_for_placement(pod_name: str) -> None: + """Waits for the placement to be complete by checking proxy logs.""" + _logger.info("Streaming proxy logs until the placement is complete...") + log_process = gke_utils.stream_pod_logs(pod_name) + + keywords = [ + "placement", + "Signaling to RM", + "Transition slice", + "FAILED_PRECONDITION", + ] + end_phrase = "unplaced -> placed" + + if log_process.stdout: + for line in iter(log_process.stdout.readline, ""): + line_lower = line.lower() + if any(keyword.lower() in line_lower for keyword in keywords): + _logger.info("Proxy log: %s", line.strip()) + + if end_phrase.lower() in line_lower: + _logger.info("TPU placement complete: %s", line.strip()) + break + _logger.info("Closing log process stdout.") + log_process.stdout.close() + + # Ensure the process is terminated + log_process.terminate() + try: + log_process.wait(timeout=5) + except subprocess.TimeoutExpired: + _logger.warning("Log streaming process did not terminate gracefully.") + log_process.kill() + _logger.info("Finished waiting for placement.") + + def _restore_env_var(key: str, original_value: str | None) -> None: """Restores an environment variable to its original value or unsets it.""" if original_value is None: @@ -147,6 +183,7 @@ class _ISCPathways: expected_tpu_instances: A dictionary mapping TPU machine types to the number of instances. proxy_job_name: The name to use for the deployed proxy. + proxy_pod_name: The name of the proxy pod, assigned during deployment. proxy_server_image: The image to use for the proxy server. proxy_options: Configuration options for the Pathways proxy. """ @@ -171,6 +208,7 @@ def __init__( self.pathways_service = pathways_service self.expected_tpu_instances = expected_tpu_instances self._proxy_job_name = proxy_job_name + self.proxy_pod_name = "" self._port_forward_process = None self._proxy_port = None self.proxy_server_image = proxy_server_image @@ -218,9 +256,11 @@ def __enter__(self): ) _logger.info("View proxy logs in Cloud Logging: %s", cloud_logging_link) - proxy_pod = gke_utils.wait_for_pod(self._proxy_job_name) + self.proxy_pod_name = gke_utils.wait_for_pod(self._proxy_job_name) self._proxy_port, self._port_forward_process = ( - gke_utils.enable_port_forwarding(proxy_pod, PROXY_SERVER_PORT) + gke_utils.enable_port_forwarding( + self.proxy_pod_name, PROXY_SERVER_PORT + ) ) # Update the JAX backend to use the proxy. @@ -349,4 +389,16 @@ def connect( proxy_server_image=proxy_server_image, proxy_options=proxy_options, ) as t: + if t.proxy_pod_name: + placement_thread = threading.Thread( + target=_wait_for_placement, + args=(t.proxy_pod_name,), + daemon=True, + ) + placement_thread.start() + else: + _logger.warning( + "proxy_pod_name not set on _ISCPathways instance, skipping background" + " _wait_for_placement." + ) yield t