From 1c29a9f19d6b25c795ad97155b392aba0b8f982f Mon Sep 17 00:00:00 2001 From: Akanksha Gupta Date: Tue, 24 Mar 2026 13:43:30 -0700 Subject: [PATCH] Add background log streaming to detect TPU placement completion A background thread watches for specific log messages indicating that the proxy pod is waiting for placement until the TPU placement process has finished. This allows for better tracking of the Pathways service readiness. Continued "waiting" messages from proxy might indicate that the Pathways service doesn't have enough TPU availability to process the request. PiperOrigin-RevId: 888835053 --- .../shared_pathways_service/gke_utils.py | 29 ++++++++++ .../shared_pathways_service/isc_pathways.py | 56 ++++++++++++++++++- 2 files changed, 83 insertions(+), 2 deletions(-) diff --git a/pathwaysutils/experimental/shared_pathways_service/gke_utils.py b/pathwaysutils/experimental/shared_pathways_service/gke_utils.py index 3b30dc1..4f44fef 100644 --- a/pathwaysutils/experimental/shared_pathways_service/gke_utils.py +++ b/pathwaysutils/experimental/shared_pathways_service/gke_utils.py @@ -3,6 +3,7 @@ import logging import socket import subprocess +import time import urllib.parse import portpicker @@ -189,6 +190,7 @@ def wait_for_pod(job_name: str) -> str: RuntimeError: If the pod is not ready. """ _logger.info("Waiting for pod to be created...") + time.sleep(1) pod_name = get_pod_from_job(job_name) _logger.info( @@ -296,6 +298,33 @@ def enable_port_forwarding( return (port_available, port_forward_process) +def stream_pod_logs(pod_name: str) -> subprocess.Popen[str]: + """Streams logs from the given pod. + + Args: + pod_name: The name of the pod. + + Returns: + The process for streaming the logs. + + Raises: + Exception: If the log streaming fails. + """ + command = ["kubectl", "logs", "-f", pod_name] + try: + process = subprocess.Popen( + command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + bufsize=1, # Line buffered + ) + return process + except Exception as e: + _logger.exception("Error streaming logs for pod %s: %r", pod_name, e) + raise + + def delete_gke_job(job_name: str) -> None: """Deletes the given job from the GKE cluster. diff --git a/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py b/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py index b25df7e..69a71e5 100644 --- a/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py +++ b/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py @@ -9,6 +9,7 @@ import random import string import subprocess +import threading from typing import Any import jax @@ -123,6 +124,41 @@ def _deploy_pathways_proxy_server( _logger.info("Successfully deployed Pathways proxy.") +def _wait_for_placement(pod_name: str) -> None: + """Waits for the placement to be complete by checking proxy logs.""" + _logger.info("Streaming proxy logs until the placement is complete...") + log_process = gke_utils.stream_pod_logs(pod_name) + + keywords = [ + "placement", + "Signaling to RM", + "Transition slice", + "FAILED_PRECONDITION", + ] + end_phrase = "unplaced -> placed" + + if log_process.stdout: + for line in iter(log_process.stdout.readline, ""): + line_lower = line.lower() + if any(keyword.lower() in line_lower for keyword in keywords): + _logger.info("Proxy log: %s", line.strip()) + + if end_phrase.lower() in line_lower: + _logger.info("TPU placement complete: %s", line.strip()) + break + _logger.info("Closing log process stdout.") + log_process.stdout.close() + + # Ensure the process is terminated + log_process.terminate() + try: + log_process.wait(timeout=5) + except subprocess.TimeoutExpired: + _logger.warning("Log streaming process did not terminate gracefully.") + log_process.kill() + _logger.info("Finished waiting for placement.") + + def _restore_env_var(key: str, original_value: str | None) -> None: """Restores an environment variable to its original value or unsets it.""" if original_value is None: @@ -147,6 +183,7 @@ class _ISCPathways: expected_tpu_instances: A dictionary mapping TPU machine types to the number of instances. proxy_job_name: The name to use for the deployed proxy. + proxy_pod_name: The name of the proxy pod, assigned during deployment. proxy_server_image: The image to use for the proxy server. proxy_options: Configuration options for the Pathways proxy. """ @@ -171,6 +208,7 @@ def __init__( self.pathways_service = pathways_service self.expected_tpu_instances = expected_tpu_instances self._proxy_job_name = proxy_job_name + self.proxy_pod_name = "" self._port_forward_process = None self._proxy_port = None self.proxy_server_image = proxy_server_image @@ -218,9 +256,11 @@ def __enter__(self): ) _logger.info("View proxy logs in Cloud Logging: %s", cloud_logging_link) - proxy_pod = gke_utils.wait_for_pod(self._proxy_job_name) + self.proxy_pod_name = gke_utils.wait_for_pod(self._proxy_job_name) self._proxy_port, self._port_forward_process = ( - gke_utils.enable_port_forwarding(proxy_pod, PROXY_SERVER_PORT) + gke_utils.enable_port_forwarding( + self.proxy_pod_name, PROXY_SERVER_PORT + ) ) # Update the JAX backend to use the proxy. @@ -349,4 +389,16 @@ def connect( proxy_server_image=proxy_server_image, proxy_options=proxy_options, ) as t: + if t.proxy_pod_name: + placement_thread = threading.Thread( + target=_wait_for_placement, + args=(t.proxy_pod_name,), + daemon=True, + ) + placement_thread.start() + else: + _logger.warning( + "proxy_pod_name not set on _ISCPathways instance, skipping background" + " _wait_for_placement." + ) yield t