From f283322367495aec4eddf41d53f381ceb497053f Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Fri, 5 Dec 2025 00:57:15 +0000 Subject: [PATCH 01/48] Add model manager and rename modelmanager in base --- sdks/python/apache_beam/ml/inference/base.py | 24 +- .../apache_beam/ml/inference/base_test.py | 9 +- .../apache_beam/ml/inference/model_manager.py | 381 ++++++++++++++++++ .../ml/inference/model_manager_test.py | 279 +++++++++++++ 4 files changed, 677 insertions(+), 16 deletions(-) create mode 100644 sdks/python/apache_beam/ml/inference/model_manager.py create mode 100644 sdks/python/apache_beam/ml/inference/model_manager_test.py diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index 2e1c4963f11d..37f0c901923b 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -470,11 +470,12 @@ def request( raise NotImplementedError(type(self)) -class _ModelManager: +class _ModelHandlerManager: """ - A class for efficiently managing copies of multiple models. Will load a - single copy of each model into a multi_process_shared object and then - return a lookup key for that object. + A class for efficiently managing copies of multiple model handlers. + Will load a single copy of each model from the model handler into a + multi_process_shared object and then return a lookup key for that + object. Used for KeyedModelHandler only. """ def __init__(self, mh_map: dict[str, ModelHandler]): """ @@ -539,8 +540,9 @@ def load(self, key: str) -> _ModelLoadStats: def increment_max_models(self, increment: int): """ - Increments the number of models that this instance of a _ModelManager is - able to hold. If it is never called, no limit is imposed. + Increments the number of models that this instance of a + _ModelHandlerManager is able to hold. If it is never called, + no limit is imposed. Args: increment: the amount by which we are incrementing the number of models. """ @@ -593,7 +595,7 @@ def __init__( class KeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT], ModelHandler[tuple[KeyT, ExampleT], tuple[KeyT, PredictionT], - Union[ModelT, _ModelManager]]): + Union[ModelT, _ModelHandlerManager]]): def __init__( self, unkeyed: Union[ModelHandler[ExampleT, PredictionT, ModelT], @@ -746,15 +748,15 @@ def __init__( 'to exactly one model handler.') self._key_to_id_map[key] = keys[0] - def load_model(self) -> Union[ModelT, _ModelManager]: + def load_model(self) -> Union[ModelT, _ModelHandlerManager]: if self._single_model: return self._unkeyed.load_model() - return _ModelManager(self._id_to_mh_map) + return _ModelHandlerManager(self._id_to_mh_map) def run_inference( self, batch: Sequence[tuple[KeyT, ExampleT]], - model: Union[ModelT, _ModelManager], + model: Union[ModelT, _ModelHandlerManager], inference_args: Optional[dict[str, Any]] = None ) -> Iterable[tuple[KeyT, PredictionT]]: if self._single_model: @@ -856,7 +858,7 @@ def validate_inference_args(self, inference_args: Optional[dict[str, Any]]): def update_model_paths( self, - model: Union[ModelT, _ModelManager], + model: Union[ModelT, _ModelHandlerManager], model_paths: list[KeyModelPathMapping[KeyT]] = None): # When there are many models, the keyed model handler is responsible for # reorganizing the model handlers into cohorts and telling the model diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index 66e85ce163e7..dd3704ff7f03 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -1593,7 +1593,7 @@ def test_child_class_without_env_vars(self): actual = pcoll | base.RunInference(FakeModelHandlerNoEnvVars()) assert_that(actual, equal_to(expected), label='assert:inferences') - def test_model_manager_loads_shared_model(self): + def test_model_handler_manager_loads_shared_model(self): mhs = { 'key1': FakeModelHandler(state=1), 'key2': FakeModelHandler(state=2), @@ -1617,7 +1617,7 @@ def test_model_manager_loads_shared_model(self): self.assertEqual(2, model2.predict(10)) self.assertEqual(3, model3.predict(10)) - def test_model_manager_evicts_models(self): + def test_model_handler_manager_evicts_models(self): mh1 = FakeModelHandler(state=1) mh2 = FakeModelHandler(state=2) mh3 = FakeModelHandler(state=3) @@ -1661,7 +1661,7 @@ def test_model_manager_evicts_models(self): mh3.load_model, tag=tag3).acquire() self.assertEqual(8, model3.predict(10)) - def test_model_manager_evicts_models_after_update(self): + def test_model_handler_manager_evicts_models_after_update(self): mh1 = FakeModelHandler(state=1) mhs = {'key1': mh1} mm = base._ModelManager(mh_map=mhs) @@ -1691,8 +1691,7 @@ def test_model_manager_evicts_models_after_update(self): self.assertEqual(6, model1.predict(10)) sh1.release(model1) - def test_model_manager_evicts_correct_num_of_models_after_being_incremented( - self): + def test_model_handler_manager_evicts_correctly_after_being_incremented(self): mh1 = FakeModelHandler(state=1) mh2 = FakeModelHandler(state=2) mh3 = FakeModelHandler(state=3) diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py b/sdks/python/apache_beam/ml/inference/model_manager.py new file mode 100644 index 000000000000..0c3caea12d0e --- /dev/null +++ b/sdks/python/apache_beam/ml/inference/model_manager.py @@ -0,0 +1,381 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# TODO: https://github.com/apache/beam/issues/21822 +# mypy: ignore-errors + +"""Module for managing ML models in Apache Beam pipelines. + +This module provides classes and functions to efficiently manage multiple +machine learning models within Apache Beam pipelines. It includes functionality +for loading, caching, and updating models using multi-process shared memory, +ensuring that models are reused across different workers to optimize resource +usage and performance. +""" + +import time +import threading +import subprocess +import logging +import gc +import numpy as np +from scipy.optimize import nnls +import torch +from collections import defaultdict, deque, Counter +from contextlib import contextmanager +from typing import Dict, Any, Tuple, Optional, Callable + +_NANOSECOND_TO_MILLISECOND = 1_000_000 +_NANOSECOND_TO_MICROSECOND = 1_000 +_MILLISECOND_TO_SECOND = 1_000 + +ModelT = TypeVar("ModelT") +ExampleT = TypeVar("ExampleT") +PreProcessT = TypeVar("PreProcessT") +PredictionT = TypeVar("PredictionT") +PostProcessT = TypeVar("PostProcessT") +_INPUT_TYPE = TypeVar("_INPUT_TYPE") +_OUTPUT_TYPE = TypeVar("_OUTPUT_TYPE") +KeyT = TypeVar("KeyT") + +# Configure Logging +logger = logging.getLogger(__name__) + +# Constants +SLACK_PERCENTAGE = 0.15 +POLL_INTERVAL = 0.5 +PEAK_WINDOW_SECONDS = 30.0 +SMOOTHING_FACTOR = 0.2 + + +@contextmanager +def cuda_oom_guard(description: str): + """Safely catches OOM, clears cache, and re-raises.""" + try: + yield + except torch.cuda.OutOfMemoryError as e: + logger.error("CUDA OOM DETECTED during: %s", description) + torch.cuda.empty_cache() + gc.collect() + raise e + + +class GPUMonitor: + def __init__(self, fallback_memory_mb: float = 16000.0): + self._current_usage = 0.0 + self._peak_usage = 0.0 + self._total_memory = fallback_memory_mb + self._memory_history = deque() + self._running = False + self._thread = None + self._lock = threading.Lock() + self._detect_hardware() + + def _detect_hardware(self): + try: + cmd = "nvidia-smi --query-gpu=memory.total --format=csv,noheader,nounits" + output = subprocess.check_output(cmd, shell=True).decode("utf-8").strip() + self._total_memory = float(output) + except Exception: + logger.warning( + "nvidia-smi failed. Defaulting total memory to %s MB", + self._total_memory) + + def start(self): + if self._running: + return + self._running = True + self._thread = threading.Thread(target=self._poll_loop, daemon=True) + self._thread.start() + + def stop(self): + self._running = False + if self._thread: + self._thread.join() + + def reset_peak(self): + with self._lock: + now = time.time() + self._memory_history.clear() + self._memory_history.append((now, self._current_usage)) + self._peak_usage = self._current_usage + + def get_stats(self) -> Tuple[float, float, float]: + with self._lock: + return self._current_usage, self._peak_usage, self._total_memory + + def _get_nvidia_smi_used(self) -> float: + try: + cmd = "nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits" + output = subprocess.check_output(cmd, shell=True).decode("utf-8").strip() + return float(output) + except Exception: + return 0.0 + + def _poll_loop(self): + while self._running: + usage = self._get_nvidia_smi_used() + now = time.time() + with self._lock: + self._current_usage = usage + self._memory_history.append((now, usage)) + while self._memory_history and (now - self._memory_history[0][0] + > PEAK_WINDOW_SECONDS): + self._memory_history.popleft() + self._peak_usage = ( + max(m for _, m in self._memory_history) + if self._memory_history else usage) + time.sleep(POLL_INTERVAL) + + +class ResourceEstimator: + def __init__(self): + self.estimates: Dict[str, float] = {} + self.history = defaultdict(lambda: deque(maxlen=20)) + self.known_models = set() + self._lock = threading.Lock() + + def is_unknown(self, model_tag: str) -> bool: + with self._lock: + return model_tag not in self.estimates + + def get_estimate(self, model_tag: str, default_mb: float = 4000.0) -> float: + with self._lock: + return self.estimates.get(model_tag, default_mb) + + def set_initial_estimate(self, model_tag: str, cost: float): + with self._lock: + self.estimates[model_tag] = cost + self.known_models.add(model_tag) + logger.info("Initial Profile for %s: %s MB", model_tag, cost) + + def add_observation( + self, active_snapshot: Dict[str, int], peak_memory: float): + if not active_snapshot: + return + with self._lock: + config_key = tuple(sorted(active_snapshot.items())) + self.history[config_key].append(peak_memory) + for tag in active_snapshot: + self.known_models.add(tag) + self._solve() + + def _solve(self): + """ + Solves Ax=b using raw readings (no pre-averaging) and NNLS. + This creates a 'tall' matrix A where every memory reading is + a separate equation. + """ + unique = sorted(list(self.known_models)) + + # We need to build the matrix first to know if we have enough data points + A, b = [], [] + + for config_key, mem_values in self.history.items(): + if not mem_values: + continue + + # 1. Create the feature row for this configuration ONCE + # (It represents the model counts + bias) + counts = dict(config_key) + feature_row = [counts.get(model, 0) for model in unique] + feature_row.append(1) # Bias column + + # 2. Add a separate row to the matrix for EVERY individual reading + # Instead of averaging, we flatten the history into the matrix + for reading in mem_values: + A.append(feature_row) # The inputs (models) stay the same + b.append(reading) # The output (memory) varies due to noise + + # Convert to numpy for SciPy + A = np.array(A) + b = np.array(b) + + if len(A) < len(unique) + 1: + # Not enough data to solve yet + return + + print(f"Solving with {len(A)} total observations for {len(unique)} models.") + + try: + # Solve using Non-Negative Least Squares + # x will be >= 0 + x, _ = nnls(A, b) + + weights = x[:-1] + bias = x[-1] + + for i, model in enumerate(unique): + calculated_cost = weights[i] + print(f"Solved Cost for {model}: {calculated_cost:.1f} MB") + + if model in self.estimates: + old = self.estimates[model] + new = (old * (1 - SMOOTHING_FACTOR)) + ( + calculated_cost * SMOOTHING_FACTOR) + self.estimates[model] = new + else: + self.estimates[model] = calculated_cost + + print(f"System Bias: {bias:.1f} MB") + + except Exception as e: + logger.error("Solver failed: %s", e) + + +class ModelManager: + _lock = threading.Lock() + + def __init__(self, monitor: Optional[GPUMonitor] = None): + self.estimator = ResourceEstimator() + self.monitor = monitor if monitor else GPUMonitor() + + self.idle_pool = defaultdict(list) + self.active_counts = Counter() + self.total_active_jobs = 0 + self.pending_reservations = 0.0 + + # State Control + self.isolation_mode = False + self.pending_isolation_count = 0 + self.isolation_baseline = 0.0 + + self._cv = threading.Condition() + self.monitor.start() + + def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: + should_spawn = False + est_cost = 0.0 + is_unknown = False + + with self._cv: + while True: + is_unknown = self.estimator.is_unknown(tag) + + # Path A: Isolation for Unknown Models + if is_unknown: + self.pending_isolation_count += 1 + try: + while self.total_active_jobs > 0 or self.isolation_mode: + self._cv.wait() + if not self.estimator.is_unknown(tag): + is_unknown = False + break + + if not is_unknown: + continue + + self.isolation_mode = True + self.total_active_jobs += 1 + self.isolation_baseline, _, _ = self.monitor.get_stats() + self.monitor.reset_peak() + should_spawn = True + break + finally: + self.pending_isolation_count -= 1 + if not should_spawn: + self._cv.notify_all() + + # Path B: Concurrent Execution + else: + # Writer Priority (allow unknown models to drain system) + if self.pending_isolation_count > 0 or self.isolation_mode: + self._cv.wait() + continue + + if self.idle_pool[tag]: + instance = self.idle_pool[tag].pop() + self.active_counts[tag] += 1 + self.total_active_jobs += 1 + return instance + + # Capacity Check + curr, peak, total = self.monitor.get_stats() + est_cost = self.estimator.get_estimate(tag) + limit = total * (1 - SLACK_PERCENTAGE) + base_usage = max(curr, peak) + + if (base_usage + self.pending_reservations + est_cost) <= limit: + self.pending_reservations += est_cost + self.total_active_jobs += 1 + self.active_counts[tag] += 1 + should_spawn = True + break + + self._cv.wait() + + # Execution Logic (Spawn) + if should_spawn: + try: + isolation_baseline_snap, _, _ = self.monitor.get_stats() + with cuda_oom_guard(f"Loading {tag}"): + instance = loader_func() + + _, peak_during_load, _ = self.monitor.get_stats() + snapshot = {tag: 1} + self.estimator.add_observation( + snapshot, peak_during_load - isolation_baseline_snap) + + if not is_unknown: + self.pending_reservations = max( + 0.0, self.pending_reservations - est_cost) + return instance + + except Exception as e: + self.total_active_jobs -= 1 + if is_unknown: + self.isolation_mode = False + self.isolation_baseline = 0.0 + else: + self.pending_reservations = max( + 0.0, self.pending_reservations - est_cost) + self.active_counts[tag] -= 1 + self._cv.notify_all() + raise e + + def release_model(self, tag: str, instance: Any): + with self._cv: + try: + self.total_active_jobs -= 1 + if self.active_counts[tag] > 0: + self.active_counts[tag] -= 1 + + # Return to pool + self.idle_pool[tag].append(instance) + + _, peak_during_job, _ = self.monitor.get_stats() + + if self.isolation_mode and self.active_counts[tag] == 0: + cost = max(0, peak_during_job - self.isolation_baseline) + self.estimator.set_initial_estimate(tag, cost) + self.isolation_mode = False + self.isolation_baseline = 0.0 + else: + # Solver Snapshot + snapshot = dict(self.active_counts) + for pool_tag, models in self.idle_pool.items(): + snapshot[pool_tag] = snapshot.get(pool_tag, 0) + len(models) + + if snapshot: + print( + f"Release Snapshot: {snapshot}, Peak: {peak_during_job:.1f} MB") + self.estimator.add_observation(snapshot, peak_during_job) + + finally: + self._cv.notify_all() + + def shutdown(self): + self.monitor.stop() diff --git a/sdks/python/apache_beam/ml/inference/model_manager_test.py b/sdks/python/apache_beam/ml/inference/model_manager_test.py new file mode 100644 index 000000000000..7e8a3c80201e --- /dev/null +++ b/sdks/python/apache_beam/ml/inference/model_manager_test.py @@ -0,0 +1,279 @@ +import unittest +import time +import threading +import random +from concurrent.futures import ThreadPoolExecutor + +# Import from the library file +from apache_beam.ml.inference.model_manager import ModelManager + + +class MockGPUMonitor: + """ + Simulates GPU hardware with cumulative memory tracking. + Allows simulating specific allocation spikes and baseline usage. + """ + def __init__(self, total_memory=12000.0, peak_window: int = 5): + self._current = 0.0 + self._peak = 0.0 + self._total = total_memory + self._lock = threading.Lock() + self.running = False + self.history = [] + self.peak_window = peak_window + + def start(self): + self.running = True + + def stop(self): + self.running = False + + def get_stats(self): + with self._lock: + return self._current, self._peak, self._total + + def reset_peak(self): + with self._lock: + self._peak = self._current + + # --- Test Helper Methods --- + def set_usage(self, current_mb): + """Sets absolute usage (legacy helper).""" + with self._lock: + self._current = current_mb + self._peak = max(self._peak, current_mb) + + def allocate(self, amount_mb): + """Simulates memory allocation (e.g., tensors loaded to VRAM).""" + with self._lock: + self._current += amount_mb + self.history.append(self._current) + if len(self.history) > self.peak_window: + self.history.pop(0) + self._peak = max(self.history) + + def free(self, amount_mb): + """Simulates memory freeing (not used often if pooling is active).""" + with self._lock: + self._current = max(0.0, self._current - amount_mb) + self.history.append(self._current) + if len(self.history) > self.peak_window: + self.history.pop(0) + self._peak = max(self.history) + + +class TestModelManager(unittest.TestCase): + def setUp(self): + """Force reset the Singleton ModelManager before every test.""" + # 1. Reset the Singleton instance + ModelManager._instance = None + + # 2. Instantiate Mock Monitor directly + self.mock_monitor = MockGPUMonitor() + + # 3. Inject Mock Monitor into Manager + self.manager = ModelManager(monitor=self.mock_monitor) + + def tearDown(self): + self.manager.shutdown() + + def test_model_manager_capacity_check(self): + """ + Test that the manager blocks when spawning models exceeds the limit, + and unblocks when resources become available (via reuse). + """ + model_name = "known_model" + model_cost = 3000.0 + # Total Memory: 12000. Limit (15% slack) ~ 10200. + # 3 * 3000 = 9000 (OK). + # 4 * 3000 = 12000 (Over Limit). + + self.manager.estimator.set_initial_estimate(model_name, model_cost) + + acquired_refs = [] + + def loader(): + self.mock_monitor.allocate(model_cost) + return model_name + + # 1. Saturate GPU with 3 models (9000 MB usage) + for _ in range(3): + inst = self.manager.acquire_model(model_name, loader) + acquired_refs.append(inst) + + # 2. Spawn one more (Should Block because 9000 + 3000 > Limit) + def run_inference(): + return self.manager.acquire_model(model_name, loader) + + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(run_inference) + + # Verify it blocks + try: + future.result(timeout=0.5) + self.fail("Should have blocked due to capacity") + except TimeoutError: + pass + + # 3. Release resources to unblock + # Releasing one puts it in the idle pool. + # The blocked thread should wake up, see the idle one in the pool, + # and reuse it. + item_to_release = acquired_refs.pop() + self.manager.release_model(model_name, item_to_release) + + # 4. Verify Success + # The previous logic required a manual notify loop because set_usage + # didn't notify. release_model calls notify_all(), so standard futures + # waiting works here. + result = future.result(timeout=2.0) + self.assertIsNotNone(result) + + # Verify we reused the released instance (optimization check) + self.assertEqual(result, item_to_release) + + def test_model_manager_unknown_model_runs_isolated(self): + """Test that a model with no history runs in isolation.""" + model_name = "unknown_model_v1" + self.assertTrue(self.manager.estimator.is_unknown(model_name)) + + def dummy_loader(): + time.sleep(0.05) + return "model_instance" + + instance = self.manager.acquire_model(model_name, dummy_loader) + + self.assertTrue(self.manager.isolation_mode) + self.assertEqual(self.manager.total_active_jobs, 1) + + self.manager.release_model(model_name, instance) + self.assertFalse(self.manager.isolation_mode) + self.assertFalse(self.manager.estimator.is_unknown(model_name)) + + def test_model_manager_concurrent_execution(self): + """Test that multiple small known models can run together.""" + model_a = "small_model_a" + model_b = "small_model_b" + + self.manager.estimator.set_initial_estimate(model_a, 1000.0) + self.manager.estimator.set_initial_estimate(model_b, 1000.0) + self.mock_monitor.set_usage(1000.0) + + inst_a = self.manager.acquire_model(model_a, lambda: "A") + inst_b = self.manager.acquire_model(model_b, lambda: "B") + + self.assertEqual(self.manager.total_active_jobs, 2) + + self.manager.release_model(model_a, inst_a) + self.manager.release_model(model_b, inst_b) + self.assertEqual(self.manager.total_active_jobs, 0) + + def test_model_manager_concurrent_mixed_workload_convergence(self): + """ + Simulates a production environment with multiple model types running + concurrently. Verifies that the estimator converges. + """ + # --- Configuration --- + TRUE_COSTS = {"model_small": 1500.0, "model_medium": 3000.0} + + def run_job(model_name): + cost = TRUE_COSTS[model_name] + + # Loader: Simulates the initial memory spike when loading to VRAM + def loader(): + self.mock_monitor.allocate(cost) + time.sleep(0.01) + return f"instance_{model_name}" + + # 1. Acquire + # Note: If reused, loader isn't called, so memory stays stable. + # If new, loader runs and bumps monitor memory. + instance = self.manager.acquire_model(model_name, loader) + + # 2. Simulate Inference Work + # In a real GPU, inference might spike memory further (activations). + # For this test, we assume the 'cost' captures the peak usage. + time.sleep(random.uniform(0.01, 0.05)) + + # 3. Release + self.manager.release_model(model_name, instance) + + # Create a workload stream + # 15 Small jobs, 15 Medium jobs, mixed order + workload = ["model_small"] * 15 + ["model_medium"] * 15 + random.shuffle(workload) + + # We use a thread pool slightly larger than the theoretical capacity + # to force queuing and reuse logic. + # Capacity ~12000. Small=1500, Med=3000. + # Max concurrent approx: 4 Med (12000) or 8 Small (12000). + with ThreadPoolExecutor(max_workers=8) as executor: + futures = [executor.submit(run_job, name) for name in workload] + for f in futures: + f.result() + + # --- Assertions --- + est_small = self.manager.estimator.get_estimate("model_small") + est_med = self.manager.estimator.get_estimate("model_medium") + + # Check convergence (allow some margin for solver approximation) + self.assertAlmostEqual(est_small, TRUE_COSTS["model_small"], delta=100.0) + self.assertAlmostEqual(est_med, TRUE_COSTS["model_medium"], delta=100.0) + + def test_model_manager_oom_recovery(self): + """Test that the manager recovers state if a loader crashes.""" + model_name = "crasher_model" + self.manager.estimator.set_initial_estimate(model_name, 1000.0) + + def crashing_loader(): + raise RuntimeError("CUDA OOM or similar") + + with self.assertRaises(RuntimeError): + self.manager.acquire_model(model_name, crashing_loader) + + self.assertEqual(self.manager.total_active_jobs, 0) + self.assertEqual(self.manager.pending_reservations, 0.0) + self.assertFalse(self.manager._cv._is_owned()) + + def test_single_model_convergence_with_fluctuations(self): + """ + Tests that the estimator converges to the true usage with: + 1. A single model type. + 2. Initial 'Load' cost that is lower than 'Inference' cost. + 3. High variance/fluctuation during inference. + """ + model_name = "fluctuating_model" + model_cost = 3000.0 + load_cost = 2000.0 # Initial load cost underestimates true cost + + def loader(): + self.mock_monitor.allocate(load_cost) + return model_name + + # Check that initial estimate is only the load cost + model = self.manager.acquire_model(model_name, loader) + self.manager.release_model(model_name, model) + initial_est = self.manager.estimator.get_estimate(model_name) + self.assertEqual(initial_est, load_cost) + + def run_inference(): + model = self.manager.acquire_model(model_name, loader) + noise = model_cost - load_cost + random.uniform(-300.0, 300.0) + self.mock_monitor.allocate(noise) + time.sleep(0.1) + self.mock_monitor.free(noise) + self.manager.release_model(model_name, model) + return + + with ThreadPoolExecutor(max_workers=8) as executor: + futures = [executor.submit(run_inference) for _ in range(100)] + + for f in futures: + f.result() + + est_cost = self.manager.estimator.get_estimate(model_name) + self.assertAlmostEqual(est_cost, model_cost, delta=100.0) + + +if __name__ == "__main__": + unittest.main() From efb3f4d2b2b00536d52b91b060f8fa373b5878d0 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Fri, 5 Dec 2025 00:58:34 +0000 Subject: [PATCH 02/48] Update indent --- .../apache_beam/ml/inference/model_manager.py | 8 ++++---- .../ml/inference/model_manager_test.py | 18 +++++++++--------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py b/sdks/python/apache_beam/ml/inference/model_manager.py index 0c3caea12d0e..ffc393600b1f 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager.py +++ b/sdks/python/apache_beam/ml/inference/model_manager.py @@ -175,10 +175,10 @@ def add_observation( def _solve(self): """ - Solves Ax=b using raw readings (no pre-averaging) and NNLS. - This creates a 'tall' matrix A where every memory reading is - a separate equation. - """ + Solves Ax=b using raw readings (no pre-averaging) and NNLS. + This creates a 'tall' matrix A where every memory reading is + a separate equation. + """ unique = sorted(list(self.known_models)) # We need to build the matrix first to know if we have enough data points diff --git a/sdks/python/apache_beam/ml/inference/model_manager_test.py b/sdks/python/apache_beam/ml/inference/model_manager_test.py index 7e8a3c80201e..744811ec7977 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager_test.py +++ b/sdks/python/apache_beam/ml/inference/model_manager_test.py @@ -10,9 +10,9 @@ class MockGPUMonitor: """ - Simulates GPU hardware with cumulative memory tracking. - Allows simulating specific allocation spikes and baseline usage. - """ + Simulates GPU hardware with cumulative memory tracking. + Allows simulating specific allocation spikes and baseline usage. + """ def __init__(self, total_memory=12000.0, peak_window: int = 5): self._current = 0.0 self._peak = 0.0 @@ -79,9 +79,9 @@ def tearDown(self): def test_model_manager_capacity_check(self): """ - Test that the manager blocks when spawning models exceeds the limit, - and unblocks when resources become available (via reuse). - """ + Test that the manager blocks when spawning models exceeds the limit, + and unblocks when resources become available (via reuse). + """ model_name = "known_model" model_cost = 3000.0 # Total Memory: 12000. Limit (15% slack) ~ 10200. @@ -170,9 +170,9 @@ def test_model_manager_concurrent_execution(self): def test_model_manager_concurrent_mixed_workload_convergence(self): """ - Simulates a production environment with multiple model types running - concurrently. Verifies that the estimator converges. - """ + Simulates a production environment with multiple model types running + concurrently. Verifies that the estimator converges. + """ # --- Configuration --- TRUE_COSTS = {"model_small": 1500.0, "model_medium": 3000.0} From d896b9da0b5838cd239601ab8b4c2eba6e64ada3 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Fri, 5 Dec 2025 06:15:41 +0000 Subject: [PATCH 03/48] RunInference with model manager --- sdks/python/apache_beam/ml/inference/base.py | 33 ++++++++++++++++--- .../apache_beam/ml/inference/base_test.py | 8 ++--- .../apache_beam/ml/inference/model_manager.py | 18 +++------- 3 files changed, 38 insertions(+), 21 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index 37f0c901923b..399a3b4adc1f 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -59,6 +59,7 @@ from apache_beam.utils import multi_process_shared from apache_beam.utils import retry from apache_beam.utils import shared +from apache_beam.ml.inference.model_manager import ModelManager try: # pylint: disable=wrong-import-order, wrong-import-position @@ -1748,15 +1749,30 @@ class _SharedModelWrapper(): This allows us to round robin calls to models sitting in different processes so that we can more efficiently use resources (e.g. GPUs). """ - def __init__(self, models: list[Any], model_tag: str): + def __init__( + self, + models: Union[list[Any], ModelManager], + model_tag: str, + loader_func: Callable[[], Any] = None): self.models = models - if len(models) > 1: + self.use_model_manager = isinstance(models, ModelManager) + self.model_tag = model_tag + self.loader_func = loader_func + if not self.use_model_manager and len(models) > 1: self.model_router = multi_process_shared.MultiProcessShared( lambda: _ModelRoutingStrategy(), tag=f'{model_tag}_counter', always_proxy=True).acquire() def next_model(self): + if self.use_model_manager: + + def load(): + unique_tag = self.model_tag + '_' + uuid.uuid4().hex + return multi_process_shared.MultiProcessShared( + self.loader_func, tag=unique_tag, always_proxy=True).acquire() + + return self.models.acquire_model(self.model_tag, load) if len(self.models) == 1: # Short circuit if there's no routing strategy needed in order to # avoid the cross-process call @@ -1765,6 +1781,8 @@ def next_model(self): return self.models[self.model_router.next_model_index(len(self.models))] def all_models(self): + if self.use_model_manager: + return self.models.all_models()[self.model_tag] return self.models @@ -1775,7 +1793,8 @@ def __init__( clock, metrics_namespace, load_model_at_runtime: bool = False, - model_tag: str = "RunInference"): + model_tag: str = "RunInference", + use_model_manager: bool = True): """A DoFn implementation generic to frameworks. Args: @@ -1799,6 +1818,7 @@ def __init__( # _cur_tag is the tag of the actually loaded model self._model_tag = model_tag self._cur_tag = model_tag + self.use_model_manager = use_model_manager def _load_model( self, @@ -1833,7 +1853,12 @@ def load(): model_tag = side_input_model_path # Ensure the tag we're loading is valid, if not replace it with a valid tag self._cur_tag = self._model_metadata.get_valid_tag(model_tag) - if self._model_handler.share_model_across_processes(): + if self.use_model_manager: + model_manager = multi_process_shared.MultiProcessShared( + lambda: ModelManager(), tag='model_manager', + always_proxy=True).acquire() + model_wrapper = _SharedModelWrapper(model_manager, self._cur_tag, load) + elif self._model_handler.share_model_across_processes(): models = [] for copy_tag in _get_tags_for_copies(self._cur_tag, self._model_handler.model_copies()): diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index dd3704ff7f03..33409542925f 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -1599,7 +1599,7 @@ def test_model_handler_manager_loads_shared_model(self): 'key2': FakeModelHandler(state=2), 'key3': FakeModelHandler(state=3) } - mm = base._ModelManager(mh_map=mhs) + mm = base._ModelHandlerManager(mh_map=mhs) tag1 = mm.load('key1').model_tag # Use bad_mh's load function to make sure we're actually loading the # version already stored @@ -1622,7 +1622,7 @@ def test_model_handler_manager_evicts_models(self): mh2 = FakeModelHandler(state=2) mh3 = FakeModelHandler(state=3) mhs = {'key1': mh1, 'key2': mh2, 'key3': mh3} - mm = base._ModelManager(mh_map=mhs) + mm = base._ModelHandlerManager(mh_map=mhs) mm.increment_max_models(2) tag1 = mm.load('key1').model_tag sh1 = multi_process_shared.MultiProcessShared(mh1.load_model, tag=tag1) @@ -1664,7 +1664,7 @@ def test_model_handler_manager_evicts_models(self): def test_model_handler_manager_evicts_models_after_update(self): mh1 = FakeModelHandler(state=1) mhs = {'key1': mh1} - mm = base._ModelManager(mh_map=mhs) + mm = base._ModelHandlerManager(mh_map=mhs) tag1 = mm.load('key1').model_tag sh1 = multi_process_shared.MultiProcessShared(mh1.load_model, tag=tag1) model1 = sh1.acquire() @@ -1696,7 +1696,7 @@ def test_model_handler_manager_evicts_correctly_after_being_incremented(self): mh2 = FakeModelHandler(state=2) mh3 = FakeModelHandler(state=3) mhs = {'key1': mh1, 'key2': mh2, 'key3': mh3} - mm = base._ModelManager(mh_map=mhs) + mm = base._ModelHandlerManager(mh_map=mhs) mm.increment_max_models(1) mm.increment_max_models(1) tag1 = mm.load('key1').model_tag diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py b/sdks/python/apache_beam/ml/inference/model_manager.py index ffc393600b1f..045ec66fd51c 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager.py +++ b/sdks/python/apache_beam/ml/inference/model_manager.py @@ -38,19 +38,6 @@ from contextlib import contextmanager from typing import Dict, Any, Tuple, Optional, Callable -_NANOSECOND_TO_MILLISECOND = 1_000_000 -_NANOSECOND_TO_MICROSECOND = 1_000 -_MILLISECOND_TO_SECOND = 1_000 - -ModelT = TypeVar("ModelT") -ExampleT = TypeVar("ExampleT") -PreProcessT = TypeVar("PreProcessT") -PredictionT = TypeVar("PredictionT") -PostProcessT = TypeVar("PostProcessT") -_INPUT_TYPE = TypeVar("_INPUT_TYPE") -_OUTPUT_TYPE = TypeVar("_OUTPUT_TYPE") -KeyT = TypeVar("KeyT") - # Configure Logging logger = logging.getLogger(__name__) @@ -243,6 +230,7 @@ def __init__(self, monitor: Optional[GPUMonitor] = None): self.estimator = ResourceEstimator() self.monitor = monitor if monitor else GPUMonitor() + self.models = defaultdict(list) self.idle_pool = defaultdict(list) self.active_counts = Counter() self.total_active_jobs = 0 @@ -256,6 +244,9 @@ def __init__(self, monitor: Optional[GPUMonitor] = None): self._cv = threading.Condition() self.monitor.start() + def all_models(self, tag) -> list[Any]: + return self.models[tag] + def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: should_spawn = False est_cost = 0.0 @@ -332,6 +323,7 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: if not is_unknown: self.pending_reservations = max( 0.0, self.pending_reservations - est_cost) + self.models[tag].append(instance) return instance except Exception as e: From 4c0a93357dc467305410c8f08f9acdf13cb9308d Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Fri, 5 Dec 2025 06:37:38 +0000 Subject: [PATCH 04/48] fix --- sdks/python/apache_beam/ml/inference/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index 399a3b4adc1f..81635b77c5b5 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -1755,7 +1755,7 @@ def __init__( model_tag: str, loader_func: Callable[[], Any] = None): self.models = models - self.use_model_manager = isinstance(models, ModelManager) + self.use_model_manager = not isinstance(models, list) self.model_tag = model_tag self.loader_func = loader_func if not self.use_model_manager and len(models) > 1: From a6ba692687992670a777727f5abe01617a0a0ade Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Fri, 5 Dec 2025 06:56:32 +0000 Subject: [PATCH 05/48] fix pickle --- sdks/python/apache_beam/ml/inference/base.py | 58 ++++++++++---------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index 81635b77c5b5..667cf3a41456 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -1764,15 +1764,14 @@ def __init__( tag=f'{model_tag}_counter', always_proxy=True).acquire() + def _load_model(self): + unique_tag = self.model_tag + '_' + uuid.uuid4().hex + return multi_process_shared.MultiProcessShared( + self.loader_func, tag=unique_tag, always_proxy=True).acquire() + def next_model(self): if self.use_model_manager: - - def load(): - unique_tag = self.model_tag + '_' + uuid.uuid4().hex - return multi_process_shared.MultiProcessShared( - self.loader_func, tag=unique_tag, always_proxy=True).acquire() - - return self.models.acquire_model(self.model_tag, load) + return self.models.acquire_model(self.model_tag, self._load_model) if len(self.models) == 1: # Short circuit if there's no routing strategy needed in order to # avoid the cross-process call @@ -1820,32 +1819,32 @@ def __init__( self._cur_tag = model_tag self.use_model_manager = use_model_manager + def _load(self): + """Function for constructing shared LoadedModel.""" + memory_before = _get_current_process_memory_in_bytes() + start_time = _to_milliseconds(self._clock.time_ns()) + if isinstance(side_input_model_path, str): + self._model_handler.update_model_path(side_input_model_path) + else: + if self._model is not None: + models = self._model.all_models() + for m in models: + self._model_handler.update_model_paths(m, side_input_model_path) + model = self._model_handler.load_model() + end_time = _to_milliseconds(self._clock.time_ns()) + memory_after = _get_current_process_memory_in_bytes() + load_model_latency_ms = end_time - start_time + model_byte_size = memory_after - memory_before + if self._metrics_collector: + self._metrics_collector.cache_load_model_metrics( + load_model_latency_ms, model_byte_size) + return model + def _load_model( self, side_input_model_path: Optional[Union[str, list[KeyModelPathMapping]]] = None ) -> _SharedModelWrapper: - def load(): - """Function for constructing shared LoadedModel.""" - memory_before = _get_current_process_memory_in_bytes() - start_time = _to_milliseconds(self._clock.time_ns()) - if isinstance(side_input_model_path, str): - self._model_handler.update_model_path(side_input_model_path) - else: - if self._model is not None: - models = self._model.all_models() - for m in models: - self._model_handler.update_model_paths(m, side_input_model_path) - model = self._model_handler.load_model() - end_time = _to_milliseconds(self._clock.time_ns()) - memory_after = _get_current_process_memory_in_bytes() - load_model_latency_ms = end_time - start_time - model_byte_size = memory_after - memory_before - if self._metrics_collector: - self._metrics_collector.cache_load_model_metrics( - load_model_latency_ms, model_byte_size) - return model - # TODO(https://github.com/apache/beam/issues/21443): Investigate releasing # model. model_tag = self._model_tag @@ -1857,7 +1856,8 @@ def load(): model_manager = multi_process_shared.MultiProcessShared( lambda: ModelManager(), tag='model_manager', always_proxy=True).acquire() - model_wrapper = _SharedModelWrapper(model_manager, self._cur_tag, load) + model_wrapper = _SharedModelWrapper( + model_manager, self._cur_tag, self._load) elif self._model_handler.share_model_across_processes(): models = [] for copy_tag in _get_tags_for_copies(self._cur_tag, From a273f6e10a34cf00422418db48c7471a90e90ee9 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Sat, 6 Dec 2025 03:05:18 +0000 Subject: [PATCH 06/48] Fix pickling and auto proxy --- sdks/python/apache_beam/ml/inference/base.py | 84 ++++++++++++++----- .../apache_beam/utils/multi_process_shared.py | 6 ++ 2 files changed, 68 insertions(+), 22 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index 667cf3a41456..71b356870776 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -60,6 +60,7 @@ from apache_beam.utils import retry from apache_beam.utils import shared from apache_beam.ml.inference.model_manager import ModelManager +from apache_beam.ml.inference.model_manager import ModelManager try: # pylint: disable=wrong-import-order, wrong-import-position @@ -1743,12 +1744,31 @@ def load_model_status( return shared.Shared().acquire(lambda: _ModelStatus(False), tag=tag) +class _ProxyLoader: + """ + A helper callable to wrap the loader for MultiProcessShared. + """ + def __init__(self, loader_func, model_tag): + self.loader_func = loader_func + self.model_tag = model_tag + + def __call__(self): + unique_tag = self.model_tag + '_' + uuid.uuid4().hex + return multi_process_shared.MultiProcessShared( + self.loader_func, tag=unique_tag, always_proxy=True).acquire() + + class _SharedModelWrapper(): """A router class to map incoming calls to the correct model. This allows us to round robin calls to models sitting in different processes so that we can more efficiently use resources (e.g. GPUs). """ + def __init__( + self, + models: Union[list[Any], ModelManager], + model_tag: str, + loader_func: Callable[[], Any] = None): def __init__( self, models: Union[list[Any], ModelManager], @@ -1758,6 +1778,10 @@ def __init__( self.use_model_manager = not isinstance(models, list) self.model_tag = model_tag self.loader_func = loader_func + if not self.use_model_manager and len(models) > 1: + self.use_model_manager = not isinstance(models, list) + self.model_tag = model_tag + self.loader_func = loader_func if not self.use_model_manager and len(models) > 1: self.model_router = multi_process_shared.MultiProcessShared( lambda: _ModelRoutingStrategy(), @@ -1771,7 +1795,8 @@ def _load_model(self): def next_model(self): if self.use_model_manager: - return self.models.acquire_model(self.model_tag, self._load_model) + loader_wrapper = _ProxyLoader(self.loader_func, self.model_tag) + return self.models.acquire_model(self.model_tag, loader_wrapper) if len(self.models) == 1: # Short circuit if there's no routing strategy needed in order to # avoid the cross-process call @@ -1779,7 +1804,13 @@ def next_model(self): return self.models[self.model_router.next_model_index(len(self.models))] + def release_model(self, model_tag: str, model: Any): + if self.use_model_manager: + self.models.release_model(model_tag, model) + def all_models(self): + if self.use_model_manager: + return self.models.all_models()[self.model_tag] if self.use_model_manager: return self.models.all_models()[self.model_tag] return self.models @@ -1794,6 +1825,8 @@ def __init__( load_model_at_runtime: bool = False, model_tag: str = "RunInference", use_model_manager: bool = True): + model_tag: str = "RunInference", + use_model_manager: bool = True): """A DoFn implementation generic to frameworks. Args: @@ -1819,26 +1852,31 @@ def __init__( self._cur_tag = model_tag self.use_model_manager = use_model_manager - def _load(self): - """Function for constructing shared LoadedModel.""" - memory_before = _get_current_process_memory_in_bytes() - start_time = _to_milliseconds(self._clock.time_ns()) - if isinstance(side_input_model_path, str): - self._model_handler.update_model_path(side_input_model_path) - else: - if self._model is not None: - models = self._model.all_models() - for m in models: - self._model_handler.update_model_paths(m, side_input_model_path) - model = self._model_handler.load_model() - end_time = _to_milliseconds(self._clock.time_ns()) - memory_after = _get_current_process_memory_in_bytes() - load_model_latency_ms = end_time - start_time - model_byte_size = memory_after - memory_before - if self._metrics_collector: - self._metrics_collector.cache_load_model_metrics( - load_model_latency_ms, model_byte_size) - return model + def _load_model( + self, + side_input_model_path: Optional[Union[str, + list[KeyModelPathMapping]]] = None + ) -> _SharedModelWrapper: + def load(): + """Function for constructing shared LoadedModel.""" + memory_before = _get_current_process_memory_in_bytes() + start_time = _to_milliseconds(self._clock.time_ns()) + if isinstance(side_input_model_path, str): + self._model_handler.update_model_path(side_input_model_path) + else: + if self._model is not None: + models = self._model.all_models() + for m in models: + self._model_handler.update_model_paths(m, side_input_model_path) + model = self._model_handler.load_model() + end_time = _to_milliseconds(self._clock.time_ns()) + memory_after = _get_current_process_memory_in_bytes() + load_model_latency_ms = end_time - start_time + model_byte_size = memory_after - memory_before + if self._metrics_collector: + self._metrics_collector.cache_load_model_metrics( + load_model_latency_ms, model_byte_size) + return model def _load_model( self, @@ -1857,7 +1895,7 @@ def _load_model( lambda: ModelManager(), tag='model_manager', always_proxy=True).acquire() model_wrapper = _SharedModelWrapper( - model_manager, self._cur_tag, self._load) + model_manager, self._cur_tag, self._model_handler.load_model) elif self._model_handler.share_model_across_processes(): models = [] for copy_tag in _get_tags_for_copies(self._cur_tag, @@ -1915,6 +1953,8 @@ def _run_inference(self, batch, inference_args): model = self._model.next_model() result_generator = self._model_handler.run_inference( batch, model, inference_args) + if self.use_model_manager: + self._model.release_model(self._model_tag, model) except BaseException as e: if self._metrics_collector: self._metrics_collector.failed_batches_counter.inc() diff --git a/sdks/python/apache_beam/utils/multi_process_shared.py b/sdks/python/apache_beam/utils/multi_process_shared.py index aecb1284a1d4..0b082ede205b 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared.py +++ b/sdks/python/apache_beam/utils/multi_process_shared.py @@ -200,6 +200,12 @@ def __call__(self, *args, **kwargs): def __getattr__(self, name): return getattr(self._proxyObject, name) + def __setstate__(self, state): + self.__dict__.update(state) + + def __getstate__(self): + return self.__dict__ + def get_auto_proxy_object(self): return self._proxyObject From 3f6a7b95433540a111a50f1f49955825804e6a1d Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Sat, 6 Dec 2025 03:26:46 +0000 Subject: [PATCH 07/48] fix --- sdks/python/apache_beam/ml/inference/base.py | 24 -------------------- 1 file changed, 24 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index 71b356870776..940449dd4708 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -60,7 +60,6 @@ from apache_beam.utils import retry from apache_beam.utils import shared from apache_beam.ml.inference.model_manager import ModelManager -from apache_beam.ml.inference.model_manager import ModelManager try: # pylint: disable=wrong-import-order, wrong-import-position @@ -1764,11 +1763,6 @@ class _SharedModelWrapper(): This allows us to round robin calls to models sitting in different processes so that we can more efficiently use resources (e.g. GPUs). """ - def __init__( - self, - models: Union[list[Any], ModelManager], - model_tag: str, - loader_func: Callable[[], Any] = None): def __init__( self, models: Union[list[Any], ModelManager], @@ -1778,21 +1772,12 @@ def __init__( self.use_model_manager = not isinstance(models, list) self.model_tag = model_tag self.loader_func = loader_func - if not self.use_model_manager and len(models) > 1: - self.use_model_manager = not isinstance(models, list) - self.model_tag = model_tag - self.loader_func = loader_func if not self.use_model_manager and len(models) > 1: self.model_router = multi_process_shared.MultiProcessShared( lambda: _ModelRoutingStrategy(), tag=f'{model_tag}_counter', always_proxy=True).acquire() - def _load_model(self): - unique_tag = self.model_tag + '_' + uuid.uuid4().hex - return multi_process_shared.MultiProcessShared( - self.loader_func, tag=unique_tag, always_proxy=True).acquire() - def next_model(self): if self.use_model_manager: loader_wrapper = _ProxyLoader(self.loader_func, self.model_tag) @@ -1809,8 +1794,6 @@ def release_model(self, model_tag: str, model: Any): self.models.release_model(model_tag, model) def all_models(self): - if self.use_model_manager: - return self.models.all_models()[self.model_tag] if self.use_model_manager: return self.models.all_models()[self.model_tag] return self.models @@ -1825,8 +1808,6 @@ def __init__( load_model_at_runtime: bool = False, model_tag: str = "RunInference", use_model_manager: bool = True): - model_tag: str = "RunInference", - use_model_manager: bool = True): """A DoFn implementation generic to frameworks. Args: @@ -1878,11 +1859,6 @@ def load(): load_model_latency_ms, model_byte_size) return model - def _load_model( - self, - side_input_model_path: Optional[Union[str, - list[KeyModelPathMapping]]] = None - ) -> _SharedModelWrapper: # TODO(https://github.com/apache/beam/issues/21443): Investigate releasing # model. model_tag = self._model_tag From 379080a4faba71863fd9db99784921cc4ba90faf Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Mon, 8 Dec 2025 18:22:07 +0000 Subject: [PATCH 08/48] Add more tests --- sdks/python/apache_beam/ml/inference/base.py | 7 +- .../apache_beam/ml/inference/base_test.py | 28 ++++++++ .../ml/inference/model_manager_it_test.py | 65 +++++++++++++++++++ 3 files changed, 98 insertions(+), 2 deletions(-) create mode 100644 sdks/python/apache_beam/ml/inference/model_manager_it_test.py diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index 940449dd4708..f0e13fea265c 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -1278,6 +1278,7 @@ def __init__( model_metadata_pcoll: beam.PCollection[ModelMetadata] = None, watch_model_pattern: Optional[str] = None, model_identifier: Optional[str] = None, + use_model_manager: bool = False, **kwargs): """ A transform that takes a PCollection of examples (or features) for use @@ -1318,6 +1319,7 @@ def __init__( self._exception_handling_timeout = None self._timeout = None self._watch_model_pattern = watch_model_pattern + self._use_model_manager = use_model_manager self._kwargs = kwargs # Generate a random tag to use for shared.py and multi_process_shared.py to # allow us to effectively disambiguate in multi-model settings. Only use @@ -1430,7 +1432,8 @@ def expand( self._clock, self._metrics_namespace, load_model_at_runtime, - self._model_tag), + self._model_tag, + self._use_model_manager), self._inference_args, beam.pvalue.AsSingleton( self._model_metadata_pcoll, @@ -1807,7 +1810,7 @@ def __init__( metrics_namespace, load_model_at_runtime: bool = False, model_tag: str = "RunInference", - use_model_manager: bool = True): + use_model_manager: bool = False): """A DoFn implementation generic to frameworks. Args: diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index 33409542925f..16cda2b5e41e 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -1881,6 +1881,34 @@ def test_model_status_provides_valid_garbage_collection(self): self.assertEqual(0, len(tags)) + def test_run_inference_impl_with_model_manager(self): + with TestPipeline() as pipeline: + examples = [1, 5, 3, 10] + expected = [example + 1 for example in examples] + expected[0] = 200 + pcoll = pipeline | 'start' >> beam.Create(examples) + actual = pcoll | base.RunInference( + FakeModelHandler(state=200), use_model_manager=True) + assert_that(actual, equal_to(expected), label='assert:inferences') + + def test_run_inference_impl_with_model_manager_keyed_handler(self): + with TestPipeline() as pipeline: + examples = [1, 5, 3, 10] + keyed_examples = [(i, example) for i, example in enumerate(examples)] + expected = [(i, example + 1) for i, example in enumerate(examples)] + expected[0] = (0, 200) + pcoll = pipeline | 'start' >> beam.Create(keyed_examples) + mhs = [ + base.KeyModelMapping([0], + FakeModelHandler( + state=200, multi_process_shared=True)), + base.KeyModelMapping([1, 2, 3], + FakeModelHandler(multi_process_shared=True)) + ] + actual = pcoll | base.RunInference( + base.KeyedModelHandler(mhs), use_model_manager=True) + assert_that(actual, equal_to(expected), label='assert:inferences') + def _always_retry(e: Exception) -> bool: return True diff --git a/sdks/python/apache_beam/ml/inference/model_manager_it_test.py b/sdks/python/apache_beam/ml/inference/model_manager_it_test.py new file mode 100644 index 000000000000..b6e8977de644 --- /dev/null +++ b/sdks/python/apache_beam/ml/inference/model_manager_it_test.py @@ -0,0 +1,65 @@ +import unittest +import torch +import apache_beam as beam +from apache_beam.ml.inference.base import RunInference +try: + from apache_beam.ml.inference.huggingface_inference import HuggingFacePipelineModelHandler +except ImportError as e: + raise unittest.SkipTest( + "HuggingFace model handler dependencies are not installed") +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import assert_that, equal_to + + +class HuggingFaceGpuTest(unittest.TestCase): + + # This decorator skips the test if you run it on a machine without a GPU + @unittest.skipIf( + not torch.cuda.is_available(), "No GPU detected, skipping GPU test") + def test_sentiment_analysis_on_gpu_large_input(self): + """ + Runs inference on a GPU (device=0) with a larger set of inputs. + """ + model_handler = HuggingFacePipelineModelHandler( + task="sentiment-analysis", + model="distilbert-base-uncased-finetuned-sst-2-english", + device=0, # <--- This forces GPU usage + inference_args={"batch_size": 4 + } # Optional: Control batch size sent to GPU + ) + + with TestPipeline() as pipeline: + examples = [ + "I absolutely love this product, it's a game changer!", + "This is the worst experience I have ever had.", + "The weather is okay, but I wish it were sunnier.", + "Apache Beam makes parallel processing incredibly efficient.", + "I am extremely disappointed with the service.", + "Logic and reason are the pillars of good debugging.", + "I'm so happy today!", + "This error message is confusing and unhelpful.", + "The movie was fantastic and the acting was superb.", + "I hate waiting in line for so long." + ] + + pcoll = pipeline | 'CreateInputs' >> beam.Create(examples) + + predictions = pcoll | 'RunInference' >> RunInference(model_handler) + + actual_labels = predictions | beam.Map(lambda x: x.inference[0]['label']) + + expected_labels = [ + 'POSITIVE', # "love this product" + 'NEGATIVE', # "worst experience" + 'NEGATIVE', # "weather is okay, but..." + 'POSITIVE', # "incredibly efficient" + 'NEGATIVE', # "disappointed" + 'POSITIVE', # "pillars of good debugging" + 'POSITIVE', # "so happy" + 'NEGATIVE', # "confusing and unhelpful" + 'POSITIVE', # "fantastic" + 'NEGATIVE' # "hate waiting" + ] + + assert_that( + actual_labels, equal_to(expected_labels), label='CheckPredictions') From fac671ca34b9aa45f38ddaac8a1641b882afe57e Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Mon, 8 Dec 2025 19:05:00 +0000 Subject: [PATCH 09/48] Fix test --- sdks/python/apache_beam/ml/inference/model_manager_it_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/ml/inference/model_manager_it_test.py b/sdks/python/apache_beam/ml/inference/model_manager_it_test.py index b6e8977de644..4fd319ccebbd 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager_it_test.py +++ b/sdks/python/apache_beam/ml/inference/model_manager_it_test.py @@ -46,7 +46,7 @@ def test_sentiment_analysis_on_gpu_large_input(self): predictions = pcoll | 'RunInference' >> RunInference(model_handler) - actual_labels = predictions | beam.Map(lambda x: x.inference[0]['label']) + actual_labels = predictions | beam.Map(lambda x: x.inference['label']) expected_labels = [ 'POSITIVE', # "love this product" From a2b79018c60b64a8da84236239aff0d0a0c0cd70 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Mon, 8 Dec 2025 19:26:59 +0000 Subject: [PATCH 10/48] Add more test --- .../ml/inference/model_manager_it_test.py | 59 ++++++++++++++++++- 1 file changed, 57 insertions(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/model_manager_it_test.py b/sdks/python/apache_beam/ml/inference/model_manager_it_test.py index 4fd319ccebbd..dc73ff19d184 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager_it_test.py +++ b/sdks/python/apache_beam/ml/inference/model_manager_it_test.py @@ -27,6 +27,7 @@ def test_sentiment_analysis_on_gpu_large_input(self): inference_args={"batch_size": 4 } # Optional: Control batch size sent to GPU ) + DUPLICATE_FACTOR = 2 # Increase to test larger inputs with TestPipeline() as pipeline: examples = [ @@ -40,7 +41,7 @@ def test_sentiment_analysis_on_gpu_large_input(self): "This error message is confusing and unhelpful.", "The movie was fantastic and the acting was superb.", "I hate waiting in line for so long." - ] + ] * DUPLICATE_FACTOR pcoll = pipeline | 'CreateInputs' >> beam.Create(examples) @@ -59,7 +60,61 @@ def test_sentiment_analysis_on_gpu_large_input(self): 'NEGATIVE', # "confusing and unhelpful" 'POSITIVE', # "fantastic" 'NEGATIVE' # "hate waiting" - ] + ] * DUPLICATE_FACTOR assert_that( actual_labels, equal_to(expected_labels), label='CheckPredictions') + + @unittest.skipIf(not torch.cuda.is_available(), "No GPU detected") + def test_sentiment_analysis_large_roberta_gpu(self): + """ + Runs inference using a Large architecture (RoBERTa-Large, ~355M params). + This tests if the GPU can handle larger weights and requires more VRAM. + """ + + model_handler = HuggingFacePipelineModelHandler( + task="sentiment-analysis", + model="Siebert/sentiment-roberta-large-english", + device=0, + inference_args={"batch_size": 2}) + + DUPLICATE_FACTOR = 2 + + with TestPipeline() as pipeline: + examples = [ + "I absolutely love this product, it's a game changer!", + "This is the worst experience I have ever had.", + "Apache Beam scales effortlessly to massive datasets.", + "I am somewhat annoyed by the delay.", + "The nuanced performance of this large model is impressive.", + "I regret buying this immediately.", + "The sunset looks beautiful tonight.", + "This documentation is sparse and misleading.", + "Winning the championship felt surreal.", + "I'm feeling very neutral about this whole situation." + ] * DUPLICATE_FACTOR + + pcoll = pipeline | 'CreateInputs' >> beam.Create(examples) + predictions = pcoll | 'RunInference' >> RunInference(model_handler) + actual_labels = predictions | beam.Map( + lambda x: x.inference[0]['label']) + + # Note: Larger models are often more accurate with nuance. + # e.g. "somewhat annoyed" is confidently NEGATIVE. + expected_labels = [ + 'POSITIVE', # love + 'NEGATIVE', # worst + 'POSITIVE', # scales effortlessly + 'NEGATIVE', # annoyed + 'POSITIVE', # impressive + 'NEGATIVE', # regret + 'POSITIVE', # beautiful + 'NEGATIVE', # misleading + 'POSITIVE', # surreal + 'NEGATIVE' # "neutral" + ] * DUPLICATE_FACTOR + + assert_that( + actual_labels, + equal_to(expected_labels), + label='CheckPredictionsLarge') From 019ec731a8c095039ec0780dceb626f9c1b3dd0f Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Mon, 8 Dec 2025 19:29:16 +0000 Subject: [PATCH 11/48] Add more test --- sdks/python/apache_beam/ml/inference/model_manager_it_test.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/model_manager_it_test.py b/sdks/python/apache_beam/ml/inference/model_manager_it_test.py index dc73ff19d184..fea397c209ec 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager_it_test.py +++ b/sdks/python/apache_beam/ml/inference/model_manager_it_test.py @@ -96,8 +96,7 @@ def test_sentiment_analysis_large_roberta_gpu(self): pcoll = pipeline | 'CreateInputs' >> beam.Create(examples) predictions = pcoll | 'RunInference' >> RunInference(model_handler) - actual_labels = predictions | beam.Map( - lambda x: x.inference[0]['label']) + actual_labels = predictions | beam.Map(lambda x: x.inference['label']) # Note: Larger models are often more accurate with nuance. # e.g. "somewhat annoyed" is confidently NEGATIVE. From e979c67703fdde234a13430432cbdad68ee43cc6 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Mon, 8 Dec 2025 19:36:13 +0000 Subject: [PATCH 12/48] Add more test --- .../ml/inference/model_manager_it_test.py | 104 +++++++++--------- 1 file changed, 52 insertions(+), 52 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/model_manager_it_test.py b/sdks/python/apache_beam/ml/inference/model_manager_it_test.py index fea397c209ec..732a790e14a4 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager_it_test.py +++ b/sdks/python/apache_beam/ml/inference/model_manager_it_test.py @@ -65,55 +65,55 @@ def test_sentiment_analysis_on_gpu_large_input(self): assert_that( actual_labels, equal_to(expected_labels), label='CheckPredictions') - @unittest.skipIf(not torch.cuda.is_available(), "No GPU detected") - def test_sentiment_analysis_large_roberta_gpu(self): - """ - Runs inference using a Large architecture (RoBERTa-Large, ~355M params). - This tests if the GPU can handle larger weights and requires more VRAM. - """ - - model_handler = HuggingFacePipelineModelHandler( - task="sentiment-analysis", - model="Siebert/sentiment-roberta-large-english", - device=0, - inference_args={"batch_size": 2}) - - DUPLICATE_FACTOR = 2 - - with TestPipeline() as pipeline: - examples = [ - "I absolutely love this product, it's a game changer!", - "This is the worst experience I have ever had.", - "Apache Beam scales effortlessly to massive datasets.", - "I am somewhat annoyed by the delay.", - "The nuanced performance of this large model is impressive.", - "I regret buying this immediately.", - "The sunset looks beautiful tonight.", - "This documentation is sparse and misleading.", - "Winning the championship felt surreal.", - "I'm feeling very neutral about this whole situation." - ] * DUPLICATE_FACTOR - - pcoll = pipeline | 'CreateInputs' >> beam.Create(examples) - predictions = pcoll | 'RunInference' >> RunInference(model_handler) - actual_labels = predictions | beam.Map(lambda x: x.inference['label']) - - # Note: Larger models are often more accurate with nuance. - # e.g. "somewhat annoyed" is confidently NEGATIVE. - expected_labels = [ - 'POSITIVE', # love - 'NEGATIVE', # worst - 'POSITIVE', # scales effortlessly - 'NEGATIVE', # annoyed - 'POSITIVE', # impressive - 'NEGATIVE', # regret - 'POSITIVE', # beautiful - 'NEGATIVE', # misleading - 'POSITIVE', # surreal - 'NEGATIVE' # "neutral" - ] * DUPLICATE_FACTOR - - assert_that( - actual_labels, - equal_to(expected_labels), - label='CheckPredictionsLarge') + @unittest.skipIf(not torch.cuda.is_available(), "No GPU detected") + def test_sentiment_analysis_large_roberta_gpu(self): + """ + Runs inference using a Large architecture (RoBERTa-Large, ~355M params). + This tests if the GPU can handle larger weights and requires more VRAM. + """ + + model_handler = HuggingFacePipelineModelHandler( + task="sentiment-analysis", + model="Siebert/sentiment-roberta-large-english", + device=0, + inference_args={"batch_size": 2}) + + DUPLICATE_FACTOR = 2 + + with TestPipeline() as pipeline: + examples = [ + "I absolutely love this product, it's a game changer!", + "This is the worst experience I have ever had.", + "Apache Beam scales effortlessly to massive datasets.", + "I am somewhat annoyed by the delay.", + "The nuanced performance of this large model is impressive.", + "I regret buying this immediately.", + "The sunset looks beautiful tonight.", + "This documentation is sparse and misleading.", + "Winning the championship felt surreal.", + "I'm feeling very neutral about this whole situation." + ] * DUPLICATE_FACTOR + + pcoll = pipeline | 'CreateInputs' >> beam.Create(examples) + predictions = pcoll | 'RunInference' >> RunInference(model_handler) + actual_labels = predictions | beam.Map(lambda x: x.inference['label']) + + # Note: Larger models are often more accurate with nuance. + # e.g. "somewhat annoyed" is confidently NEGATIVE. + expected_labels = [ + 'POSITIVE', # love + 'NEGATIVE', # worst + 'POSITIVE', # scales effortlessly + 'NEGATIVE', # annoyed + 'POSITIVE', # impressive + 'NEGATIVE', # regret + 'POSITIVE', # beautiful + 'NEGATIVE', # misleading + 'POSITIVE', # surreal + 'NEGATIVE' # "neutral" + ] * DUPLICATE_FACTOR + + assert_that( + actual_labels, + equal_to(expected_labels), + label='CheckPredictionsLarge') From 7df7d53c1057892ad0992a9856910e533ed090ee Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Tue, 9 Dec 2025 04:44:58 +0000 Subject: [PATCH 13/48] Add more test and error handling --- sdks/python/apache_beam/ml/inference/base.py | 2 + .../apache_beam/ml/inference/base_test.py | 31 ++++++- .../apache_beam/ml/inference/model_manager.py | 56 ++++++++++-- .../ml/inference/model_manager_test.py | 90 ++++++++++++++++++- 4 files changed, 169 insertions(+), 10 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index f0e13fea265c..88b8be9e620e 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -1937,6 +1937,8 @@ def _run_inference(self, batch, inference_args): except BaseException as e: if self._metrics_collector: self._metrics_collector.failed_batches_counter.inc() + if self.use_model_manager: + self._model.force_reset() if (e is pickle.PickleError and self._model_handler.share_model_across_processes()): raise TypeError( diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index 16cda2b5e41e..86b800c68a4a 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -51,6 +51,16 @@ def predict(self, example: int) -> int: return example + 1 +class FakeFailsOnceModel: + _has_failed = False + + def predict(self, example: int) -> int: + if not FakeFailsOnceModel._has_failed: + FakeFailsOnceModel._has_failed = True + raise Exception('Intentional Failure') + return example + + class FakeStatefulModel: def __init__(self, state: int): if state == 100: @@ -128,6 +138,7 @@ def __init__( incrementing=False, max_copies=1, num_bytes_per_element=None, + inference_fail_once=False, **kwargs): self._fake_clock = clock self._min_batch_size = min_batch_size @@ -138,11 +149,14 @@ def __init__( self._incrementing = incrementing self._max_copies = max_copies self._num_bytes_per_element = num_bytes_per_element + self._inference_fail_once = inference_fail_once def load_model(self): assert (not self._incrementing or self._state is None) if self._fake_clock: self._fake_clock.current_time_ns += 500_000_000 # 500ms + if self._inference_fail_once: + return FakeFailsOnceModel() if self._incrementing: return FakeIncrementingModel() if self._state is not None: @@ -1885,12 +1899,25 @@ def test_run_inference_impl_with_model_manager(self): with TestPipeline() as pipeline: examples = [1, 5, 3, 10] expected = [example + 1 for example in examples] - expected[0] = 200 pcoll = pipeline | 'start' >> beam.Create(examples) actual = pcoll | base.RunInference( - FakeModelHandler(state=200), use_model_manager=True) + FakeModelHandler(multi_process_shared=True), use_model_manager=True) assert_that(actual, equal_to(expected), label='assert:inferences') + def test_run_inference_impl_with_model_manager_fail_and_retry(self): + pipeline = TestPipeline() + examples = [1, 5, 3, 10] + expected = [example + 1 for example in examples] + with self.assertRaises(Exception): + actual = ( + pipeline | 'start' >> beam.Create(examples) + | base.RunInference( + FakeModelHandler( + multi_process_shared=True, inference_fail_once=True), + use_model_manager=True)) + pipeline.run() + assert_that(actual, equal_to(expected), label='assert:inferences') + def test_run_inference_impl_with_model_manager_keyed_handler(self): with TestPipeline() as pipeline: examples = [1, 5, 3, 10] diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py b/sdks/python/apache_beam/ml/inference/model_manager.py index 045ec66fd51c..dc497acf57d2 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager.py +++ b/sdks/python/apache_beam/ml/inference/model_manager.py @@ -55,8 +55,8 @@ def cuda_oom_guard(description: str): yield except torch.cuda.OutOfMemoryError as e: logger.error("CUDA OOM DETECTED during: %s", description) - torch.cuda.empty_cache() gc.collect() + torch.cuda.empty_cache() raise e @@ -69,20 +69,33 @@ def __init__(self, fallback_memory_mb: float = 16000.0): self._running = False self._thread = None self._lock = threading.Lock() - self._detect_hardware() + self._gpu_available = self._detect_hardware() def _detect_hardware(self): try: - cmd = "nvidia-smi --query-gpu=memory.total --format=csv,noheader,nounits" - output = subprocess.check_output(cmd, shell=True).decode("utf-8").strip() + cmd = [ + "nvidia-smi", + "--query-gpu=memory.total", + "--format=csv,noheader,nounits" + ] + output = subprocess.check_output(cmd, text=True).strip() self._total_memory = float(output) - except Exception: + return True + except (FileNotFoundError, subprocess.CalledProcessError): + logger.warning( + "nvidia-smi not found or failed. Defaulting total memory to %s MB", + self._total_memory) + return False + except Exception as e: logger.warning( - "nvidia-smi failed. Defaulting total memory to %s MB", + "Error parsing nvidia-smi output: %s. " + "Defaulting total memory to %s MB", + e, self._total_memory) + return False def start(self): - if self._running: + if self._running or not self._gpu_available: return self._running = True self._thread = threading.Thread(target=self._poll_loop, daemon=True) @@ -311,6 +324,7 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: # Execution Logic (Spawn) if should_spawn: try: + logger.info("Loading model for tag: %s", tag) isolation_baseline_snap, _, _ = self.monitor.get_stats() with cuda_oom_guard(f"Loading {tag}"): instance = loader_func() @@ -369,5 +383,33 @@ def release_model(self, tag: str, instance: Any): finally: self._cv.notify_all() + def force_reset(self): + for _, instances in self.models.items(): + del instances[:] + gc.collect() + torch.cuda.empty_cache() + + self.models = defaultdict(list) + self.idle_pool = defaultdict(list) + self.active_counts = Counter() + self.total_active_jobs = 0 + self.pending_reservations = 0.0 + self.isolation_mode = False + self.pending_isolation_count = 0 + self.isolation_baseline = 0.0 + def shutdown(self): + try: + for _, instances in self.models.items(): + del instances[:] + gc.collect() + torch.cuda.empty_cache() + except Exception as e: + logger.error("Error during ModelManager shutdown: %s", e) self.monitor.stop() + + def __del__(self): + self.shutdown() + + def __exit__(self, exc_type, exc_value, traceback): + self.shutdown() diff --git a/sdks/python/apache_beam/ml/inference/model_manager_test.py b/sdks/python/apache_beam/ml/inference/model_manager_test.py index 744811ec7977..301e5551b31c 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager_test.py +++ b/sdks/python/apache_beam/ml/inference/model_manager_test.py @@ -3,9 +3,10 @@ import threading import random from concurrent.futures import ThreadPoolExecutor +from unittest.mock import patch # Import from the library file -from apache_beam.ml.inference.model_manager import ModelManager +from apache_beam.ml.inference.model_manager import ModelManager, GPUMonitor class MockGPUMonitor: @@ -235,6 +236,29 @@ def crashing_loader(): self.assertEqual(self.manager.pending_reservations, 0.0) self.assertFalse(self.manager._cv._is_owned()) + def test_model_managaer_force_reset_on_exception(self): + """Test that force_reset clears all models from the manager.""" + model_name = "test_model" + + def dummy_loader(): + self.mock_monitor.allocate(1000.0) + raise RuntimeError("Simulated loader exception") + + # Acquire a model to populate the pool + try: + instance = self.manager.acquire_model(model_name, dummy_loader) + except RuntimeError: + self.manager.force_reset() + self.assertTrue(len(self.manager.models[model_name]) == 0) + self.assertEqual(self.manager.total_active_jobs, 0) + self.assertEqual(self.manager.pending_reservations, 0.0) + self.assertFalse(self.manager.isolation_mode) + pass + + # Verify we can still try again + instance = self.manager.acquire_model(model_name, lambda: "model_instance") + self.manager.release_model(model_name, instance) + def test_single_model_convergence_with_fluctuations(self): """ Tests that the estimator converges to the true usage with: @@ -275,5 +299,69 @@ def run_inference(): self.assertAlmostEqual(est_cost, model_cost, delta=100.0) +class TestGPUMonitor(unittest.TestCase): + def setUp(self): + self.subprocess_patcher = patch('subprocess.check_output') + self.mock_subprocess = self.subprocess_patcher.start() + + def tearDown(self): + self.subprocess_patcher.stop() + + def test_init_hardware_detected(self): + """Test that init correctly reads total memory when nvidia-smi exists.""" + self.mock_subprocess.return_value = "24576" + monitor = GPUMonitor() + self.assertTrue(monitor._gpu_available) + self.assertEqual(monitor._total_memory, 24576.0) + + def test_init_hardware_missing(self): + """Test fallback behavior when nvidia-smi is missing.""" + self.mock_subprocess.side_effect = FileNotFoundError() + monitor = GPUMonitor(fallback_memory_mb=12000.0) + self.assertFalse(monitor._gpu_available) + self.assertEqual(monitor._total_memory, 12000.0) + + @patch('time.sleep') # Patch sleep to speed up the test + def test_polling_updates_stats(self, mock_sleep): + """Test that the polling loop updates current and peak usage.""" + def subprocess_side_effect(*args, **kwargs): + if isinstance(args[0], list) and "memory.total" in args[0][1]: + return "16000" + + if isinstance(args[0], str) and "memory.used" in args[0]: + return b"4000" + + raise Exception("Unexpected command") + + self.mock_subprocess.side_effect = subprocess_side_effect + self.mock_subprocess.return_value = None # Clear default + + monitor = GPUMonitor() + monitor.start() + time.sleep(0.1) + curr, peak, total = monitor.get_stats() + monitor.stop() + + self.assertEqual(curr, 4000.0) + self.assertEqual(peak, 4000.0) + self.assertEqual(total, 16000.0) + + def test_reset_peak(self): + """Test that resetting peak usage works.""" + monitor = GPUMonitor() + monitor._gpu_available = True + + with monitor._lock: + monitor._current_usage = 2000.0 + monitor._peak_usage = 8000.0 + monitor._memory_history.append((time.time(), 8000.0)) + monitor._memory_history.append((time.time(), 2000.0)) + + monitor.reset_peak() + + _, peak, _ = monitor.get_stats() + self.assertEqual(peak, 2000.0) # Peak should reset to current + + if __name__ == "__main__": unittest.main() From eb21943e33490fc4a227e263a0fc110c52f01f7e Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Tue, 9 Dec 2025 04:49:45 +0000 Subject: [PATCH 14/48] Add logging --- .../apache_beam/ml/inference/model_manager.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py b/sdks/python/apache_beam/ml/inference/model_manager.py index dc497acf57d2..053a7f2a126c 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager.py +++ b/sdks/python/apache_beam/ml/inference/model_manager.py @@ -261,6 +261,26 @@ def all_models(self, tag) -> list[Any]: return self.models[tag] def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: + logger.info( + "Acquiring model for tag: %s | " + "idle_pool size: %d | " + "active_count: %d | " + "total_active_jobs: %d | " + "pending_reservations: %.1f | " + "isolation_mode: %s | " + "pending_isolation_count: %d | " + "estimator known: %s | " + "estimator cost: %.1f MB", + tag, + len(self.idle_pool[tag]), + self.active_counts[tag], + self.total_active_jobs, + self.pending_reservations, + self.isolation_mode, + self.pending_isolation_count, + not self.estimator.is_unknown(tag), + self.estimator.get_estimate(tag), + ) should_spawn = False est_cost = 0.0 is_unknown = False From 4b653232b7a78824eebf0d1b369fc7500af6e5ab Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Tue, 9 Dec 2025 05:39:35 +0000 Subject: [PATCH 15/48] Add logging --- sdks/python/apache_beam/ml/inference/base.py | 4 ++++ sdks/python/apache_beam/ml/inference/model_manager.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index 88b8be9e620e..b2367c338cce 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -1801,6 +1801,10 @@ def all_models(self): return self.models.all_models()[self.model_tag] return self.models + def force_reset(self): + if self.use_model_manager: + self.models.force_reset() + class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT, PredictionT]): def __init__( diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py b/sdks/python/apache_beam/ml/inference/model_manager.py index 053a7f2a126c..67b078399be3 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager.py +++ b/sdks/python/apache_beam/ml/inference/model_manager.py @@ -348,7 +348,7 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: isolation_baseline_snap, _, _ = self.monitor.get_stats() with cuda_oom_guard(f"Loading {tag}"): instance = loader_func() - + logger.info("Model loaded for tag: %s", tag) _, peak_during_load, _ = self.monitor.get_stats() snapshot = {tag: 1} self.estimator.add_observation( From e2e96bd8b5b91735e0b4ad33218ff5f34696eaa2 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Tue, 9 Dec 2025 21:38:47 +0000 Subject: [PATCH 16/48] Fix memory check --- .../apache_beam/ml/inference/model_manager.py | 32 +++++++++++++++---- .../ml/inference/model_manager_test.py | 4 +-- 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py b/sdks/python/apache_beam/ml/inference/model_manager.py index 67b078399be3..a69bf50daea1 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager.py +++ b/sdks/python/apache_beam/ml/inference/model_manager.py @@ -119,9 +119,10 @@ def get_stats(self) -> Tuple[float, float, float]: def _get_nvidia_smi_used(self) -> float: try: - cmd = "nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits" + cmd = "nvidia-smi --query-gpu=memory.free --format=csv,noheader,nounits" output = subprocess.check_output(cmd, shell=True).decode("utf-8").strip() - return float(output) + free_memory = float(output) + return self._total_memory - free_memory except Exception: return 0.0 @@ -208,7 +209,10 @@ def _solve(self): # Not enough data to solve yet return - print(f"Solving with {len(A)} total observations for {len(unique)} models.") + logger.info( + "Solving with %s total observations for %s models.", + len(A), + len(unique)) try: # Solve using Non-Negative Least Squares @@ -220,7 +224,7 @@ def _solve(self): for i, model in enumerate(unique): calculated_cost = weights[i] - print(f"Solved Cost for {model}: {calculated_cost:.1f} MB") + logger.info("Solved Cost for %s: %s MB", model, calculated_cost) if model in self.estimates: old = self.estimates[model] @@ -230,7 +234,7 @@ def _solve(self): else: self.estimates[model] = calculated_cost - print(f"System Bias: {bias:.1f} MB") + logger.info("System Bias: %s MB", bias) except Exception as e: logger.error("Solver failed: %s", e) @@ -339,6 +343,20 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: should_spawn = True break + logger.info( + "Model load blocked for tag: %s | " + "Current Usage: %.1f MB | " + "Peak Usage: %.1f MB | " + "Total Memory: %.1f MB | " + "Estimated Cost: %.1f MB | " + "Limit: %.1f MB", + tag, + curr, + peak, + total, + est_cost, + limit, + ) self._cv.wait() # Execution Logic (Spawn) @@ -396,8 +414,8 @@ def release_model(self, tag: str, instance: Any): snapshot[pool_tag] = snapshot.get(pool_tag, 0) + len(models) if snapshot: - print( - f"Release Snapshot: {snapshot}, Peak: {peak_during_job:.1f} MB") + logger.info( + "Release Snapshot: %s, Peak: %s MB", snapshot, peak_during_job) self.estimator.add_observation(snapshot, peak_during_job) finally: diff --git a/sdks/python/apache_beam/ml/inference/model_manager_test.py b/sdks/python/apache_beam/ml/inference/model_manager_test.py index 301e5551b31c..592c8763c370 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager_test.py +++ b/sdks/python/apache_beam/ml/inference/model_manager_test.py @@ -328,8 +328,8 @@ def subprocess_side_effect(*args, **kwargs): if isinstance(args[0], list) and "memory.total" in args[0][1]: return "16000" - if isinstance(args[0], str) and "memory.used" in args[0]: - return b"4000" + if isinstance(args[0], str) and "memory.free" in args[0]: + return b"12000" raise Exception("Unexpected command") From d3537ee17d35c5668e07272b379ebc82afd5b169 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Wed, 10 Dec 2025 07:06:15 +0000 Subject: [PATCH 17/48] Fix solver check --- .../apache_beam/ml/inference/model_manager.py | 3 +- .../ml/inference/model_manager_test.py | 74 ++++++++++++++++++- 2 files changed, 75 insertions(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py b/sdks/python/apache_beam/ml/inference/model_manager.py index a69bf50daea1..156d112754d8 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager.py +++ b/sdks/python/apache_beam/ml/inference/model_manager.py @@ -46,6 +46,7 @@ POLL_INTERVAL = 0.5 PEAK_WINDOW_SECONDS = 30.0 SMOOTHING_FACTOR = 0.2 +MIN_DATA_POINTS = 5 @contextmanager @@ -205,7 +206,7 @@ def _solve(self): A = np.array(A) b = np.array(b) - if len(A) < len(unique) + 1: + if len(self.history.keys()) < len(unique) + 1 or len(A) < MIN_DATA_POINTS: # Not enough data to solve yet return diff --git a/sdks/python/apache_beam/ml/inference/model_manager_test.py b/sdks/python/apache_beam/ml/inference/model_manager_test.py index 592c8763c370..3f760b97697d 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager_test.py +++ b/sdks/python/apache_beam/ml/inference/model_manager_test.py @@ -6,7 +6,7 @@ from unittest.mock import patch # Import from the library file -from apache_beam.ml.inference.model_manager import ModelManager, GPUMonitor +from apache_beam.ml.inference.model_manager import ModelManager, GPUMonitor, ResourceEstimator class MockGPUMonitor: @@ -363,5 +363,77 @@ def test_reset_peak(self): self.assertEqual(peak, 2000.0) # Peak should reset to current +class TestResourceEstimatorSolver(unittest.TestCase): + def setUp(self): + self.estimator = ResourceEstimator() + + @patch('apache_beam.ml.inference.model_manager.nnls') + def test_solver_respects_min_data_points(self, mock_nnls): + """ + Test that the solver does not run until the total number of observations + reaches MIN_DATA_POINTS, even if we have enough unique configurations. + """ + mock_nnls.return_value = ([100.0, 50.0], 0.0) + + self.estimator.add_observation({'model_A': 1}, 500) + self.estimator.add_observation({'model_B': 1}, 500) + self.assertFalse( + mock_nnls.called, + "Should not solve: Not enough data points or unique keys") + + # Now we have 3 unique keys. For 2 models, 3 >= 2+1. + # The 'Unique Models' constraint is now SATISFIED. + # However, total observations is 3. MIN_DATA_POINTS is 5. + self.estimator.add_observation({'model_A': 1, 'model_B': 1}, 1000) + + self.assertFalse( + mock_nnls.called, + "Should not solve yet: Total obs (3) < MIN_DATA_POINTS") + + self.estimator.add_observation({'model_A': 1}, 500) + self.assertFalse(mock_nnls.called) + + # Now Total Obs = 5. MIN_DATA_POINTS satisfied. + # Unique Keys = 3. Variety satisfied. + self.estimator.add_observation({'model_B': 1}, 500) + + self.assertTrue( + mock_nnls.called, + "Solver SHOULD run now that min data points are reached") + + @patch('apache_beam.ml.inference.model_manager.nnls') + def test_solver_respects_unique_model_constraint(self, mock_nnls): + """ + Test that the solver does not run if we have a lot of data points, + but they don't represent enough unique configurations to solve + the linear system safely. + """ + mock_nnls.return_value = ([100.0, 100.0, 50.0], 0.0) + + for _ in range(5): + self.estimator.add_observation({'model_A': 1, 'model_B': 1}, 800) + + for _ in range(5): + self.estimator.add_observation({'model_C': 1}, 400) + + # Current State: + # Total Observations: 10 (>> MIN_DATA_POINTS) + # Unique Keys: 2 ({A,B} and {C}) + # Required Keys: 4 (A, B, C + Bias) + + self.assertFalse( + mock_nnls.called, + "Should not solve: 2 unique keys provided, " + "but need 4 to solve for 3 models + bias") + + # Now let's satisfy the constraint by adding distinct configurations + self.estimator.add_observation({'model_A': 1}, 300) # Key 3 + self.estimator.add_observation({'model_B': 1}, 300) # Key 4 + + self.assertTrue( + mock_nnls.called, + "Solver SHOULD run now that we have enough unique configurations") + + if __name__ == "__main__": unittest.main() From a2e11781b3ade87cce1cdc32180a9af5f9eeb12e Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Wed, 10 Dec 2025 07:11:11 +0000 Subject: [PATCH 18/48] Update logging --- .../apache_beam/ml/inference/model_manager.py | 17 +-------- .../ml/inference/model_manager_it_test.py | 12 ++---- .../ml/inference/model_manager_test.py | 38 ------------------- 3 files changed, 6 insertions(+), 61 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py b/sdks/python/apache_beam/ml/inference/model_manager.py index 156d112754d8..82dedabbae24 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager.py +++ b/sdks/python/apache_beam/ml/inference/model_manager.py @@ -225,7 +225,6 @@ def _solve(self): for i, model in enumerate(unique): calculated_cost = weights[i] - logger.info("Solved Cost for %s: %s MB", model, calculated_cost) if model in self.estimates: old = self.estimates[model] @@ -235,6 +234,8 @@ def _solve(self): else: self.estimates[model] = calculated_cost + logger.info( + "Updated Estimate for %s: %.1f MB", model, self.estimates[model]) logger.info("System Bias: %s MB", bias) except Exception as e: @@ -344,20 +345,6 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: should_spawn = True break - logger.info( - "Model load blocked for tag: %s | " - "Current Usage: %.1f MB | " - "Peak Usage: %.1f MB | " - "Total Memory: %.1f MB | " - "Estimated Cost: %.1f MB | " - "Limit: %.1f MB", - tag, - curr, - peak, - total, - est_cost, - limit, - ) self._cv.wait() # Execution Logic (Spawn) diff --git a/sdks/python/apache_beam/ml/inference/model_manager_it_test.py b/sdks/python/apache_beam/ml/inference/model_manager_it_test.py index 732a790e14a4..e55c2ec41a16 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager_it_test.py +++ b/sdks/python/apache_beam/ml/inference/model_manager_it_test.py @@ -13,7 +13,7 @@ class HuggingFaceGpuTest(unittest.TestCase): - # This decorator skips the test if you run it on a machine without a GPU + # Skips the test if you run it on a machine without a GPU @unittest.skipIf( not torch.cuda.is_available(), "No GPU detected, skipping GPU test") def test_sentiment_analysis_on_gpu_large_input(self): @@ -23,11 +23,9 @@ def test_sentiment_analysis_on_gpu_large_input(self): model_handler = HuggingFacePipelineModelHandler( task="sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english", - device=0, # <--- This forces GPU usage - inference_args={"batch_size": 4 - } # Optional: Control batch size sent to GPU - ) - DUPLICATE_FACTOR = 2 # Increase to test larger inputs + device=0, + inference_args={"batch_size": 4}) + DUPLICATE_FACTOR = 2 with TestPipeline() as pipeline: examples = [ @@ -98,8 +96,6 @@ def test_sentiment_analysis_large_roberta_gpu(self): predictions = pcoll | 'RunInference' >> RunInference(model_handler) actual_labels = predictions | beam.Map(lambda x: x.inference['label']) - # Note: Larger models are often more accurate with nuance. - # e.g. "somewhat annoyed" is confidently NEGATIVE. expected_labels = [ 'POSITIVE', # love 'NEGATIVE', # worst diff --git a/sdks/python/apache_beam/ml/inference/model_manager_test.py b/sdks/python/apache_beam/ml/inference/model_manager_test.py index 3f760b97697d..33bb8bac6699 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager_test.py +++ b/sdks/python/apache_beam/ml/inference/model_manager_test.py @@ -5,7 +5,6 @@ from concurrent.futures import ThreadPoolExecutor from unittest.mock import patch -# Import from the library file from apache_beam.ml.inference.model_manager import ModelManager, GPUMonitor, ResourceEstimator @@ -37,7 +36,6 @@ def reset_peak(self): with self._lock: self._peak = self._current - # --- Test Helper Methods --- def set_usage(self, current_mb): """Sets absolute usage (legacy helper).""" with self._lock: @@ -66,13 +64,8 @@ def free(self, amount_mb): class TestModelManager(unittest.TestCase): def setUp(self): """Force reset the Singleton ModelManager before every test.""" - # 1. Reset the Singleton instance ModelManager._instance = None - - # 2. Instantiate Mock Monitor directly self.mock_monitor = MockGPUMonitor() - - # 3. Inject Mock Monitor into Manager self.manager = ModelManager(monitor=self.mock_monitor) def tearDown(self): @@ -85,12 +78,7 @@ def test_model_manager_capacity_check(self): """ model_name = "known_model" model_cost = 3000.0 - # Total Memory: 12000. Limit (15% slack) ~ 10200. - # 3 * 3000 = 9000 (OK). - # 4 * 3000 = 12000 (Over Limit). - self.manager.estimator.set_initial_estimate(model_name, model_cost) - acquired_refs = [] def loader(): @@ -108,8 +96,6 @@ def run_inference(): with ThreadPoolExecutor(max_workers=1) as executor: future = executor.submit(run_inference) - - # Verify it blocks try: future.result(timeout=0.5) self.fail("Should have blocked due to capacity") @@ -117,20 +103,11 @@ def run_inference(): pass # 3. Release resources to unblock - # Releasing one puts it in the idle pool. - # The blocked thread should wake up, see the idle one in the pool, - # and reuse it. item_to_release = acquired_refs.pop() self.manager.release_model(model_name, item_to_release) - # 4. Verify Success - # The previous logic required a manual notify loop because set_usage - # didn't notify. release_model calls notify_all(), so standard futures - # waiting works here. result = future.result(timeout=2.0) self.assertIsNotNone(result) - - # Verify we reused the released instance (optimization check) self.assertEqual(result, item_to_release) def test_model_manager_unknown_model_runs_isolated(self): @@ -174,29 +151,18 @@ def test_model_manager_concurrent_mixed_workload_convergence(self): Simulates a production environment with multiple model types running concurrently. Verifies that the estimator converges. """ - # --- Configuration --- TRUE_COSTS = {"model_small": 1500.0, "model_medium": 3000.0} def run_job(model_name): cost = TRUE_COSTS[model_name] - # Loader: Simulates the initial memory spike when loading to VRAM def loader(): self.mock_monitor.allocate(cost) time.sleep(0.01) return f"instance_{model_name}" - # 1. Acquire - # Note: If reused, loader isn't called, so memory stays stable. - # If new, loader runs and bumps monitor memory. instance = self.manager.acquire_model(model_name, loader) - - # 2. Simulate Inference Work - # In a real GPU, inference might spike memory further (activations). - # For this test, we assume the 'cost' captures the peak usage. time.sleep(random.uniform(0.01, 0.05)) - - # 3. Release self.manager.release_model(model_name, instance) # Create a workload stream @@ -213,11 +179,9 @@ def loader(): for f in futures: f.result() - # --- Assertions --- est_small = self.manager.estimator.get_estimate("model_small") est_med = self.manager.estimator.get_estimate("model_medium") - # Check convergence (allow some margin for solver approximation) self.assertAlmostEqual(est_small, TRUE_COSTS["model_small"], delta=100.0) self.assertAlmostEqual(est_med, TRUE_COSTS["model_medium"], delta=100.0) @@ -244,7 +208,6 @@ def dummy_loader(): self.mock_monitor.allocate(1000.0) raise RuntimeError("Simulated loader exception") - # Acquire a model to populate the pool try: instance = self.manager.acquire_model(model_name, dummy_loader) except RuntimeError: @@ -255,7 +218,6 @@ def dummy_loader(): self.assertFalse(self.manager.isolation_mode) pass - # Verify we can still try again instance = self.manager.acquire_model(model_name, lambda: "model_instance") self.manager.release_model(model_name, instance) From 6c208dedab54427be33e71b4485f5b2dba0a6639 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Wed, 10 Dec 2025 23:41:01 +0000 Subject: [PATCH 19/48] Fix force reset --- .../apache_beam/ml/inference/model_manager.py | 19 ++++++++++--------- .../ml/inference/model_manager_test.py | 3 +++ .../apache_beam/utils/multi_process_shared.py | 18 ++++++++++++++++-- 3 files changed, 29 insertions(+), 11 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py b/sdks/python/apache_beam/ml/inference/model_manager.py index 82dedabbae24..278ec00c4059 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager.py +++ b/sdks/python/apache_beam/ml/inference/model_manager.py @@ -409,12 +409,17 @@ def release_model(self, tag: str, instance: Any): finally: self._cv.notify_all() - def force_reset(self): + def delete_all_models(self): for _, instances in self.models.items(): - del instances[:] + for instance in instances: + if hasattr(instance, "unsafe_hard_delete"): + instance.unsafe_hard_delete() + del instance gc.collect() torch.cuda.empty_cache() + def force_reset(self): + self.delete_all_models() self.models = defaultdict(list) self.idle_pool = defaultdict(list) self.active_counts = Counter() @@ -425,13 +430,9 @@ def force_reset(self): self.isolation_baseline = 0.0 def shutdown(self): - try: - for _, instances in self.models.items(): - del instances[:] - gc.collect() - torch.cuda.empty_cache() - except Exception as e: - logger.error("Error during ModelManager shutdown: %s", e) + self.delete_all_models() + gc.collect() + torch.cuda.empty_cache() self.monitor.stop() def __del__(self): diff --git a/sdks/python/apache_beam/ml/inference/model_manager_test.py b/sdks/python/apache_beam/ml/inference/model_manager_test.py index 33bb8bac6699..a1d9f7276523 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager_test.py +++ b/sdks/python/apache_beam/ml/inference/model_manager_test.py @@ -209,6 +209,9 @@ def dummy_loader(): raise RuntimeError("Simulated loader exception") try: + instance = self.manager.acquire_model( + model_name, lambda: "model_instance") + self.manager.release_model(model_name, instance) instance = self.manager.acquire_model(model_name, dummy_loader) except RuntimeError: self.manager.force_reset() diff --git a/sdks/python/apache_beam/utils/multi_process_shared.py b/sdks/python/apache_beam/utils/multi_process_shared.py index 0b082ede205b..5e7b40618c4d 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared.py +++ b/sdks/python/apache_beam/utils/multi_process_shared.py @@ -191,7 +191,10 @@ class _SingletonRegistrar(multiprocessing.managers.BaseManager): # singletonProxy_call__ calls (which is a wrapper around the underlying # object's __call__ function) class _AutoProxyWrapper: - def __init__(self, proxyObject: multiprocessing.managers.BaseProxy): + def __init__( + self, + proxyObject: multiprocessing.managers.BaseProxy, + deleter: Optional[Callable[[], None]] = None): self._proxyObject = proxyObject def __call__(self, *args, **kwargs): @@ -209,6 +212,13 @@ def __getstate__(self): def get_auto_proxy_object(self): return self._proxyObject + def unsafe_hard_delete(self): + if self._deleter: + self._deleter() + else: + raise NotImplementedError( + "This proxy was not initialized with deletion capabilities.") + class MultiProcessShared(Generic[T]): """MultiProcessShared is used to share a single object across processes. @@ -307,7 +317,11 @@ def acquire(self): # Caveat: They must always agree, as they will be ignored if the object # is already constructed. singleton = self._get_manager().acquire_singleton(self._tag) - return _AutoProxyWrapper(singleton) + + def deleter(): + manager.unsafe_hard_delete_singleton(self._tag) + + return _AutoProxyWrapper(singleton, deleter=deleter) def release(self, obj): self._manager.release_singleton(self._tag, obj.get_auto_proxy_object()) From a2694a7147b275945d0aa57cf976823b7134ec5c Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Thu, 11 Dec 2025 00:08:45 +0000 Subject: [PATCH 20/48] Allow passing in model manager args --- sdks/python/apache_beam/ml/inference/base.py | 12 +++-- .../apache_beam/ml/inference/base_test.py | 18 +++++++ .../apache_beam/ml/inference/model_manager.py | 48 ++++++++++++------- 3 files changed, 57 insertions(+), 21 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index b2367c338cce..4b7761a69f2a 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -1279,6 +1279,7 @@ def __init__( watch_model_pattern: Optional[str] = None, model_identifier: Optional[str] = None, use_model_manager: bool = False, + model_manager_args: Optional[dict[str, Any]] = None, **kwargs): """ A transform that takes a PCollection of examples (or features) for use @@ -1320,6 +1321,7 @@ def __init__( self._timeout = None self._watch_model_pattern = watch_model_pattern self._use_model_manager = use_model_manager + self._model_manager_args = model_manager_args self._kwargs = kwargs # Generate a random tag to use for shared.py and multi_process_shared.py to # allow us to effectively disambiguate in multi-model settings. Only use @@ -1433,7 +1435,8 @@ def expand( self._metrics_namespace, load_model_at_runtime, self._model_tag, - self._use_model_manager), + self._use_model_manager, + self._model_manager_args), self._inference_args, beam.pvalue.AsSingleton( self._model_metadata_pcoll, @@ -1814,7 +1817,8 @@ def __init__( metrics_namespace, load_model_at_runtime: bool = False, model_tag: str = "RunInference", - use_model_manager: bool = False): + use_model_manager: bool = False, + model_manager_args: Optional[dict[str, Any]] = None): """A DoFn implementation generic to frameworks. Args: @@ -1839,6 +1843,7 @@ def __init__( self._model_tag = model_tag self._cur_tag = model_tag self.use_model_manager = use_model_manager + self._model_manager_args = model_manager_args or {} def _load_model( self, @@ -1875,7 +1880,8 @@ def load(): self._cur_tag = self._model_metadata.get_valid_tag(model_tag) if self.use_model_manager: model_manager = multi_process_shared.MultiProcessShared( - lambda: ModelManager(), tag='model_manager', + lambda: ModelManager(**self._model_manager_args), + tag='model_manager', always_proxy=True).acquire() model_wrapper = _SharedModelWrapper( model_manager, self._cur_tag, self._model_handler.load_model) diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index 86b800c68a4a..051190be380c 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -1904,6 +1904,24 @@ def test_run_inference_impl_with_model_manager(self): FakeModelHandler(multi_process_shared=True), use_model_manager=True) assert_that(actual, equal_to(expected), label='assert:inferences') + def test_run_inference_impl_with_model_manager_args(self): + with TestPipeline() as pipeline: + examples = [1, 5, 3, 10] + expected = [example + 1 for example in examples] + pcoll = pipeline | 'start' >> beam.Create(examples) + actual = pcoll | base.RunInference( + FakeModelHandler( + multi_process_shared=True, min_batch_size=2, max_batch_size=4), + use_model_manager=True, + model_manager_args={ + 'slack_percentage': 0.2, + 'poll_interval': 1.0, + 'peak_window_seconds': 10.0, + 'min_data_points': 10, + 'smoothing_factor': 0.5 + }) + assert_that(actual, equal_to(expected), label='assert:inferences') + def test_run_inference_impl_with_model_manager_fail_and_retry(self): pipeline = TestPipeline() examples = [1, 5, 3, 10] diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py b/sdks/python/apache_beam/ml/inference/model_manager.py index 278ec00c4059..f0b1d615bdb8 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager.py +++ b/sdks/python/apache_beam/ml/inference/model_manager.py @@ -41,13 +41,6 @@ # Configure Logging logger = logging.getLogger(__name__) -# Constants -SLACK_PERCENTAGE = 0.15 -POLL_INTERVAL = 0.5 -PEAK_WINDOW_SECONDS = 30.0 -SMOOTHING_FACTOR = 0.2 -MIN_DATA_POINTS = 5 - @contextmanager def cuda_oom_guard(description: str): @@ -62,10 +55,16 @@ def cuda_oom_guard(description: str): class GPUMonitor: - def __init__(self, fallback_memory_mb: float = 16000.0): + def __init__( + self, + fallback_memory_mb: float = 16000.0, + poll_interval: float = 0.5, + peak_window_seconds: float = 30.0): self._current_usage = 0.0 self._peak_usage = 0.0 self._total_memory = fallback_memory_mb + self._poll_interval = poll_interval + self._peak_window_seconds = peak_window_seconds self._memory_history = deque() self._running = False self._thread = None @@ -135,16 +134,18 @@ def _poll_loop(self): self._current_usage = usage self._memory_history.append((now, usage)) while self._memory_history and (now - self._memory_history[0][0] - > PEAK_WINDOW_SECONDS): + > self._peak_window_seconds): self._memory_history.popleft() self._peak_usage = ( max(m for _, m in self._memory_history) if self._memory_history else usage) - time.sleep(POLL_INTERVAL) + time.sleep(self._poll_interval) class ResourceEstimator: - def __init__(self): + def __init__(self, smoothing_factor: float = 0.2, min_data_points: int = 5): + self.smoothing_factor = smoothing_factor + self.min_data_points = min_data_points self.estimates: Dict[str, float] = {} self.history = defaultdict(lambda: deque(maxlen=20)) self.known_models = set() @@ -206,7 +207,8 @@ def _solve(self): A = np.array(A) b = np.array(b) - if len(self.history.keys()) < len(unique) + 1 or len(A) < MIN_DATA_POINTS: + if len( + self.history.keys()) < len(unique) + 1 or len(A) < self.min_data_points: # Not enough data to solve yet return @@ -228,8 +230,8 @@ def _solve(self): if model in self.estimates: old = self.estimates[model] - new = (old * (1 - SMOOTHING_FACTOR)) + ( - calculated_cost * SMOOTHING_FACTOR) + new = (old * (1 - self.smoothing_factor)) + ( + calculated_cost * self.smoothing_factor) self.estimates[model] = new else: self.estimates[model] = calculated_cost @@ -245,9 +247,19 @@ def _solve(self): class ModelManager: _lock = threading.Lock() - def __init__(self, monitor: Optional[GPUMonitor] = None): - self.estimator = ResourceEstimator() - self.monitor = monitor if monitor else GPUMonitor() + def __init__( + self, + monitor: Optional[GPUMonitor] = None, + slack_percentage: float = 0.15, + poll_interval: float = 0.5, + peak_window_seconds: float = 30.0, + min_data_points: int = 5, + smoothing_factor: float = 0.2): + self.estimator = ResourceEstimator( + min_data_points=min_data_points, smoothing_factor=smoothing_factor) + self.monitor = monitor if monitor else GPUMonitor( + poll_interval=poll_interval, peak_window_seconds=peak_window_seconds) + self.slack_percentage = slack_percentage self.models = defaultdict(list) self.idle_pool = defaultdict(list) @@ -335,7 +347,7 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: # Capacity Check curr, peak, total = self.monitor.get_stats() est_cost = self.estimator.get_estimate(tag) - limit = total * (1 - SLACK_PERCENTAGE) + limit = total * (1 - self.slack_percentage) base_usage = max(curr, peak) if (base_usage + self.pending_reservations + est_cost) <= limit: From d1e9a8fdffb36421daed5011030a5da92c26c1e7 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Thu, 11 Dec 2025 02:22:29 +0000 Subject: [PATCH 21/48] Fix multiprocessingshared --- .../apache_beam/utils/multi_process_shared.py | 3 +- .../utils/multi_process_shared_test.py | 28 +++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/utils/multi_process_shared.py b/sdks/python/apache_beam/utils/multi_process_shared.py index 5e7b40618c4d..87b364a61e7c 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared.py +++ b/sdks/python/apache_beam/utils/multi_process_shared.py @@ -196,6 +196,7 @@ def __init__( proxyObject: multiprocessing.managers.BaseProxy, deleter: Optional[Callable[[], None]] = None): self._proxyObject = proxyObject + self._deleter = deleter def __call__(self, *args, **kwargs): return self._proxyObject.singletonProxy_call__(*args, **kwargs) @@ -319,7 +320,7 @@ def acquire(self): singleton = self._get_manager().acquire_singleton(self._tag) def deleter(): - manager.unsafe_hard_delete_singleton(self._tag) + self._get_manager().unsafe_hard_delete_singleton(self._tag) return _AutoProxyWrapper(singleton, deleter=deleter) diff --git a/sdks/python/apache_beam/utils/multi_process_shared_test.py b/sdks/python/apache_beam/utils/multi_process_shared_test.py index 0b7957632368..68dd3d7221af 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared_test.py +++ b/sdks/python/apache_beam/utils/multi_process_shared_test.py @@ -193,6 +193,34 @@ def test_unsafe_hard_delete(self): self.assertEqual(counter3.increment(), 1) + def test_unsafe_hard_delete_autoproxywrapper(self): + shared1 = multi_process_shared.MultiProcessShared( + Counter, + tag='test_unsafe_hard_delete_autoproxywrapper', + always_proxy=True) + shared2 = multi_process_shared.MultiProcessShared( + Counter, + tag='test_unsafe_hard_delete_autoproxywrapper', + always_proxy=True) + + counter1 = shared1.acquire() + counter2 = shared2.acquire() + self.assertEqual(counter1.increment(), 1) + self.assertEqual(counter2.increment(), 2) + + counter2.unsafe_hard_delete() + + with self.assertRaises(Exception): + counter1.get() + with self.assertRaises(Exception): + counter2.get() + + counter3 = multi_process_shared.MultiProcessShared( + Counter, + tag='test_unsafe_hard_delete_autoproxywrapper', + always_proxy=True).acquire() + self.assertEqual(counter3.increment(), 1) + def test_unsafe_hard_delete_no_op(self): shared1 = multi_process_shared.MultiProcessShared( Counter, tag='test_unsafe_hard_delete_no_op', always_proxy=True) From 5860f271bfb1cc7eafb5d81a812683c64d816f3f Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Thu, 11 Dec 2025 03:04:46 +0000 Subject: [PATCH 22/48] Fix unsafe delete --- sdks/python/apache_beam/ml/inference/base.py | 2 -- .../apache_beam/ml/inference/model_manager.py | 22 +++++++----------- .../apache_beam/utils/multi_process_shared.py | 23 +++++++------------ .../utils/multi_process_shared_test.py | 12 ++++++++++ 4 files changed, 28 insertions(+), 31 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index 4b7761a69f2a..db8717b1bd57 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -1947,8 +1947,6 @@ def _run_inference(self, batch, inference_args): except BaseException as e: if self._metrics_collector: self._metrics_collector.failed_batches_counter.inc() - if self.use_model_manager: - self._model.force_reset() if (e is pickle.PickleError and self._model_handler.share_model_across_processes()): raise TypeError( diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py b/sdks/python/apache_beam/ml/inference/model_manager.py index f0b1d615bdb8..e8c378d791c7 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager.py +++ b/sdks/python/apache_beam/ml/inference/model_manager.py @@ -35,25 +35,12 @@ from scipy.optimize import nnls import torch from collections import defaultdict, deque, Counter -from contextlib import contextmanager from typing import Dict, Any, Tuple, Optional, Callable # Configure Logging logger = logging.getLogger(__name__) -@contextmanager -def cuda_oom_guard(description: str): - """Safely catches OOM, clears cache, and re-raises.""" - try: - yield - except torch.cuda.OutOfMemoryError as e: - logger.error("CUDA OOM DETECTED during: %s", description) - gc.collect() - torch.cuda.empty_cache() - raise e - - class GPUMonitor: def __init__( self, @@ -364,8 +351,15 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: try: logger.info("Loading model for tag: %s", tag) isolation_baseline_snap, _, _ = self.monitor.get_stats() - with cuda_oom_guard(f"Loading {tag}"): + try: instance = loader_func() + except torch.cuda.OutOfMemoryError: + logger.error( + "CUDA OOM while loading model for tag: %s, " + "clearing all model instances and reset", + tag) + self.force_reset() + pass logger.info("Model loaded for tag: %s", tag) _, peak_during_load, _ = self.monitor.get_stats() snapshot = {tag: 1} diff --git a/sdks/python/apache_beam/utils/multi_process_shared.py b/sdks/python/apache_beam/utils/multi_process_shared.py index 87b364a61e7c..f92094157181 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared.py +++ b/sdks/python/apache_beam/utils/multi_process_shared.py @@ -79,6 +79,10 @@ def singletonProxy_release(self): assert self._SingletonProxy_valid self._SingletonProxy_valid = False + def unsafe_hard_delete(self): + assert self._SingletonProxy_valid + self._SingletonProxy_entry.unsafe_hard_delete() + def __getattr__(self, name): if not self._SingletonProxy_valid: raise RuntimeError('Entry was released.') @@ -105,6 +109,7 @@ def __dir__(self): dir = self._SingletonProxy_entry.obj.__dir__() dir.append('singletonProxy_call__') dir.append('singletonProxy_release') + dir.append('unsafe_hard_delete') return dir @@ -191,12 +196,8 @@ class _SingletonRegistrar(multiprocessing.managers.BaseManager): # singletonProxy_call__ calls (which is a wrapper around the underlying # object's __call__ function) class _AutoProxyWrapper: - def __init__( - self, - proxyObject: multiprocessing.managers.BaseProxy, - deleter: Optional[Callable[[], None]] = None): + def __init__(self, proxyObject: multiprocessing.managers.BaseProxy): self._proxyObject = proxyObject - self._deleter = deleter def __call__(self, *args, **kwargs): return self._proxyObject.singletonProxy_call__(*args, **kwargs) @@ -214,11 +215,7 @@ def get_auto_proxy_object(self): return self._proxyObject def unsafe_hard_delete(self): - if self._deleter: - self._deleter() - else: - raise NotImplementedError( - "This proxy was not initialized with deletion capabilities.") + return self._proxyObject.unsafe_hard_delete() class MultiProcessShared(Generic[T]): @@ -318,11 +315,7 @@ def acquire(self): # Caveat: They must always agree, as they will be ignored if the object # is already constructed. singleton = self._get_manager().acquire_singleton(self._tag) - - def deleter(): - self._get_manager().unsafe_hard_delete_singleton(self._tag) - - return _AutoProxyWrapper(singleton, deleter=deleter) + return _AutoProxyWrapper(singleton) def release(self, obj): self._manager.release_singleton(self._tag, obj.get_auto_proxy_object()) diff --git a/sdks/python/apache_beam/utils/multi_process_shared_test.py b/sdks/python/apache_beam/utils/multi_process_shared_test.py index 68dd3d7221af..d87eeea1c01a 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared_test.py +++ b/sdks/python/apache_beam/utils/multi_process_shared_test.py @@ -270,6 +270,18 @@ def test_release_always_proxy(self): with self.assertRaisesRegex(Exception, 'released'): counter1.get() + def test_proxy_on_proxy(self): + class SimpleClass: + def make_proxy(self): + return multi_process_shared.MultiProcessShared( + Counter, tag='proxy_on_proxy', always_proxy=True).acquire() + + shared1 = multi_process_shared.MultiProcessShared( + SimpleClass, tag='proxy_on_proxy_main', always_proxy=True) + instance = shared1.acquire() + proxy_instance = instance.make_proxy() + self.assertEqual(proxy_instance.increment(), 1) + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) From ed4a578f43de31907057ab992358067c4fd65340 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Thu, 11 Dec 2025 07:53:04 +0000 Subject: [PATCH 23/48] Make more fixes --- sdks/python/apache_beam/ml/inference/base.py | 9 ++- .../apache_beam/ml/inference/base_test.py | 64 ++++------------ .../ml/inference/model_manager_it_test.py | 73 +++++++++++++++++++ 3 files changed, 96 insertions(+), 50 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index db8717b1bd57..c646b65cb4c6 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -328,6 +328,11 @@ def model_copies(self) -> int: of being loaded per process.""" return 1 + def model_copies_not_overriden(self) -> bool: + """Returns whether the model_copies method has been overridden by the + child class. Used to determine if the model manager should be used.""" + return type(self).model_copies == ModelHandler.model_copies + def override_metrics(self, metrics_namespace: str = '') -> bool: """Returns a boolean representing whether or not a model handler will override metrics reporting. If True, RunInference will not report any @@ -1878,7 +1883,9 @@ def load(): model_tag = side_input_model_path # Ensure the tag we're loading is valid, if not replace it with a valid tag self._cur_tag = self._model_metadata.get_valid_tag(model_tag) - if self.use_model_manager: + if self.use_model_manager and \ + self._model_handler.model_copies_not_overriden( + ): model_manager = multi_process_shared.MultiProcessShared( lambda: ModelManager(**self._model_manager_args), tag='model_manager', diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index 051190be380c..3e3392bbac57 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -51,16 +51,6 @@ def predict(self, example: int) -> int: return example + 1 -class FakeFailsOnceModel: - _has_failed = False - - def predict(self, example: int) -> int: - if not FakeFailsOnceModel._has_failed: - FakeFailsOnceModel._has_failed = True - raise Exception('Intentional Failure') - return example - - class FakeStatefulModel: def __init__(self, state: int): if state == 100: @@ -127,6 +117,19 @@ def batch_elements_kwargs(self): return {'min_batch_size': 1, 'max_batch_size': 1} +class SimpleFakeModelHanlder(base.ModelHandler[int, int, FakeModel]): + def load_model(self): + return FakeModel() + + def run_inference( + self, + batch: Sequence[int], + model: FakeModel, + inference_args=None) -> Iterable[int]: + for example in batch: + yield model.predict(example) + + class FakeModelHandler(base.ModelHandler[int, int, FakeModel]): def __init__( self, @@ -138,7 +141,6 @@ def __init__( incrementing=False, max_copies=1, num_bytes_per_element=None, - inference_fail_once=False, **kwargs): self._fake_clock = clock self._min_batch_size = min_batch_size @@ -149,14 +151,11 @@ def __init__( self._incrementing = incrementing self._max_copies = max_copies self._num_bytes_per_element = num_bytes_per_element - self._inference_fail_once = inference_fail_once def load_model(self): assert (not self._incrementing or self._state is None) if self._fake_clock: self._fake_clock.current_time_ns += 500_000_000 # 500ms - if self._inference_fail_once: - return FakeFailsOnceModel() if self._incrementing: return FakeIncrementingModel() if self._state is not None: @@ -1901,7 +1900,7 @@ def test_run_inference_impl_with_model_manager(self): expected = [example + 1 for example in examples] pcoll = pipeline | 'start' >> beam.Create(examples) actual = pcoll | base.RunInference( - FakeModelHandler(multi_process_shared=True), use_model_manager=True) + SimpleFakeModelHanlder(), use_model_manager=True) assert_that(actual, equal_to(expected), label='assert:inferences') def test_run_inference_impl_with_model_manager_args(self): @@ -1910,8 +1909,7 @@ def test_run_inference_impl_with_model_manager_args(self): expected = [example + 1 for example in examples] pcoll = pipeline | 'start' >> beam.Create(examples) actual = pcoll | base.RunInference( - FakeModelHandler( - multi_process_shared=True, min_batch_size=2, max_batch_size=4), + SimpleFakeModelHanlder(), use_model_manager=True, model_manager_args={ 'slack_percentage': 0.2, @@ -1922,38 +1920,6 @@ def test_run_inference_impl_with_model_manager_args(self): }) assert_that(actual, equal_to(expected), label='assert:inferences') - def test_run_inference_impl_with_model_manager_fail_and_retry(self): - pipeline = TestPipeline() - examples = [1, 5, 3, 10] - expected = [example + 1 for example in examples] - with self.assertRaises(Exception): - actual = ( - pipeline | 'start' >> beam.Create(examples) - | base.RunInference( - FakeModelHandler( - multi_process_shared=True, inference_fail_once=True), - use_model_manager=True)) - pipeline.run() - assert_that(actual, equal_to(expected), label='assert:inferences') - - def test_run_inference_impl_with_model_manager_keyed_handler(self): - with TestPipeline() as pipeline: - examples = [1, 5, 3, 10] - keyed_examples = [(i, example) for i, example in enumerate(examples)] - expected = [(i, example + 1) for i, example in enumerate(examples)] - expected[0] = (0, 200) - pcoll = pipeline | 'start' >> beam.Create(keyed_examples) - mhs = [ - base.KeyModelMapping([0], - FakeModelHandler( - state=200, multi_process_shared=True)), - base.KeyModelMapping([1, 2, 3], - FakeModelHandler(multi_process_shared=True)) - ] - actual = pcoll | base.RunInference( - base.KeyedModelHandler(mhs), use_model_manager=True) - assert_that(actual, equal_to(expected), label='assert:inferences') - def _always_retry(e: Exception) -> bool: return True diff --git a/sdks/python/apache_beam/ml/inference/model_manager_it_test.py b/sdks/python/apache_beam/ml/inference/model_manager_it_test.py index e55c2ec41a16..c0b367a80eaa 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager_it_test.py +++ b/sdks/python/apache_beam/ml/inference/model_manager_it_test.py @@ -113,3 +113,76 @@ def test_sentiment_analysis_large_roberta_gpu(self): actual_labels, equal_to(expected_labels), label='CheckPredictionsLarge') + + @unittest.skipIf(not torch.cuda.is_available(), "No GPU detected") + def test_parallel_inference_branches(self): + """ + Tests a branching pipeline where one input source feeds two + RunInference transforms running in parallel. + + Topology: + [ Input Data ] + | + +--------+--------+ + | | + [ Translation ] [ Sentiment ] + | | + [ Check Trans ] [ Check Sent ] + """ + + translator_handler = HuggingFacePipelineModelHandler( + task="translation_en_to_es", + model="Helsinki-NLP/opus-mt-en-es", + device=0, + inference_args={"batch_size": 8 + }) # Increased batch size for throughput + sentiment_handler = HuggingFacePipelineModelHandler( + task="sentiment-analysis", + model="nlptown/bert-base-multilingual-uncased-sentiment", + device=0, + inference_args={"batch_size": 8}) + base_examples = [ + "I love this product.", # Trans: Me encanta... | Sent: 5 stars + "This is terrible.", # Trans: Esto es... | Sent: 1 star + "Hello world.", # Trans: Hola mundo. | Sent: 4/5 stars + "The service was okay.", # Trans: El servicio...| Sent: 3 stars + "I am extremely angry." # Trans: Estoy... | Sent: 1 star + ] + MULTIPLIER = 10 + examples = base_examples * MULTIPLIER + + with TestPipeline() as pipeline: + inputs = pipeline | 'CreateInputs' >> beam.Create(examples) + translations = ( + inputs + | 'RunTranslation' >> RunInference(translator_handler) + | 'ExtractSpanish' >> + beam.Map(lambda x: x.inference[0]['translation_text'])) + sentiments = ( + inputs + | 'RunSentiment' >> RunInference(sentiment_handler) + | 'ExtractLabel' >> beam.Map(lambda x: x.inference['label'])) + + expected_translations = [ + "Me encanta este producto.", + "Esto es terrible.", + "Hola mundo.", + "El servicio estuvo bien.", + "Estoy extremadamente enojado." + ] * MULTIPLIER + + assert_that( + translations, + equal_to(expected_translations), + label='CheckTranslations') + + expected_sentiments = [ + "5 stars", # love + "1 star", # terrible + "5 stars", # Hello world + "3 stars", # okay + "1 star" # angry + ] * MULTIPLIER + + assert_that( + sentiments, equal_to(expected_sentiments), label='CheckSentiments') From f26923c3ff2529094ae27074495b171b4f3db813 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Thu, 11 Dec 2025 07:59:08 +0000 Subject: [PATCH 24/48] Fix indent --- .../ml/inference/model_manager_it_test.py | 143 +++++++++--------- 1 file changed, 71 insertions(+), 72 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/model_manager_it_test.py b/sdks/python/apache_beam/ml/inference/model_manager_it_test.py index c0b367a80eaa..093c5857f42d 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager_it_test.py +++ b/sdks/python/apache_beam/ml/inference/model_manager_it_test.py @@ -114,75 +114,74 @@ def test_sentiment_analysis_large_roberta_gpu(self): equal_to(expected_labels), label='CheckPredictionsLarge') - @unittest.skipIf(not torch.cuda.is_available(), "No GPU detected") - def test_parallel_inference_branches(self): - """ - Tests a branching pipeline where one input source feeds two - RunInference transforms running in parallel. - - Topology: - [ Input Data ] - | - +--------+--------+ - | | - [ Translation ] [ Sentiment ] - | | - [ Check Trans ] [ Check Sent ] - """ - - translator_handler = HuggingFacePipelineModelHandler( - task="translation_en_to_es", - model="Helsinki-NLP/opus-mt-en-es", - device=0, - inference_args={"batch_size": 8 - }) # Increased batch size for throughput - sentiment_handler = HuggingFacePipelineModelHandler( - task="sentiment-analysis", - model="nlptown/bert-base-multilingual-uncased-sentiment", - device=0, - inference_args={"batch_size": 8}) - base_examples = [ - "I love this product.", # Trans: Me encanta... | Sent: 5 stars - "This is terrible.", # Trans: Esto es... | Sent: 1 star - "Hello world.", # Trans: Hola mundo. | Sent: 4/5 stars - "The service was okay.", # Trans: El servicio...| Sent: 3 stars - "I am extremely angry." # Trans: Estoy... | Sent: 1 star - ] - MULTIPLIER = 10 - examples = base_examples * MULTIPLIER - - with TestPipeline() as pipeline: - inputs = pipeline | 'CreateInputs' >> beam.Create(examples) - translations = ( - inputs - | 'RunTranslation' >> RunInference(translator_handler) - | 'ExtractSpanish' >> - beam.Map(lambda x: x.inference[0]['translation_text'])) - sentiments = ( - inputs - | 'RunSentiment' >> RunInference(sentiment_handler) - | 'ExtractLabel' >> beam.Map(lambda x: x.inference['label'])) - - expected_translations = [ - "Me encanta este producto.", - "Esto es terrible.", - "Hola mundo.", - "El servicio estuvo bien.", - "Estoy extremadamente enojado." - ] * MULTIPLIER - - assert_that( - translations, - equal_to(expected_translations), - label='CheckTranslations') - - expected_sentiments = [ - "5 stars", # love - "1 star", # terrible - "5 stars", # Hello world - "3 stars", # okay - "1 star" # angry - ] * MULTIPLIER - - assert_that( - sentiments, equal_to(expected_sentiments), label='CheckSentiments') + @unittest.skipIf(not torch.cuda.is_available(), "No GPU detected") + def test_parallel_inference_branches(self): + """ + Tests a branching pipeline where one input source feeds two + RunInference transforms running in parallel. + + Topology: + [ Input Data ] + | + +--------+--------+ + | | + [ Translation ] [ Sentiment ] + | | + [ Check Trans ] [ Check Sent ] + """ + + translator_handler = HuggingFacePipelineModelHandler( + task="translation_en_to_es", + model="Helsinki-NLP/opus-mt-en-es", + device=0, + inference_args={"batch_size": 8}) # Increased batch size for throughput + sentiment_handler = HuggingFacePipelineModelHandler( + task="sentiment-analysis", + model="nlptown/bert-base-multilingual-uncased-sentiment", + device=0, + inference_args={"batch_size": 8}) + base_examples = [ + "I love this product.", # Trans: Me encanta... | Sent: 5 stars + "This is terrible.", # Trans: Esto es... | Sent: 1 star + "Hello world.", # Trans: Hola mundo. | Sent: 4/5 stars + "The service was okay.", # Trans: El servicio...| Sent: 3 stars + "I am extremely angry." # Trans: Estoy... | Sent: 1 star + ] + MULTIPLIER = 10 + examples = base_examples * MULTIPLIER + + with TestPipeline() as pipeline: + inputs = pipeline | 'CreateInputs' >> beam.Create(examples) + translations = ( + inputs + | 'RunTranslation' >> RunInference(translator_handler) + | 'ExtractSpanish' >> + beam.Map(lambda x: x.inference[0]['translation_text'])) + sentiments = ( + inputs + | 'RunSentiment' >> RunInference(sentiment_handler) + | 'ExtractLabel' >> beam.Map(lambda x: x.inference['label'])) + + expected_translations = [ + "Me encanta este producto.", + "Esto es terrible.", + "Hola mundo.", + "El servicio estuvo bien.", + "Estoy extremadamente enojado." + ] * MULTIPLIER + + assert_that( + translations, + equal_to(expected_translations), + label='CheckTranslations') + + expected_sentiments = [ + "5 stars", # love + "1 star", # terrible + "5 stars", # Hello world + "3 stars", # okay + "1 star" # angry + ] * MULTIPLIER + + assert_that( + sentiments, equal_to(expected_sentiments), label='CheckSentiments') From 0c8d26de3d4574fd52d22e998c2f06c46b009c0f Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Thu, 11 Dec 2025 08:06:20 +0000 Subject: [PATCH 25/48] Fix assert --- .../ml/inference/model_manager_it_test.py | 42 ++++--------------- 1 file changed, 9 insertions(+), 33 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/model_manager_it_test.py b/sdks/python/apache_beam/ml/inference/model_manager_it_test.py index 093c5857f42d..e626c0c4b5c8 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager_it_test.py +++ b/sdks/python/apache_beam/ml/inference/model_manager_it_test.py @@ -134,54 +134,30 @@ def test_parallel_inference_branches(self): task="translation_en_to_es", model="Helsinki-NLP/opus-mt-en-es", device=0, - inference_args={"batch_size": 8}) # Increased batch size for throughput + inference_args={"batch_size": 8}) sentiment_handler = HuggingFacePipelineModelHandler( task="sentiment-analysis", model="nlptown/bert-base-multilingual-uncased-sentiment", device=0, inference_args={"batch_size": 8}) base_examples = [ - "I love this product.", # Trans: Me encanta... | Sent: 5 stars - "This is terrible.", # Trans: Esto es... | Sent: 1 star - "Hello world.", # Trans: Hola mundo. | Sent: 4/5 stars - "The service was okay.", # Trans: El servicio...| Sent: 3 stars - "I am extremely angry." # Trans: Estoy... | Sent: 1 star + "I love this product.", + "This is terrible.", + "Hello world.", + "The service was okay.", + "I am extremely angry." ] MULTIPLIER = 10 examples = base_examples * MULTIPLIER with TestPipeline() as pipeline: inputs = pipeline | 'CreateInputs' >> beam.Create(examples) - translations = ( + _ = ( inputs | 'RunTranslation' >> RunInference(translator_handler) | 'ExtractSpanish' >> - beam.Map(lambda x: x.inference[0]['translation_text'])) - sentiments = ( + beam.Map(lambda x: x.inference['translation_text'])) + _ = ( inputs | 'RunSentiment' >> RunInference(sentiment_handler) | 'ExtractLabel' >> beam.Map(lambda x: x.inference['label'])) - - expected_translations = [ - "Me encanta este producto.", - "Esto es terrible.", - "Hola mundo.", - "El servicio estuvo bien.", - "Estoy extremadamente enojado." - ] * MULTIPLIER - - assert_that( - translations, - equal_to(expected_translations), - label='CheckTranslations') - - expected_sentiments = [ - "5 stars", # love - "1 star", # terrible - "5 stars", # Hello world - "3 stars", # okay - "1 star" # angry - ] * MULTIPLIER - - assert_that( - sentiments, equal_to(expected_sentiments), label='CheckSentiments') From c8f064e51d4c592e5bd6e8db9e6b3833dee83481 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Fri, 12 Dec 2025 06:42:19 +0000 Subject: [PATCH 26/48] Supports model eviction --- .../apache_beam/ml/inference/model_manager.py | 428 +++++++++++------- .../ml/inference/model_manager_it_test.py | 2 - .../ml/inference/model_manager_test.py | 280 ++++++++---- 3 files changed, 474 insertions(+), 236 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py b/sdks/python/apache_beam/ml/inference/model_manager.py index e8c378d791c7..410e6e7436dd 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager.py +++ b/sdks/python/apache_beam/ml/inference/model_manager.py @@ -34,7 +34,9 @@ import numpy as np from scipy.optimize import nnls import torch -from collections import defaultdict, deque, Counter +import heapq +import itertools +from collections import defaultdict, deque, Counter, OrderedDict from typing import Dict, Any, Tuple, Optional, Callable # Configure Logging @@ -154,6 +156,10 @@ def set_initial_estimate(self, model_tag: str, cost: float): def add_observation( self, active_snapshot: Dict[str, int], peak_memory: float): + logger.info( + "Adding Observation: Snapshot=%s, PeakMemory=%.1f MB", + active_snapshot, + peak_memory) if not active_snapshot: return with self._lock: @@ -236,210 +242,324 @@ class ModelManager: def __init__( self, - monitor: Optional[GPUMonitor] = None, + monitor: Optional['GPUMonitor'] = None, slack_percentage: float = 0.15, poll_interval: float = 0.5, peak_window_seconds: float = 30.0, min_data_points: int = 5, - smoothing_factor: float = 0.2): - self.estimator = ResourceEstimator( + smoothing_factor: float = 0.2, + eviction_cooldown_seconds: float = 10.0, + min_model_copies: int = 1): + + self._estimator = ResourceEstimator( min_data_points=min_data_points, smoothing_factor=smoothing_factor) - self.monitor = monitor if monitor else GPUMonitor( + self._monitor = monitor if monitor else GPUMonitor( poll_interval=poll_interval, peak_window_seconds=peak_window_seconds) - self.slack_percentage = slack_percentage + self._slack_percentage = slack_percentage + + self._eviction_cooldown = eviction_cooldown_seconds + self._min_model_copies = min_model_copies - self.models = defaultdict(list) - self.idle_pool = defaultdict(list) - self.active_counts = Counter() - self.total_active_jobs = 0 - self.pending_reservations = 0.0 + # Resource State + self._models = defaultdict(list) + self._idle_lru = OrderedDict() + self._active_counts = Counter() + self._total_active_jobs = 0 + self._pending_reservations = 0.0 - # State Control - self.isolation_mode = False - self.pending_isolation_count = 0 - self.isolation_baseline = 0.0 + self._isolation_mode = False + self._pending_isolation_count = 0 + self._isolation_baseline = 0.0 + self._wait_queue = [] + self._ticket_counter = itertools.count() self._cv = threading.Condition() - self.monitor.start() + self._load_lock = threading.Lock() + + self._monitor.start() def all_models(self, tag) -> list[Any]: - return self.models[tag] + return self._models[tag] def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: - logger.info( - "Acquiring model for tag: %s | " - "idle_pool size: %d | " - "active_count: %d | " - "total_active_jobs: %d | " - "pending_reservations: %.1f | " - "isolation_mode: %s | " - "pending_isolation_count: %d | " - "estimator known: %s | " - "estimator cost: %.1f MB", - tag, - len(self.idle_pool[tag]), - self.active_counts[tag], - self.total_active_jobs, - self.pending_reservations, - self.isolation_mode, - self.pending_isolation_count, - not self.estimator.is_unknown(tag), - self.estimator.get_estimate(tag), - ) - should_spawn = False - est_cost = 0.0 - is_unknown = False + current_priority = 0 if self._estimator.is_unknown(tag) else 1 + ticket_num = next(self._ticket_counter) + my_id = object() with self._cv: - while True: - is_unknown = self.estimator.is_unknown(tag) + # FAST PATH + if self._pending_isolation_count == 0 and not self._isolation_mode: + cached_instance = self._try_grab_from_lru(tag) + if cached_instance: + return cached_instance - # Path A: Isolation for Unknown Models - if is_unknown: - self.pending_isolation_count += 1 - try: - while self.total_active_jobs > 0 or self.isolation_mode: - self._cv.wait() - if not self.estimator.is_unknown(tag): - is_unknown = False - break - - if not is_unknown: - continue + # SLOW PATH + logger.info("Acquire Queued: tag=%s, priority=%d", tag, current_priority) + heapq.heappush( + self._wait_queue, (current_priority, ticket_num, my_id, tag)) - self.isolation_mode = True - self.total_active_jobs += 1 - self.isolation_baseline, _, _ = self.monitor.get_stats() - self.monitor.reset_peak() - should_spawn = True - break - finally: - self.pending_isolation_count -= 1 - if not should_spawn: - self._cv.notify_all() + should_spawn = False + est_cost = 0.0 + is_unknown = False - # Path B: Concurrent Execution - else: - # Writer Priority (allow unknown models to drain system) - if self.pending_isolation_count > 0 or self.isolation_mode: + try: + while True: + if not self._wait_queue or self._wait_queue[0][2] is not my_id: self._cv.wait() continue - if self.idle_pool[tag]: - instance = self.idle_pool[tag].pop() - self.active_counts[tag] += 1 - self.total_active_jobs += 1 - return instance - - # Capacity Check - curr, peak, total = self.monitor.get_stats() - est_cost = self.estimator.get_estimate(tag) - limit = total * (1 - self.slack_percentage) - base_usage = max(curr, peak) - - if (base_usage + self.pending_reservations + est_cost) <= limit: - self.pending_reservations += est_cost - self.total_active_jobs += 1 - self.active_counts[tag] += 1 + real_is_unknown = self._estimator.is_unknown(tag) + real_priority = 0 if real_is_unknown else 1 + + if current_priority != real_priority: + heapq.heappop(self._wait_queue) + current_priority = real_priority + heapq.heappush( + self._wait_queue, (current_priority, ticket_num, my_id, tag)) + self._cv.notify_all() + continue + + cached_instance = self._try_grab_from_lru(tag) + if cached_instance: + return cached_instance + + is_unknown = real_is_unknown + + # Path A: Isolation + if is_unknown: + if self._total_active_jobs > 0: + self._cv.wait() + continue + + logger.info("Unknown model %s detected. Flushing GPU.", tag) + self._delete_all_models() + + self._isolation_mode = True + self._total_active_jobs += 1 + self._isolation_baseline, _, _ = self._monitor.get_stats() + self._monitor.reset_peak() should_spawn = True break - self._cv.wait() - - # Execution Logic (Spawn) - if should_spawn: - try: - logger.info("Loading model for tag: %s", tag) - isolation_baseline_snap, _, _ = self.monitor.get_stats() - try: - instance = loader_func() - except torch.cuda.OutOfMemoryError: - logger.error( - "CUDA OOM while loading model for tag: %s, " - "clearing all model instances and reset", - tag) - self.force_reset() - pass - logger.info("Model loaded for tag: %s", tag) - _, peak_during_load, _ = self.monitor.get_stats() - snapshot = {tag: 1} - self.estimator.add_observation( - snapshot, peak_during_load - isolation_baseline_snap) - - if not is_unknown: - self.pending_reservations = max( - 0.0, self.pending_reservations - est_cost) - self.models[tag].append(instance) - return instance - - except Exception as e: - self.total_active_jobs -= 1 - if is_unknown: - self.isolation_mode = False - self.isolation_baseline = 0.0 + # Path B: Concurrent else: - self.pending_reservations = max( - 0.0, self.pending_reservations - est_cost) - self.active_counts[tag] -= 1 + if self._pending_isolation_count > 0 or self._isolation_mode: + self._cv.wait() + continue + + curr, _, total = self._monitor.get_stats() + est_cost = self._estimator.get_estimate(tag) + limit = total * (1 - self._slack_percentage) + + # Use current usage for capacity check (ignore old spikes) + if (curr + self._pending_reservations + est_cost) <= limit: + self._pending_reservations += est_cost + self._total_active_jobs += 1 + self._active_counts[tag] += 1 + should_spawn = True + break + + # Evict to make space (passing tag to check demand/existence) + if self._evict_to_make_space(limit, est_cost, requesting_tag=tag): + continue + + self._cv.wait() + + finally: + if self._wait_queue and self._wait_queue[0][2] is my_id: + heapq.heappop(self._wait_queue) self._cv.notify_all() - raise e + + if should_spawn: + return self._spawn_new_model(tag, loader_func, is_unknown, est_cost) def release_model(self, tag: str, instance: Any): with self._cv: try: - self.total_active_jobs -= 1 - if self.active_counts[tag] > 0: - self.active_counts[tag] -= 1 + self._total_active_jobs -= 1 + if self._active_counts[tag] > 0: + self._active_counts[tag] -= 1 - # Return to pool - self.idle_pool[tag].append(instance) + self._idle_lru[id(instance)] = (tag, instance, time.time()) - _, peak_during_job, _ = self.monitor.get_stats() + _, peak_during_job, _ = self._monitor.get_stats() - if self.isolation_mode and self.active_counts[tag] == 0: - cost = max(0, peak_during_job - self.isolation_baseline) - self.estimator.set_initial_estimate(tag, cost) - self.isolation_mode = False - self.isolation_baseline = 0.0 + if self._isolation_mode and self._active_counts[tag] == 0: + cost = max(0, peak_during_job - self._isolation_baseline) + self._estimator.set_initial_estimate(tag, cost) + self._isolation_mode = False + self._isolation_baseline = 0.0 else: - # Solver Snapshot - snapshot = dict(self.active_counts) - for pool_tag, models in self.idle_pool.items(): - snapshot[pool_tag] = snapshot.get(pool_tag, 0) + len(models) - + snapshot = { + t: len(instances) + for t, instances in self._models.items() if len(instances) > 0 + } if snapshot: - logger.info( - "Release Snapshot: %s, Peak: %s MB", snapshot, peak_during_job) - self.estimator.add_observation(snapshot, peak_during_job) + self._estimator.add_observation(snapshot, peak_during_job) finally: self._cv.notify_all() - def delete_all_models(self): - for _, instances in self.models.items(): + def _try_grab_from_lru(self, tag: str) -> Any: + target_key = None + target_instance = None + + for key, (t, instance, _) in reversed(self._idle_lru.items()): + if t == tag: + target_key = key + target_instance = instance + break + + if target_instance: + del self._idle_lru[target_key] + self._active_counts[tag] += 1 + self._total_active_jobs += 1 + return target_instance + return None + + def _evict_to_make_space( + self, limit: float, est_cost: float, requesting_tag: str) -> bool: + """ + Evicts models based on Demand Magnitude + Tiers. + Crucially: If we have 0 active copies of 'requesting_tag', we FORCE eviction + of the lowest-demand candidate to avoid starvation. + """ + evicted_something = False + curr, _, _ = self._monitor.get_stats() + projected_usage = curr + self._pending_reservations + est_cost + + if projected_usage <= limit: + return False + + now = time.time() + + demand_map = Counter() + for item in self._wait_queue: + if len(item) >= 4: + demand_map[item[3]] += 1 + + my_demand = demand_map[requesting_tag] + am_i_starving = len(self._models[requesting_tag]) == 0 + + candidates = [] + for key, (tag, instance, release_time) in self._idle_lru.items(): + candidate_demand = demand_map[tag] + + if not am_i_starving: + if candidate_demand >= my_demand: + continue + + age = now - release_time + is_cold = age >= self._eviction_cooldown + + total_copies = len(self._models[tag]) + is_surplus = total_copies > self._min_model_copies + + if is_cold and is_surplus: tier = 0 + elif not is_cold and is_surplus: tier = 1 + elif is_cold and not is_surplus: tier = 2 + else: tier = 3 + + score = (candidate_demand * 10) + tier + + candidates.append((score, release_time, key, tag, instance)) + + candidates.sort(key=lambda x: (x[0], x[1])) + + for score, _, key, tag, instance in candidates: + if projected_usage <= limit: + break + + if key not in self._idle_lru: continue + + self._perform_eviction(key, tag, instance, score) + evicted_something = True + + curr, _, _ = self._monitor.get_stats() + projected_usage = curr + self._pending_reservations + est_cost + + return evicted_something + + def _perform_eviction(self, key, tag, instance, score): + logger.info("Evicting Model: %s (Score %d)", tag, score) + + if key in self._idle_lru: + del self._idle_lru[key] + + if hasattr(instance, "unsafe_hard_delete"): + instance.unsafe_hard_delete() + + if instance in self._models[tag]: + self._models[tag].remove(instance) + + del instance + gc.collect() + torch.cuda.empty_cache() + self._monitor.reset_peak() + + def _spawn_new_model(self, tag, loader_func, is_unknown, est_cost): + try: + with self._load_lock: + logger.info("Loading Model: %s (Unknown: %s)", tag, is_unknown) + isolation_baseline_snap, _, _ = self._monitor.get_stats() + instance = loader_func() + _, peak_during_load, _ = self._monitor.get_stats() + + with self._cv: + snapshot = {tag: 1} + self._estimator.add_observation( + snapshot, peak_during_load - isolation_baseline_snap) + + if not is_unknown: + self._pending_reservations = max( + 0.0, self._pending_reservations - est_cost) + self._models[tag].append(instance) + return instance + + except Exception as e: + logger.error("Load Failed: %s. Error: %s", tag, e) + with self._cv: + self._total_active_jobs -= 1 + if is_unknown: + self._isolation_mode = False + self._isolation_baseline = 0.0 + else: + self._pending_reservations = max( + 0.0, self._pending_reservations - est_cost) + self._active_counts[tag] -= 1 + self._cv.notify_all() + raise e + + def _delete_all_models(self): + self._idle_lru.clear() + for _, instances in self._models.items(): for instance in instances: if hasattr(instance, "unsafe_hard_delete"): instance.unsafe_hard_delete() del instance + self._models.clear() + self._active_counts.clear() gc.collect() torch.cuda.empty_cache() - def force_reset(self): - self.delete_all_models() - self.models = defaultdict(list) - self.idle_pool = defaultdict(list) - self.active_counts = Counter() - self.total_active_jobs = 0 - self.pending_reservations = 0.0 - self.isolation_mode = False - self.pending_isolation_count = 0 - self.isolation_baseline = 0.0 + def _force_reset(self): + logger.warning("Force Reset Triggered") + self._delete_all_models() + self._models = defaultdict(list) + self._idle_lru = OrderedDict() + self._active_counts = Counter() + self._wait_queue = [] + self._total_active_jobs = 0 + self._pending_reservations = 0.0 + self._isolation_mode = False + self._pending_isolation_count = 0 + self._isolation_baseline = 0.0 def shutdown(self): - self.delete_all_models() + self._delete_all_models() gc.collect() torch.cuda.empty_cache() - self.monitor.stop() + self._monitor.stop() def __del__(self): self.shutdown() diff --git a/sdks/python/apache_beam/ml/inference/model_manager_it_test.py b/sdks/python/apache_beam/ml/inference/model_manager_it_test.py index e626c0c4b5c8..4dec86d132b2 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager_it_test.py +++ b/sdks/python/apache_beam/ml/inference/model_manager_it_test.py @@ -126,8 +126,6 @@ def test_parallel_inference_branches(self): +--------+--------+ | | [ Translation ] [ Sentiment ] - | | - [ Check Trans ] [ Check Sent ] """ translator_handler = HuggingFacePipelineModelHandler( diff --git a/sdks/python/apache_beam/ml/inference/model_manager_test.py b/sdks/python/apache_beam/ml/inference/model_manager_test.py index a1d9f7276523..9070a3270396 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager_test.py +++ b/sdks/python/apache_beam/ml/inference/model_manager_test.py @@ -35,6 +35,7 @@ def get_stats(self): def reset_peak(self): with self._lock: self._peak = self._current + self.history = [self._current] def set_usage(self, current_mb): """Sets absolute usage (legacy helper).""" @@ -61,6 +62,20 @@ def free(self, amount_mb): self._peak = max(self.history) +class MockModel: + def __init__(self, name, size, monitor): + self.name = name + self.size = size + self.monitor = monitor + self.deleted = False + self.monitor.allocate(size) + + def unsafe_hard_delete(self): + if not self.deleted: + self.monitor.free(self.size) + self.deleted = True + + class TestModelManager(unittest.TestCase): def setUp(self): """Force reset the Singleton ModelManager before every test.""" @@ -78,7 +93,7 @@ def test_model_manager_capacity_check(self): """ model_name = "known_model" model_cost = 3000.0 - self.manager.estimator.set_initial_estimate(model_name, model_cost) + self.manager._estimator.set_initial_estimate(model_name, model_cost) acquired_refs = [] def loader(): @@ -113,7 +128,7 @@ def run_inference(): def test_model_manager_unknown_model_runs_isolated(self): """Test that a model with no history runs in isolation.""" model_name = "unknown_model_v1" - self.assertTrue(self.manager.estimator.is_unknown(model_name)) + self.assertTrue(self.manager._estimator.is_unknown(model_name)) def dummy_loader(): time.sleep(0.05) @@ -121,30 +136,30 @@ def dummy_loader(): instance = self.manager.acquire_model(model_name, dummy_loader) - self.assertTrue(self.manager.isolation_mode) - self.assertEqual(self.manager.total_active_jobs, 1) + self.assertTrue(self.manager._isolation_mode) + self.assertEqual(self.manager._total_active_jobs, 1) self.manager.release_model(model_name, instance) - self.assertFalse(self.manager.isolation_mode) - self.assertFalse(self.manager.estimator.is_unknown(model_name)) + self.assertFalse(self.manager._isolation_mode) + self.assertFalse(self.manager._estimator.is_unknown(model_name)) def test_model_manager_concurrent_execution(self): """Test that multiple small known models can run together.""" model_a = "small_model_a" model_b = "small_model_b" - self.manager.estimator.set_initial_estimate(model_a, 1000.0) - self.manager.estimator.set_initial_estimate(model_b, 1000.0) + self.manager._estimator.set_initial_estimate(model_a, 1000.0) + self.manager._estimator.set_initial_estimate(model_b, 1000.0) self.mock_monitor.set_usage(1000.0) inst_a = self.manager.acquire_model(model_a, lambda: "A") inst_b = self.manager.acquire_model(model_b, lambda: "B") - self.assertEqual(self.manager.total_active_jobs, 2) + self.assertEqual(self.manager._total_active_jobs, 2) self.manager.release_model(model_a, inst_a) self.manager.release_model(model_b, inst_b) - self.assertEqual(self.manager.total_active_jobs, 0) + self.assertEqual(self.manager._total_active_jobs, 0) def test_model_manager_concurrent_mixed_workload_convergence(self): """ @@ -157,30 +172,24 @@ def run_job(model_name): cost = TRUE_COSTS[model_name] def loader(): - self.mock_monitor.allocate(cost) - time.sleep(0.01) - return f"instance_{model_name}" + model = MockModel(model_name, cost, self.mock_monitor) + return model instance = self.manager.acquire_model(model_name, loader) time.sleep(random.uniform(0.01, 0.05)) self.manager.release_model(model_name, instance) # Create a workload stream - # 15 Small jobs, 15 Medium jobs, mixed order workload = ["model_small"] * 15 + ["model_medium"] * 15 random.shuffle(workload) - # We use a thread pool slightly larger than the theoretical capacity - # to force queuing and reuse logic. - # Capacity ~12000. Small=1500, Med=3000. - # Max concurrent approx: 4 Med (12000) or 8 Small (12000). with ThreadPoolExecutor(max_workers=8) as executor: futures = [executor.submit(run_job, name) for name in workload] for f in futures: f.result() - est_small = self.manager.estimator.get_estimate("model_small") - est_med = self.manager.estimator.get_estimate("model_medium") + est_small = self.manager._estimator.get_estimate("model_small") + est_med = self.manager._estimator.get_estimate("model_medium") self.assertAlmostEqual(est_small, TRUE_COSTS["model_small"], delta=100.0) self.assertAlmostEqual(est_med, TRUE_COSTS["model_medium"], delta=100.0) @@ -188,7 +197,7 @@ def loader(): def test_model_manager_oom_recovery(self): """Test that the manager recovers state if a loader crashes.""" model_name = "crasher_model" - self.manager.estimator.set_initial_estimate(model_name, 1000.0) + self.manager._estimator.set_initial_estimate(model_name, 1000.0) def crashing_loader(): raise RuntimeError("CUDA OOM or similar") @@ -196,8 +205,8 @@ def crashing_loader(): with self.assertRaises(RuntimeError): self.manager.acquire_model(model_name, crashing_loader) - self.assertEqual(self.manager.total_active_jobs, 0) - self.assertEqual(self.manager.pending_reservations, 0.0) + self.assertEqual(self.manager._total_active_jobs, 0) + self.assertEqual(self.manager._pending_reservations, 0.0) self.assertFalse(self.manager._cv._is_owned()) def test_model_managaer_force_reset_on_exception(self): @@ -214,11 +223,11 @@ def dummy_loader(): self.manager.release_model(model_name, instance) instance = self.manager.acquire_model(model_name, dummy_loader) except RuntimeError: - self.manager.force_reset() - self.assertTrue(len(self.manager.models[model_name]) == 0) - self.assertEqual(self.manager.total_active_jobs, 0) - self.assertEqual(self.manager.pending_reservations, 0.0) - self.assertFalse(self.manager.isolation_mode) + self.manager._force_reset() + self.assertTrue(len(self.manager._models[model_name]) == 0) + self.assertEqual(self.manager._total_active_jobs, 0) + self.assertEqual(self.manager._pending_reservations, 0.0) + self.assertFalse(self.manager._isolation_mode) pass instance = self.manager.acquire_model(model_name, lambda: "model_instance") @@ -226,23 +235,19 @@ def dummy_loader(): def test_single_model_convergence_with_fluctuations(self): """ - Tests that the estimator converges to the true usage with: - 1. A single model type. - 2. Initial 'Load' cost that is lower than 'Inference' cost. - 3. High variance/fluctuation during inference. - """ + Tests that the estimator converges to the true usage with fluctuations. + """ model_name = "fluctuating_model" model_cost = 3000.0 - load_cost = 2000.0 # Initial load cost underestimates true cost + load_cost = 2000.0 def loader(): self.mock_monitor.allocate(load_cost) return model_name - # Check that initial estimate is only the load cost model = self.manager.acquire_model(model_name, loader) self.manager.release_model(model_name, model) - initial_est = self.manager.estimator.get_estimate(model_name) + initial_est = self.manager._estimator.get_estimate(model_name) self.assertEqual(initial_est, load_cost) def run_inference(): @@ -260,10 +265,158 @@ def run_inference(): for f in futures: f.result() - est_cost = self.manager.estimator.get_estimate(model_name) + est_cost = self.manager._estimator.get_estimate(model_name) self.assertAlmostEqual(est_cost, model_cost, delta=100.0) +class TestModelManagerEviction(unittest.TestCase): + def setUp(self): + self.mock_monitor = MockGPUMonitor(total_memory=12000.0) + ModelManager._instance = None + self.manager = ModelManager( + monitor=self.mock_monitor, + slack_percentage=0.0, + min_data_points=1, + eviction_cooldown_seconds=10.0, + min_model_copies=1) + + def tearDown(self): + self.manager.shutdown() + + def create_loader(self, name, size): + return lambda: MockModel(name, size, self.mock_monitor) + + def test_basic_lru_eviction(self): + self.manager._estimator.set_initial_estimate("A", 4000) + self.manager._estimator.set_initial_estimate("B", 4000) + self.manager._estimator.set_initial_estimate("C", 5000) + + model_a = self.manager.acquire_model("A", self.create_loader("A", 4000)) + self.manager.release_model("A", model_a) + + model_b = self.manager.acquire_model("B", self.create_loader("B", 4000)) + self.manager.release_model("B", model_b) + + key_a = list(self.manager._idle_lru.keys())[0] + self.manager._idle_lru[key_a] = ("A", model_a, time.time() - 20.0) + + key_b = list(self.manager._idle_lru.keys())[1] + self.manager._idle_lru[key_b] = ("B", model_b, time.time() - 20.0) + + model_a_again = self.manager.acquire_model( + "A", self.create_loader("A", 4000)) + self.manager.release_model("A", model_a_again) + + self.manager.acquire_model("C", self.create_loader("C", 5000)) + + self.assertEqual(len(self.manager.all_models("B")), 0) + self.assertEqual(len(self.manager.all_models("A")), 1) + + def test_chained_eviction(self): + self.manager._estimator.set_initial_estimate("big_guy", 8000) + models = [] + for i in range(4): + name = f"small_{i}" + m = self.manager.acquire_model(name, self.create_loader(name, 3000)) + self.manager.release_model(name, m) + models.append(m) + + self.manager.acquire_model("big_guy", self.create_loader("big_guy", 8000)) + + self.assertTrue(models[0].deleted) + self.assertTrue(models[1].deleted) + self.assertTrue(models[2].deleted) + self.assertFalse(models[3].deleted) + + def test_active_models_are_protected(self): + self.manager._estimator.set_initial_estimate("A", 6000) + self.manager._estimator.set_initial_estimate("B", 4000) + self.manager._estimator.set_initial_estimate("C", 4000) + + model_a = self.manager.acquire_model("A", self.create_loader("A", 6000)) + model_b = self.manager.acquire_model("B", self.create_loader("B", 4000)) + self.manager.release_model("B", model_b) + + key_b = list(self.manager._idle_lru.keys())[0] + self.manager._idle_lru[key_b] = ("B", model_b, time.time() - 20.0) + + def acquire_c(): + return self.manager.acquire_model("C", self.create_loader("C", 4000)) + + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(acquire_c) + model_c = future.result(timeout=2.0) + + self.assertTrue(model_b.deleted) + self.assertFalse(model_a.deleted) + + self.manager.release_model("A", model_a) + self.manager.release_model("C", model_c) + + def test_unknown_model_clears_memory(self): + self.manager._estimator.set_initial_estimate("A", 2000) + model_a = self.manager.acquire_model("A", self.create_loader("A", 2000)) + self.manager.release_model("A", model_a) + self.assertFalse(model_a.deleted) + + self.assertTrue(self.manager._estimator.is_unknown("X")) + model_x = self.manager.acquire_model("X", self.create_loader("X", 10000)) + + self.assertTrue(model_a.deleted, "Model A should be deleted for isolation") + self.assertEqual(len(self.manager.all_models("A")), 0) + self.assertTrue(self.manager._isolation_mode) + self.manager.release_model("X", model_x) + + def test_concurrent_eviction_pressure(self): + def worker(idx): + name = f"model_{idx % 5}" + try: + m = self.manager.acquire_model(name, self.create_loader(name, 4000)) + time.sleep(0.001) + self.manager.release_model(name, m) + except Exception: + pass + + with ThreadPoolExecutor(max_workers=8) as executor: + futures = [executor.submit(worker, i) for i in range(50)] + for f in futures: + f.result() + + curr, _, _ = self.mock_monitor.get_stats() + expected_usage = 0 + for _, instances in self.manager._models.items(): + expected_usage += len(instances) * 4000 + + self.assertAlmostEqual(curr, expected_usage) + + def test_starvation_prevention_overrides_demand(self): + self.manager._estimator.set_initial_estimate("A", 12000) + m_a = self.manager.acquire_model("A", self.create_loader("A", 12000)) + self.manager.release_model("A", m_a) + + def cycle_a(): + try: + m = self.manager.acquire_model("A", self.create_loader("A", 12000)) + time.sleep(0.3) + self.manager.release_model("A", m) + except Exception: + pass + + executor = ThreadPoolExecutor(max_workers=5) + for _ in range(5): + executor.submit(cycle_a) + + def acquire_b(): + return self.manager.acquire_model("B", self.create_loader("B", 4000)) + + b_future = executor.submit(acquire_b) + model_b = b_future.result() + + self.assertTrue(m_a.deleted) + self.manager.release_model("B", model_b) + executor.shutdown(wait=True) + + class TestGPUMonitor(unittest.TestCase): def setUp(self): self.subprocess_patcher = patch('subprocess.check_output') @@ -286,7 +439,7 @@ def test_init_hardware_missing(self): self.assertFalse(monitor._gpu_available) self.assertEqual(monitor._total_memory, 12000.0) - @patch('time.sleep') # Patch sleep to speed up the test + @patch('time.sleep') def test_polling_updates_stats(self, mock_sleep): """Test that the polling loop updates current and peak usage.""" def subprocess_side_effect(*args, **kwargs): @@ -299,7 +452,7 @@ def subprocess_side_effect(*args, **kwargs): raise Exception("Unexpected command") self.mock_subprocess.side_effect = subprocess_side_effect - self.mock_subprocess.return_value = None # Clear default + self.mock_subprocess.return_value = None monitor = GPUMonitor() monitor.start() @@ -325,7 +478,7 @@ def test_reset_peak(self): monitor.reset_peak() _, peak, _ = monitor.get_stats() - self.assertEqual(peak, 2000.0) # Peak should reset to current + self.assertEqual(peak, 2000.0) class TestResourceEstimatorSolver(unittest.TestCase): @@ -334,45 +487,23 @@ def setUp(self): @patch('apache_beam.ml.inference.model_manager.nnls') def test_solver_respects_min_data_points(self, mock_nnls): - """ - Test that the solver does not run until the total number of observations - reaches MIN_DATA_POINTS, even if we have enough unique configurations. - """ mock_nnls.return_value = ([100.0, 50.0], 0.0) self.estimator.add_observation({'model_A': 1}, 500) self.estimator.add_observation({'model_B': 1}, 500) - self.assertFalse( - mock_nnls.called, - "Should not solve: Not enough data points or unique keys") + self.assertFalse(mock_nnls.called) - # Now we have 3 unique keys. For 2 models, 3 >= 2+1. - # The 'Unique Models' constraint is now SATISFIED. - # However, total observations is 3. MIN_DATA_POINTS is 5. self.estimator.add_observation({'model_A': 1, 'model_B': 1}, 1000) - - self.assertFalse( - mock_nnls.called, - "Should not solve yet: Total obs (3) < MIN_DATA_POINTS") + self.assertFalse(mock_nnls.called) self.estimator.add_observation({'model_A': 1}, 500) self.assertFalse(mock_nnls.called) - # Now Total Obs = 5. MIN_DATA_POINTS satisfied. - # Unique Keys = 3. Variety satisfied. self.estimator.add_observation({'model_B': 1}, 500) - - self.assertTrue( - mock_nnls.called, - "Solver SHOULD run now that min data points are reached") + self.assertTrue(mock_nnls.called) @patch('apache_beam.ml.inference.model_manager.nnls') def test_solver_respects_unique_model_constraint(self, mock_nnls): - """ - Test that the solver does not run if we have a lot of data points, - but they don't represent enough unique configurations to solve - the linear system safely. - """ mock_nnls.return_value = ([100.0, 100.0, 50.0], 0.0) for _ in range(5): @@ -381,23 +512,12 @@ def test_solver_respects_unique_model_constraint(self, mock_nnls): for _ in range(5): self.estimator.add_observation({'model_C': 1}, 400) - # Current State: - # Total Observations: 10 (>> MIN_DATA_POINTS) - # Unique Keys: 2 ({A,B} and {C}) - # Required Keys: 4 (A, B, C + Bias) - - self.assertFalse( - mock_nnls.called, - "Should not solve: 2 unique keys provided, " - "but need 4 to solve for 3 models + bias") + self.assertFalse(mock_nnls.called) - # Now let's satisfy the constraint by adding distinct configurations - self.estimator.add_observation({'model_A': 1}, 300) # Key 3 - self.estimator.add_observation({'model_B': 1}, 300) # Key 4 + self.estimator.add_observation({'model_A': 1}, 300) + self.estimator.add_observation({'model_B': 1}, 300) - self.assertTrue( - mock_nnls.called, - "Solver SHOULD run now that we have enough unique configurations") + self.assertTrue(mock_nnls.called) if __name__ == "__main__": From 3b0f4dfd9984765c886327098d201ce94fb68c43 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Fri, 12 Dec 2025 08:16:34 +0000 Subject: [PATCH 27/48] Fix override check --- sdks/python/apache_beam/ml/inference/base.py | 8 +++++--- sdks/python/apache_beam/ml/inference/model_manager.py | 1 - 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index c646b65cb4c6..99f37d0f90b7 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -331,7 +331,9 @@ def model_copies(self) -> int: def model_copies_not_overriden(self) -> bool: """Returns whether the model_copies method has been overridden by the child class. Used to determine if the model manager should be used.""" - return type(self).model_copies == ModelHandler.model_copies + return type( + self + ).model_copies.__qualname__ == ModelHandler.model_copies.__qualname__ def override_metrics(self, metrics_namespace: str = '') -> bool: """Returns a boolean representing whether or not a model handler will @@ -1884,8 +1886,8 @@ def load(): # Ensure the tag we're loading is valid, if not replace it with a valid tag self._cur_tag = self._model_metadata.get_valid_tag(model_tag) if self.use_model_manager and \ - self._model_handler.model_copies_not_overriden( - ): + self._model_handler.model_copies_not_overriden(): + logging.info("Using Model Manager to manage models automatically.") model_manager = multi_process_shared.MultiProcessShared( lambda: ModelManager(**self._model_manager_args), tag='model_manager', diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py b/sdks/python/apache_beam/ml/inference/model_manager.py index 410e6e7436dd..fea578aa5054 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager.py +++ b/sdks/python/apache_beam/ml/inference/model_manager.py @@ -39,7 +39,6 @@ from collections import defaultdict, deque, Counter, OrderedDict from typing import Dict, Any, Tuple, Optional, Callable -# Configure Logging logger = logging.getLogger(__name__) From bafe6e481c21f9f46fd3578d6491fb02f1b2b1fd Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Fri, 12 Dec 2025 20:57:06 +0000 Subject: [PATCH 28/48] Fix parallel process issue and add capabiility to spawn process with MPS --- sdks/python/apache_beam/ml/inference/base.py | 4 +- .../apache_beam/utils/multi_process_shared.py | 178 ++++++++++++++++-- .../utils/multi_process_shared_test.py | 92 ++++++++- 3 files changed, 248 insertions(+), 26 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index 99f37d0f90b7..bec360813d56 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -1766,8 +1766,10 @@ def __init__(self, loader_func, model_tag): def __call__(self): unique_tag = self.model_tag + '_' + uuid.uuid4().hex + # Ensure that each model loaded in a different process for parallelism return multi_process_shared.MultiProcessShared( - self.loader_func, tag=unique_tag, always_proxy=True).acquire() + self.loader_func, tag=unique_tag, always_proxy=True, + spawn_process=True).acquire() class _SharedModelWrapper(): diff --git a/sdks/python/apache_beam/utils/multi_process_shared.py b/sdks/python/apache_beam/utils/multi_process_shared.py index f92094157181..7aa8701d2b11 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared.py +++ b/sdks/python/apache_beam/utils/multi_process_shared.py @@ -25,6 +25,10 @@ import logging import multiprocessing.managers import os +import time +import traceback +import atexit +import sys import tempfile import threading from typing import Any @@ -218,6 +222,72 @@ def unsafe_hard_delete(self): return self._proxyObject.unsafe_hard_delete() +def _run_server_process(address_file, tag, constructor, authkey): + """ + Runs in a separate process. + Includes a 'Suicide Pact' monitor: If parent dies, I die. + """ + parent_pid = os.getppid() + + def cleanup_files(): + logging.info("Server process exiting. Deleting files for %s", tag) + try: + if os.path.exists(address_file): + os.remove(address_file) + if os.path.exists(address_file + ".error"): + os.remove(address_file + ".error") + except Exception: + pass + + def _monitor_parent(): + """Checks if parent is alive every second.""" + while True: + try: + os.kill(parent_pid, 0) + except OSError: + logging.warning( + "Process %s detected Parent %s died. Self-destructing.", + os.getpid(), + parent_pid) + cleanup_files() + os._exit(0) + time.sleep(0.5) + + atexit.register(cleanup_files) + + try: + t = threading.Thread(target=_monitor_parent, daemon=True) + t.start() + + logging.getLogger().setLevel(logging.INFO) + multiprocessing.current_process().authkey = authkey + + serving_manager = _SingletonRegistrar( + address=('localhost', 0), authkey=authkey) + _process_level_singleton_manager.register_singleton( + constructor, tag, initialize_eagerly=True) + + server = serving_manager.get_server() + logging.info( + 'Process %s: Proxy serving %s at %s', os.getpid(), tag, server.address) + + with open(address_file + '.tmp', 'w') as fout: + fout.write('%s:%d' % server.address) + os.rename(address_file + '.tmp', address_file) + + server.serve_forever() + + except Exception: + tb = traceback.format_exc() + try: + with open(address_file + ".error.tmp", 'w') as fout: + fout.write(tb) + os.rename(address_file + ".error.tmp", address_file + ".error") + except Exception: + print(f"CRITICAL ERROR IN SHARED SERVER:\n{tb}", file=sys.stderr) + os._exit(1) + + class MultiProcessShared(Generic[T]): """MultiProcessShared is used to share a single object across processes. @@ -266,7 +336,8 @@ def __init__( tag: Any, *, path: str = tempfile.gettempdir(), - always_proxy: Optional[bool] = None): + always_proxy: Optional[bool] = None, + spawn_process: bool = False): self._constructor = constructor self._tag = tag self._path = path @@ -276,6 +347,7 @@ def __init__( self._rpc_address = None self._cross_process_lock = fasteners.InterProcessLock( os.path.join(self._path, self._tag) + '.lock') + self._spawn_process = spawn_process def _get_manager(self): if self._manager is None: @@ -332,22 +404,88 @@ def unsafe_hard_delete(self): self._get_manager().unsafe_hard_delete_singleton(self._tag) def _create_server(self, address_file): - # We need to be able to authenticate with both the manager and the process. - self._serving_manager = _SingletonRegistrar( - address=('localhost', 0), authkey=AUTH_KEY) - multiprocessing.current_process().authkey = AUTH_KEY - # Initialize eagerly to avoid acting as the server if there are issues. - # Note, however, that _create_server itself is called lazily. - _process_level_singleton_manager.register_singleton( - self._constructor, self._tag, initialize_eagerly=True) - self._server = self._serving_manager.get_server() - logging.info( - 'Starting proxy server at %s for shared %s', - self._server.address, - self._tag) - with open(address_file + '.tmp', 'w') as fout: - fout.write('%s:%d' % self._server.address) - os.rename(address_file + '.tmp', address_file) - t = threading.Thread(target=self._server.serve_forever, daemon=True) - t.start() - logging.info('Done starting server') + if self._spawn_process: + error_file = address_file + ".error" + + if os.path.exists(error_file): + try: + os.remove(error_file) + except OSError: + pass + + ctx = multiprocessing.get_context('spawn') + p = ctx.Process( + target=_run_server_process, + args=(address_file, self._tag, self._constructor, AUTH_KEY), + daemon=False # Must be False for nested proxies + ) + p.start() + logging.info("Parent: Waiting for %s to write address file...", self._tag) + + def cleanup_process(): + if p.is_alive(): + logging.info( + "Parent: Terminating server process %s for %s", p.pid, self._tag) + p.terminate() + p.join() + try: + if os.path.exists(address_file): + os.remove(address_file) + if os.path.exists(error_file): + os.remove(error_file) + except Exception: + pass + + atexit.register(cleanup_process) + + start_time = time.time() + while True: + if os.path.exists(address_file): + break + + if os.path.exists(error_file): + with open(error_file, 'r') as f: + error_msg = f.read() + try: + os.remove(error_file) + except OSError: + pass + + if p.is_alive(): p.terminate() + raise RuntimeError(f"Shared Server Process crashed:\n{error_msg}") + + if not p.is_alive(): + exit_code = p.exitcode + raise RuntimeError( + "Shared Server Process died unexpectedly" + f" with exit code {exit_code}") + + if time.time() - start_time > 30: + if p.is_alive(): p.terminate() + raise TimeoutError( + f"Timed out waiting for server process {self._tag} to start.") + + time.sleep(0.05) + + logging.info('External process successfully started for %s', self._tag) + else: + # We need to be able to authenticate with both the manager + # and the process. + self._serving_manager = _SingletonRegistrar( + address=('localhost', 0), authkey=AUTH_KEY) + multiprocessing.current_process().authkey = AUTH_KEY + # Initialize eagerly to avoid acting as the server if there are issues. + # Note, however, that _create_server itself is called lazily. + _process_level_singleton_manager.register_singleton( + self._constructor, self._tag, initialize_eagerly=True) + self._server = self._serving_manager.get_server() + logging.info( + 'Starting proxy server at %s for shared %s', + self._server.address, + self._tag) + with open(address_file + '.tmp', 'w') as fout: + fout.write('%s:%d' % self._server.address) + os.rename(address_file + '.tmp', address_file) + t = threading.Thread(target=self._server.serve_forever, daemon=True) + t.start() + logging.info('Done starting server') diff --git a/sdks/python/apache_beam/utils/multi_process_shared_test.py b/sdks/python/apache_beam/utils/multi_process_shared_test.py index d87eeea1c01a..26f08b13bdef 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared_test.py +++ b/sdks/python/apache_beam/utils/multi_process_shared_test.py @@ -18,6 +18,9 @@ import logging import threading +import tempfile +import os +import multiprocessing import unittest from typing import Any @@ -82,6 +85,12 @@ def __getattribute__(self, __name: str) -> Any: return object.__getattribute__(self, __name) +class SimpleClass: + def make_proxy(self): + return multi_process_shared.MultiProcessShared( + Counter, tag='proxy_on_proxy', always_proxy=True).acquire() + + class MultiProcessSharedTest(unittest.TestCase): @classmethod def setUpClass(cls): @@ -271,11 +280,6 @@ def test_release_always_proxy(self): counter1.get() def test_proxy_on_proxy(self): - class SimpleClass: - def make_proxy(self): - return multi_process_shared.MultiProcessShared( - Counter, tag='proxy_on_proxy', always_proxy=True).acquire() - shared1 = multi_process_shared.MultiProcessShared( SimpleClass, tag='proxy_on_proxy_main', always_proxy=True) instance = shared1.acquire() @@ -283,6 +287,84 @@ def make_proxy(self): self.assertEqual(proxy_instance.increment(), 1) +class MultiProcessSharedSpawnProcessTest(unittest.TestCase): + def setUp(self): + tempdir = tempfile.gettempdir() + for tag in ['basic', + 'proxy_on_proxy', + 'proxy_on_proxy_main', + 'to_delete', + 'mix1', + 'mix2']: + for ext in ['', '.address', '.address.error']: + try: + os.remove(os.path.join(tempdir, tag + ext)) + except OSError: + pass + + def tearDown(self): + for p in multiprocessing.active_children(): + p.terminate() + p.join() + + def test_call(self): + shared = multi_process_shared.MultiProcessShared( + Counter, tag='basic', always_proxy=True, spawn_process=True).acquire() + self.assertEqual(shared.get(), 0) + self.assertEqual(shared.increment(), 1) + self.assertEqual(shared.increment(10), 11) + self.assertEqual(shared.increment(value=10), 21) + self.assertEqual(shared.get(), 21) + + def test_proxy_on_proxy(self): + shared1 = multi_process_shared.MultiProcessShared( + SimpleClass, + tag='proxy_on_proxy_main', + always_proxy=True, + spawn_process=True) + instance = shared1.acquire() + proxy_instance = instance.make_proxy() + self.assertEqual(proxy_instance.increment(), 1) + + def test_unsafe_hard_delete_autoproxywrapper(self): + shared1 = multi_process_shared.MultiProcessShared( + Counter, tag='to_delete', always_proxy=True, spawn_process=True) + shared2 = multi_process_shared.MultiProcessShared( + Counter, tag='to_delete', always_proxy=True, spawn_process=True) + counter3 = multi_process_shared.MultiProcessShared( + Counter, tag='basic', always_proxy=True, spawn_process=True).acquire() + + counter1 = shared1.acquire() + counter2 = shared2.acquire() + self.assertEqual(counter1.increment(), 1) + self.assertEqual(counter2.increment(), 2) + + counter2.unsafe_hard_delete() + + with self.assertRaises(Exception): + counter1.get() + with self.assertRaises(Exception): + counter2.get() + + counter4 = multi_process_shared.MultiProcessShared( + Counter, tag='to_delete', always_proxy=True, + spawn_process=True).acquire() + + self.assertEqual(counter3.increment(), 1) + self.assertEqual(counter4.increment(), 1) + + def test_mix_usage(self): + shared1 = multi_process_shared.MultiProcessShared( + Counter, tag='mix1', always_proxy=True, spawn_process=False).acquire() + shared2 = multi_process_shared.MultiProcessShared( + Counter, tag='mix2', always_proxy=True, spawn_process=True).acquire() + + self.assertEqual(shared1.get(), 0) + self.assertEqual(shared1.increment(), 1) + self.assertEqual(shared2.get(), 0) + self.assertEqual(shared2.increment(), 1) + + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main() From 716d27339c4ea814ed20a3e64d03980962f11fbc Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Fri, 12 Dec 2025 22:08:25 +0000 Subject: [PATCH 29/48] Remove override check --- sdks/python/apache_beam/ml/inference/base.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index bec360813d56..ebb3c3709ee9 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -328,13 +328,6 @@ def model_copies(self) -> int: of being loaded per process.""" return 1 - def model_copies_not_overriden(self) -> bool: - """Returns whether the model_copies method has been overridden by the - child class. Used to determine if the model manager should be used.""" - return type( - self - ).model_copies.__qualname__ == ModelHandler.model_copies.__qualname__ - def override_metrics(self, metrics_namespace: str = '') -> bool: """Returns a boolean representing whether or not a model handler will override metrics reporting. If True, RunInference will not report any @@ -1887,8 +1880,7 @@ def load(): model_tag = side_input_model_path # Ensure the tag we're loading is valid, if not replace it with a valid tag self._cur_tag = self._model_metadata.get_valid_tag(model_tag) - if self.use_model_manager and \ - self._model_handler.model_copies_not_overriden(): + if self.use_model_manager: logging.info("Using Model Manager to manage models automatically.") model_manager = multi_process_shared.MultiProcessShared( lambda: ModelManager(**self._model_manager_args), From 6231af0d44901018114d17147fcaacfc6469648d Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Fri, 12 Dec 2025 23:32:11 +0000 Subject: [PATCH 30/48] Update the default slack percentage --- sdks/python/apache_beam/ml/inference/model_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py b/sdks/python/apache_beam/ml/inference/model_manager.py index fea578aa5054..81acb4d3e482 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager.py +++ b/sdks/python/apache_beam/ml/inference/model_manager.py @@ -242,7 +242,7 @@ class ModelManager: def __init__( self, monitor: Optional['GPUMonitor'] = None, - slack_percentage: float = 0.15, + slack_percentage: float = 0.10, poll_interval: float = 0.5, peak_window_seconds: float = 30.0, min_data_points: int = 5, From 6b428db579783b159f35029d32f223e4d66adbe2 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Sat, 13 Dec 2025 00:29:10 +0000 Subject: [PATCH 31/48] Skip tests if dependencies is missing --- .../python/apache_beam/ml/inference/model_manager_it_test.py | 4 +++- sdks/python/apache_beam/ml/inference/model_manager_test.py | 5 ++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/model_manager_it_test.py b/sdks/python/apache_beam/ml/inference/model_manager_it_test.py index 4dec86d132b2..9d9c31f77490 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager_it_test.py +++ b/sdks/python/apache_beam/ml/inference/model_manager_it_test.py @@ -1,8 +1,10 @@ import unittest -import torch import apache_beam as beam from apache_beam.ml.inference.base import RunInference + +# pylint: disable=ungrouped-imports try: + import torch from apache_beam.ml.inference.huggingface_inference import HuggingFacePipelineModelHandler except ImportError as e: raise unittest.SkipTest( diff --git a/sdks/python/apache_beam/ml/inference/model_manager_test.py b/sdks/python/apache_beam/ml/inference/model_manager_test.py index 9070a3270396..9334df95222a 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager_test.py +++ b/sdks/python/apache_beam/ml/inference/model_manager_test.py @@ -5,7 +5,10 @@ from concurrent.futures import ThreadPoolExecutor from unittest.mock import patch -from apache_beam.ml.inference.model_manager import ModelManager, GPUMonitor, ResourceEstimator +try: + from apache_beam.ml.inference.model_manager import ModelManager, GPUMonitor, ResourceEstimator +except ImportError as e: + raise unittest.SkipTest("Model Manager dependencies are not installed") class MockGPUMonitor: From b245c1301a92087f33c825d771d7abb289754b97 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Sat, 13 Dec 2025 00:29:41 +0000 Subject: [PATCH 32/48] Reorder import --- sdks/python/apache_beam/ml/inference/model_manager_it_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/model_manager_it_test.py b/sdks/python/apache_beam/ml/inference/model_manager_it_test.py index 9d9c31f77490..6148e0330204 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager_it_test.py +++ b/sdks/python/apache_beam/ml/inference/model_manager_it_test.py @@ -1,6 +1,8 @@ import unittest import apache_beam as beam from apache_beam.ml.inference.base import RunInference +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import assert_that, equal_to # pylint: disable=ungrouped-imports try: @@ -9,8 +11,6 @@ except ImportError as e: raise unittest.SkipTest( "HuggingFace model handler dependencies are not installed") -from apache_beam.testing.test_pipeline import TestPipeline -from apache_beam.testing.util import assert_that, equal_to class HuggingFaceGpuTest(unittest.TestCase): From ffd53d21d8025fc8ec473b01b24b09f706f171d7 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Sat, 13 Dec 2025 00:38:28 +0000 Subject: [PATCH 33/48] Add license --- .../apache_beam/ml/inference/model_manager.py | 2 -- .../ml/inference/model_manager_it_test.py | 17 +++++++++++++++++ .../ml/inference/model_manager_test.py | 17 +++++++++++++++++ 3 files changed, 34 insertions(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py b/sdks/python/apache_beam/ml/inference/model_manager.py index 81acb4d3e482..775e73a32215 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager.py +++ b/sdks/python/apache_beam/ml/inference/model_manager.py @@ -14,8 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -# TODO: https://github.com/apache/beam/issues/21822 -# mypy: ignore-errors """Module for managing ML models in Apache Beam pipelines. diff --git a/sdks/python/apache_beam/ml/inference/model_manager_it_test.py b/sdks/python/apache_beam/ml/inference/model_manager_it_test.py index 6148e0330204..32d007e647e6 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager_it_test.py +++ b/sdks/python/apache_beam/ml/inference/model_manager_it_test.py @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + import unittest import apache_beam as beam from apache_beam.ml.inference.base import RunInference diff --git a/sdks/python/apache_beam/ml/inference/model_manager_test.py b/sdks/python/apache_beam/ml/inference/model_manager_test.py index 9334df95222a..4a88b31fb10b 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager_test.py +++ b/sdks/python/apache_beam/ml/inference/model_manager_test.py @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + import unittest import time import threading From 3d8a2b5b08cc0eb1683bfbce5e73a183c16a807f Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Sat, 13 Dec 2025 01:06:15 +0000 Subject: [PATCH 34/48] Skip model manager import if error --- sdks/python/apache_beam/ml/inference/base.py | 3 ++- sdks/python/apache_beam/ml/inference/base_test.py | 11 +++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index ebb3c3709ee9..6fc7b4ed220e 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -59,13 +59,14 @@ from apache_beam.utils import multi_process_shared from apache_beam.utils import retry from apache_beam.utils import shared -from apache_beam.ml.inference.model_manager import ModelManager try: # pylint: disable=wrong-import-order, wrong-import-position import resource + from apache_beam.ml.inference.model_manager import ModelManager except ImportError: resource = None # type: ignore[assignment] + ModelManager = None # type: ignore[assignment] _NANOSECOND_TO_MILLISECOND = 1_000_000 _NANOSECOND_TO_MICROSECOND = 1_000 diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index 3e3392bbac57..b5d739a6b7a9 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -317,6 +317,15 @@ def validate_inference_args(self, inference_args): pass +def try_import_model_manager(): + try: + # pylint: disable=unused-import + from apache_beam.ml.inference.model_manager import ModelManager + return True + except ImportError: + return False + + class RunInferenceBaseTest(unittest.TestCase): def test_run_inference_impl_simple_examples(self): with TestPipeline() as pipeline: @@ -1894,6 +1903,7 @@ def test_model_status_provides_valid_garbage_collection(self): self.assertEqual(0, len(tags)) + @skipIf(try_import_model_manager() is False, 'Model Manager not available') def test_run_inference_impl_with_model_manager(self): with TestPipeline() as pipeline: examples = [1, 5, 3, 10] @@ -1903,6 +1913,7 @@ def test_run_inference_impl_with_model_manager(self): SimpleFakeModelHanlder(), use_model_manager=True) assert_that(actual, equal_to(expected), label='assert:inferences') + @skipIf(try_import_model_manager() is False, 'Model Manager not available') def test_run_inference_impl_with_model_manager_args(self): with TestPipeline() as pipeline: examples = [1, 5, 3, 10] From 6a0bd89b8c6ecb8a006919dcfffda69ce1fb47e3 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Sat, 13 Dec 2025 01:34:31 +0000 Subject: [PATCH 35/48] Fix skipif --- sdks/python/apache_beam/ml/inference/base_test.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index b5d739a6b7a9..86c0ae95318d 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -1903,7 +1903,8 @@ def test_model_status_provides_valid_garbage_collection(self): self.assertEqual(0, len(tags)) - @skipIf(try_import_model_manager() is False, 'Model Manager not available') + @unittest.skipIf( + not try_import_model_manager(), 'Model Manager not available') def test_run_inference_impl_with_model_manager(self): with TestPipeline() as pipeline: examples = [1, 5, 3, 10] @@ -1913,7 +1914,8 @@ def test_run_inference_impl_with_model_manager(self): SimpleFakeModelHanlder(), use_model_manager=True) assert_that(actual, equal_to(expected), label='assert:inferences') - @skipIf(try_import_model_manager() is False, 'Model Manager not available') + @unittest.skipIf( + not try_import_model_manager(), 'Model Manager not available') def test_run_inference_impl_with_model_manager_args(self): with TestPipeline() as pipeline: examples = [1, 5, 3, 10] From e375b458aa8511cb6ecd1fa410f88461c3e7e57e Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Sat, 13 Dec 2025 01:49:07 +0000 Subject: [PATCH 36/48] Enable model manager on the it test --- .../ml/inference/model_manager_it_test.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/model_manager_it_test.py b/sdks/python/apache_beam/ml/inference/model_manager_it_test.py index 32d007e647e6..609de0606a14 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager_it_test.py +++ b/sdks/python/apache_beam/ml/inference/model_manager_it_test.py @@ -62,7 +62,8 @@ def test_sentiment_analysis_on_gpu_large_input(self): pcoll = pipeline | 'CreateInputs' >> beam.Create(examples) - predictions = pcoll | 'RunInference' >> RunInference(model_handler) + predictions = pcoll | 'RunInference' >> RunInference( + model_handler, use_model_manager=True) actual_labels = predictions | beam.Map(lambda x: x.inference['label']) @@ -112,7 +113,8 @@ def test_sentiment_analysis_large_roberta_gpu(self): ] * DUPLICATE_FACTOR pcoll = pipeline | 'CreateInputs' >> beam.Create(examples) - predictions = pcoll | 'RunInference' >> RunInference(model_handler) + predictions = pcoll | 'RunInference' >> RunInference( + model_handler, use_model_manager=True) actual_labels = predictions | beam.Map(lambda x: x.inference['label']) expected_labels = [ @@ -171,10 +173,12 @@ def test_parallel_inference_branches(self): inputs = pipeline | 'CreateInputs' >> beam.Create(examples) _ = ( inputs - | 'RunTranslation' >> RunInference(translator_handler) + | 'RunTranslation' >> RunInference( + translator_handler, use_model_manager=True) | 'ExtractSpanish' >> beam.Map(lambda x: x.inference['translation_text'])) _ = ( inputs - | 'RunSentiment' >> RunInference(sentiment_handler) + | 'RunSentiment' >> RunInference( + sentiment_handler, use_model_manager=True) | 'ExtractLabel' >> beam.Map(lambda x: x.inference['label'])) From 540cc8c646dbc686f0d1a26e2c6c98819bf06de3 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Sat, 13 Dec 2025 04:27:02 +0000 Subject: [PATCH 37/48] Add timeout and force poll gpu usage to prevent race condition --- .../apache_beam/ml/inference/model_manager.py | 20 ++++++++++++++++++- .../ml/inference/model_manager_test.py | 4 ++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py b/sdks/python/apache_beam/ml/inference/model_manager.py index 775e73a32215..5f44eead5e06 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager.py +++ b/sdks/python/apache_beam/ml/inference/model_manager.py @@ -103,6 +103,21 @@ def get_stats(self) -> Tuple[float, float, float]: with self._lock: return self._current_usage, self._peak_usage, self._total_memory + def refresh(self): + """Forces an immediate poll of the GPU.""" + usage = self._get_nvidia_smi_used() + now = time.time() + with self._lock: + self._current_usage = usage + self._memory_history.append((now, usage)) + # Recalculate peak immediately + while self._memory_history and (now - self._memory_history[0][0] + > self._peak_window_seconds): + self._memory_history.popleft() + self._peak_usage = ( + max(m for _, m in self._memory_history) + if self._memory_history else usage) + def _get_nvidia_smi_used(self) -> float: try: cmd = "nvidia-smi --query-gpu=memory.free --format=csv,noheader,nounits" @@ -360,7 +375,7 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: if self._evict_to_make_space(limit, est_cost, requesting_tag=tag): continue - self._cv.wait() + self._cv.wait(timeout=10.0) finally: if self._wait_queue and self._wait_queue[0][2] is my_id: @@ -492,6 +507,7 @@ def _perform_eviction(self, key, tag, instance, score): del instance gc.collect() torch.cuda.empty_cache() + self._monitor.refresh() self._monitor.reset_peak() def _spawn_new_model(self, tag, loader_func, is_unknown, est_cost): @@ -538,6 +554,8 @@ def _delete_all_models(self): self._active_counts.clear() gc.collect() torch.cuda.empty_cache() + self._monitor.refresh() + self._monitor.reset_peak() def _force_reset(self): logger.warning("Force Reset Triggered") diff --git a/sdks/python/apache_beam/ml/inference/model_manager_test.py b/sdks/python/apache_beam/ml/inference/model_manager_test.py index 4a88b31fb10b..7412ea6a6c64 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager_test.py +++ b/sdks/python/apache_beam/ml/inference/model_manager_test.py @@ -81,6 +81,10 @@ def free(self, amount_mb): self.history.pop(0) self._peak = max(self.history) + def refresh(self): + """Simulates a refresh of the monitor stats (no-op for mock).""" + pass + class MockModel: def __init__(self, name, size, monitor): From d3924380869173b73a7be7ea49ea4c93bac8848e Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Sat, 13 Dec 2025 05:30:36 +0000 Subject: [PATCH 38/48] Cleanup queue to avoid leaving zombie ticket --- sdks/python/apache_beam/ml/inference/model_manager.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py b/sdks/python/apache_beam/ml/inference/model_manager.py index 5f44eead5e06..34b29ace019f 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager.py +++ b/sdks/python/apache_beam/ml/inference/model_manager.py @@ -380,7 +380,12 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: finally: if self._wait_queue and self._wait_queue[0][2] is my_id: heapq.heappop(self._wait_queue) - self._cv.notify_all() + else: + for i, item in enumerate(self._wait_queue): + if item[2] is my_id: + self._wait_queue.pop(i) + heapq.heapify(self._wait_queue) + self._cv.notify_all() if should_spawn: return self._spawn_new_model(tag, loader_func, is_unknown, est_cost) From ed0c2d82e69d3a81971b20f4224aa8e3d627979f Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Sat, 13 Dec 2025 22:54:16 +0000 Subject: [PATCH 39/48] Add logging for debugging --- .../apache_beam/ml/inference/model_manager.py | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py b/sdks/python/apache_beam/ml/inference/model_manager.py index 34b29ace019f..d0bbe63b00f2 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager.py +++ b/sdks/python/apache_beam/ml/inference/model_manager.py @@ -306,7 +306,11 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: return cached_instance # SLOW PATH - logger.info("Acquire Queued: tag=%s, priority=%d", tag, current_priority) + logger.info( + "Acquire Queued: tag=%s, priority=%d total models count=%s", + tag, + current_priority, + len(self._models[tag])) heapq.heappush( self._wait_queue, (current_priority, ticket_num, my_id, tag)) @@ -317,6 +321,8 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: try: while True: if not self._wait_queue or self._wait_queue[0][2] is not my_id: + logger.info( + "Waiting for its turn: tag=%s ticket num=%s", tag, ticket_num) self._cv.wait() continue @@ -340,6 +346,10 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: # Path A: Isolation if is_unknown: if self._total_active_jobs > 0: + logger.info( + "Waiting to enter isolation: tag=%s ticket num=%s", + tag, + ticket_num) self._cv.wait() continue @@ -356,6 +366,10 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: # Path B: Concurrent else: if self._pending_isolation_count > 0 or self._isolation_mode: + logger.info( + "Waiting due to isolation in progress: tag=%s ticket num%s", + tag, + ticket_num) self._cv.wait() continue @@ -375,6 +389,10 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: if self._evict_to_make_space(limit, est_cost, requesting_tag=tag): continue + logger.info( + "Waiting for resources to free up: tag=%s ticket num%s", + tag, + ticket_num) self._cv.wait(timeout=10.0) finally: From c380699ed7f3376474b514d5169718a52410ba13 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Sun, 14 Dec 2025 01:21:52 +0000 Subject: [PATCH 40/48] Add more logging --- .../apache_beam/ml/inference/model_manager.py | 31 ++++++++++++++++--- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py b/sdks/python/apache_beam/ml/inference/model_manager.py index d0bbe63b00f2..81afdb22a76b 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager.py +++ b/sdks/python/apache_beam/ml/inference/model_manager.py @@ -307,10 +307,12 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: # SLOW PATH logger.info( - "Acquire Queued: tag=%s, priority=%d total models count=%s", + "Acquire Queued: tag=%s, priority=%d " + "total models count=%s ticket num=%s", tag, current_priority, - len(self._models[tag])) + len(self._models[tag]), + ticket_num) heapq.heappush( self._wait_queue, (current_priority, ticket_num, my_id, tag)) @@ -389,10 +391,29 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: if self._evict_to_make_space(limit, est_cost, requesting_tag=tag): continue + idle_count = 0 + other_idle_count = 0 + for item in self._idle_lru.items(): + if item[1][0] == tag: + idle_count += 1 + else: + other_idle_count += 1 + total_model_count = 0 + for _, instances in self._models.items(): + total_model_count += len(instances) + curr, _, _ = self._monitor.get_stats() logger.info( - "Waiting for resources to free up: tag=%s ticket num%s", + "Waiting for resources to free up: " + "tag=%s ticket num%s model count=%s " + "idle count=%s resource usage=%.1f MB " + "total models count=%s other idle=%s", tag, - ticket_num) + ticket_num, + len(self._models[tag]), + idle_count, + curr, + total_model_count, + other_idle_count) self._cv.wait(timeout=10.0) finally: @@ -450,6 +471,8 @@ def _try_grab_from_lru(self, tag: str) -> Any: self._active_counts[tag] += 1 self._total_active_jobs += 1 return target_instance + + logger.info("No idle model found for tag: %s", tag) return None def _evict_to_make_space( From 2e1b7d90cf54d388d0da7fa27c8fb0cb5827a65b Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Sun, 14 Dec 2025 05:20:33 +0000 Subject: [PATCH 41/48] Fix delete model --- .../apache_beam/ml/inference/model_manager.py | 10 ++++-- .../apache_beam/utils/multi_process_shared.py | 33 +++++++++++++++++-- 2 files changed, 38 insertions(+), 5 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py b/sdks/python/apache_beam/ml/inference/model_manager.py index 81afdb22a76b..30dbe08304f3 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager.py +++ b/sdks/python/apache_beam/ml/inference/model_manager.py @@ -540,21 +540,25 @@ def _evict_to_make_space( def _perform_eviction(self, key, tag, instance, score): logger.info("Evicting Model: %s (Score %d)", tag, score) + curr, _, _ = self._monitor.get_stats() + logger.info("Resource Usage Before Eviction: %.1f MB", curr) if key in self._idle_lru: del self._idle_lru[key] - if hasattr(instance, "unsafe_hard_delete"): - instance.unsafe_hard_delete() - if instance in self._models[tag]: self._models[tag].remove(instance) + if hasattr(instance, "unsafe_hard_delete"): + instance.unsafe_hard_delete() + del instance gc.collect() torch.cuda.empty_cache() self._monitor.refresh() self._monitor.reset_peak() + curr, _, _ = self._monitor.get_stats() + logger.info("Resource Usage After Eviction: %.1f MB", curr) def _spawn_new_model(self, tag, loader_func, is_unknown, est_cost): try: diff --git a/sdks/python/apache_beam/utils/multi_process_shared.py b/sdks/python/apache_beam/utils/multi_process_shared.py index 7aa8701d2b11..9c158a86b345 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared.py +++ b/sdks/python/apache_beam/utils/multi_process_shared.py @@ -155,6 +155,12 @@ def unsafe_hard_delete(self): class _SingletonManager: entries: Dict[Any, Any] = {} + def __init__(self): + self._hard_delete_callback = None + + def set_hard_delete_callback(self, callback): + self._hard_delete_callback = callback + def register_singleton(self, constructor, tag, initialize_eagerly=True): assert tag not in self.entries, tag self.entries[tag] = _SingletonEntry(constructor, initialize_eagerly) @@ -219,7 +225,13 @@ def get_auto_proxy_object(self): return self._proxyObject def unsafe_hard_delete(self): - return self._proxyObject.unsafe_hard_delete() + try: + self._proxyObject.unsafe_hard_delete() + except (EOFError, ConnectionResetError, BrokenPipeError): + pass + except Exception as e: + logging.warning( + "Exception %s when trying to hard delete shared object proxy", e) def _run_server_process(address_file, tag, constructor, authkey): @@ -239,6 +251,13 @@ def cleanup_files(): except Exception: pass + def handle_unsafe_hard_delete(): + def do_exit(): + cleanup_files() + os._exit(0) + + threading.Thread(target=do_exit, daemon=True).start() + def _monitor_parent(): """Checks if parent is alive every second.""" while True: @@ -264,6 +283,8 @@ def _monitor_parent(): serving_manager = _SingletonRegistrar( address=('localhost', 0), authkey=authkey) + _process_level_singleton_manager.set_hard_delete_callback( + handle_unsafe_hard_delete) _process_level_singleton_manager.register_singleton( constructor, tag, initialize_eagerly=True) @@ -401,7 +422,15 @@ def unsafe_hard_delete(self): to this object exist, or (b) you are ok with all existing references to this object throwing strange errors when derefrenced. """ - self._get_manager().unsafe_hard_delete_singleton(self._tag) + try: + self._get_manager().unsafe_hard_delete_singleton(self._tag) + except (EOFError, ConnectionResetError, BrokenPipeError): + pass + except Exception as e: + logging.warning( + "Exception %s when trying to hard delete shared object %s", + e, + self._tag) def _create_server(self, address_file): if self._spawn_process: From 89047b73553505a053ce6cd5c3ede9749ef7841e Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Mon, 15 Dec 2025 00:51:42 +0000 Subject: [PATCH 42/48] Make sure process exit --- .../apache_beam/utils/multi_process_shared.py | 34 ++++++++++++++----- .../utils/multi_process_shared_test.py | 18 ++++++---- 2 files changed, 37 insertions(+), 15 deletions(-) diff --git a/sdks/python/apache_beam/utils/multi_process_shared.py b/sdks/python/apache_beam/utils/multi_process_shared.py index 9c158a86b345..167e62e9492b 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared.py +++ b/sdks/python/apache_beam/utils/multi_process_shared.py @@ -117,6 +117,25 @@ def __dir__(self): return dir +def _run_with_oom_protection(func, *args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + # Check string to avoid hard import dependency + if 'CUDA out of memory' in str(e): + logging.warning("Caught CUDA OOM during operation. Cleaning memory.") + try: + import gc + import torch + gc.collect() + torch.cuda.empty_cache() + except ImportError: + pass + except Exception as cleanup_error: + logging.error("Failed to clean up CUDA memory: %s", cleanup_error) + raise e + + class _SingletonEntry: """Represents a single, refcounted entry in this process.""" def __init__(self, constructor, initialize_eagerly=True): @@ -124,7 +143,7 @@ def __init__(self, constructor, initialize_eagerly=True): self.refcount = 0 self.lock = threading.Lock() if initialize_eagerly: - self.obj = constructor() + self.obj = _run_with_oom_protection(constructor) self.initialied = True else: self.initialied = False @@ -132,7 +151,7 @@ def __init__(self, constructor, initialize_eagerly=True): def acquire(self): with self.lock: if not self.initialied: - self.obj = self.constructor() + self.obj = _run_with_oom_protection(self.constructor) self.initialied = True self.refcount += 1 return _SingletonProxy(self) @@ -175,7 +194,8 @@ def release_singleton(self, tag, obj): return self.entries[tag].release(obj) def unsafe_hard_delete_singleton(self, tag): - return self.entries[tag].unsafe_hard_delete() + self.entries[tag].unsafe_hard_delete() + self._hard_delete_callback() _process_level_singleton_manager = _SingletonManager() @@ -240,6 +260,7 @@ def _run_server_process(address_file, tag, constructor, authkey): Includes a 'Suicide Pact' monitor: If parent dies, I die. """ parent_pid = os.getppid() + os.setsid() def cleanup_files(): logging.info("Server process exiting. Deleting files for %s", tag) @@ -252,11 +273,8 @@ def cleanup_files(): pass def handle_unsafe_hard_delete(): - def do_exit(): - cleanup_files() - os._exit(0) - - threading.Thread(target=do_exit, daemon=True).start() + cleanup_files() + os._exit(0) def _monitor_parent(): """Checks if parent is alive every second.""" diff --git a/sdks/python/apache_beam/utils/multi_process_shared_test.py b/sdks/python/apache_beam/utils/multi_process_shared_test.py index 26f08b13bdef..4ccf19321243 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared_test.py +++ b/sdks/python/apache_beam/utils/multi_process_shared_test.py @@ -86,9 +86,11 @@ def __getattribute__(self, __name: str) -> Any: class SimpleClass: - def make_proxy(self): + def make_proxy( + self, tag: str = 'proxy_on_proxy', spawn_process: bool = False): return multi_process_shared.MultiProcessShared( - Counter, tag='proxy_on_proxy', always_proxy=True).acquire() + Counter, tag=tag, always_proxy=True, + spawn_process=spawn_process).acquire() class MultiProcessSharedTest(unittest.TestCase): @@ -293,6 +295,7 @@ def setUp(self): for tag in ['basic', 'proxy_on_proxy', 'proxy_on_proxy_main', + 'main', 'to_delete', 'mix1', 'mix2']: @@ -318,13 +321,14 @@ def test_call(self): def test_proxy_on_proxy(self): shared1 = multi_process_shared.MultiProcessShared( - SimpleClass, - tag='proxy_on_proxy_main', - always_proxy=True, - spawn_process=True) + SimpleClass, tag='main', always_proxy=True) instance = shared1.acquire() - proxy_instance = instance.make_proxy() + proxy_instance = instance.make_proxy(spawn_process=True) self.assertEqual(proxy_instance.increment(), 1) + proxy_instance.unsafe_hard_delete() + + proxy_instance2 = instance.make_proxy(tag='proxy_2', spawn_process=True) + self.assertEqual(proxy_instance2.increment(), 1) def test_unsafe_hard_delete_autoproxywrapper(self): shared1 = multi_process_shared.MultiProcessShared( From 79eafe8d780dca1f2d9b5aa8add49d55be4bcd42 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Mon, 15 Dec 2025 01:08:05 +0000 Subject: [PATCH 43/48] Make sure process exit without manager --- .../apache_beam/utils/multi_process_shared.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/sdks/python/apache_beam/utils/multi_process_shared.py b/sdks/python/apache_beam/utils/multi_process_shared.py index 167e62e9492b..cd79701f2478 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared.py +++ b/sdks/python/apache_beam/utils/multi_process_shared.py @@ -138,8 +138,10 @@ def _run_with_oom_protection(func, *args, **kwargs): class _SingletonEntry: """Represents a single, refcounted entry in this process.""" - def __init__(self, constructor, initialize_eagerly=True): + def __init__( + self, constructor, initialize_eagerly=True, hard_delete_callback=None): self.constructor = constructor + self._hard_delete_callback = hard_delete_callback self.refcount = 0 self.lock = threading.Lock() if initialize_eagerly: @@ -169,6 +171,8 @@ def unsafe_hard_delete(self): if self.initialied: del self.obj self.initialied = False + if self._hard_delete_callback: + self._hard_delete_callback() class _SingletonManager: @@ -180,9 +184,15 @@ def __init__(self): def set_hard_delete_callback(self, callback): self._hard_delete_callback = callback - def register_singleton(self, constructor, tag, initialize_eagerly=True): + def register_singleton( + self, + constructor, + tag, + initialize_eagerly=True, + hard_delete_callback=None): assert tag not in self.entries, tag - self.entries[tag] = _SingletonEntry(constructor, initialize_eagerly) + self.entries[tag] = _SingletonEntry( + constructor, initialize_eagerly, hard_delete_callback) def has_singleton(self, tag): return tag in self.entries From d8195045fd27e727bdfd70482443e413bc24af2e Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Mon, 15 Dec 2025 01:56:16 +0000 Subject: [PATCH 44/48] Add tests for reaping --- .../apache_beam/utils/multi_process_shared.py | 9 +- .../utils/multi_process_shared_test.py | 94 ++++++++++++++++++- 2 files changed, 101 insertions(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/utils/multi_process_shared.py b/sdks/python/apache_beam/utils/multi_process_shared.py index cd79701f2478..a943f0fe1a6b 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared.py +++ b/sdks/python/apache_beam/utils/multi_process_shared.py @@ -314,7 +314,10 @@ def _monitor_parent(): _process_level_singleton_manager.set_hard_delete_callback( handle_unsafe_hard_delete) _process_level_singleton_manager.register_singleton( - constructor, tag, initialize_eagerly=True) + constructor, + tag, + initialize_eagerly=True, + hard_delete_callback=handle_unsafe_hard_delete) server = serving_manager.get_server() logging.info( @@ -436,6 +439,10 @@ def acquire(self): # Caveat: They must always agree, as they will be ignored if the object # is already constructed. singleton = self._get_manager().acquire_singleton(self._tag) + # Trigger a sweep of zombie processes. + # calling active_children() has the side-effect of joining any finished + # processes, effectively reaping zombies from previous unsafe_hard_deletes. + if self._spawn_process: multiprocessing.active_children() return _AutoProxyWrapper(singleton) def release(self, obj): diff --git a/sdks/python/apache_beam/utils/multi_process_shared_test.py b/sdks/python/apache_beam/utils/multi_process_shared_test.py index 4ccf19321243..f3258cf0a968 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared_test.py +++ b/sdks/python/apache_beam/utils/multi_process_shared_test.py @@ -298,7 +298,8 @@ def setUp(self): 'main', 'to_delete', 'mix1', - 'mix2']: + 'mix2' + 'test_process_exit']: for ext in ['', '.address', '.address.error']: try: os.remove(os.path.join(tempdir, tag + ext)) @@ -368,6 +369,97 @@ def test_mix_usage(self): self.assertEqual(shared2.get(), 0) self.assertEqual(shared2.increment(), 1) + def test_process_exits_on_unsafe_hard_delete(self): + shared = multi_process_shared.MultiProcessShared( + Counter, tag='test_process_exit', always_proxy=True, spawn_process=True) + obj = shared.acquire() + + self.assertEqual(obj.increment(), 1) + + children = multiprocessing.active_children() + server_process = None + for p in children: + if p.pid != os.getpid() and p.is_alive(): + server_process = p + break + + self.assertIsNotNone( + server_process, "Could not find spawned server process") + obj.unsafe_hard_delete() + server_process.join(timeout=5) + + self.assertFalse( + server_process.is_alive(), + f"Server process {server_process.pid} is still alive after hard delete") + self.assertIsNotNone( + server_process.exitcode, "Process has no exit code (did not exit)") + + with self.assertRaises(Exception): + obj.get() + + def test_process_exits_on_unsafe_hard_delete_with_manager(self): + shared = multi_process_shared.MultiProcessShared( + Counter, tag='test_process_exit', always_proxy=True, spawn_process=True) + obj = shared.acquire() + + self.assertEqual(obj.increment(), 1) + + children = multiprocessing.active_children() + server_process = None + for p in children: + if p.pid != os.getpid() and p.is_alive(): + server_process = p + break + + self.assertIsNotNone( + server_process, "Could not find spawned server process") + shared.unsafe_hard_delete() + server_process.join(timeout=5) + + self.assertFalse( + server_process.is_alive(), + f"Server process {server_process.pid} is still alive after hard delete") + self.assertIsNotNone( + server_process.exitcode, "Process has no exit code (did not exit)") + + with self.assertRaises(Exception): + obj.get() + + def test_zombie_reaping_on_acquire(self): + shared1 = multi_process_shared.MultiProcessShared( + Counter, tag='test_zombie_reap', always_proxy=True, spawn_process=True) + obj = shared1.acquire() + + children = multiprocessing.active_children() + server_pid = next( + p.pid for p in children if p.is_alive() and p.pid != os.getpid()) + + obj.unsafe_hard_delete() + + try: + os.kill(server_pid, 0) + is_zombie = True + except OSError: + is_zombie = False + self.assertTrue( + is_zombie, + f"Server process {server_pid} was reaped too early before acquire()") + + shared2 = multi_process_shared.MultiProcessShared( + Counter, tag='unrelated_tag', always_proxy=True, spawn_process=True) + _ = shared2.acquire() + + pid_exists = True + try: + os.kill(server_pid, 0) + except OSError: + pid_exists = False + + self.assertFalse( + pid_exists, + f"Old server process {server_pid} was not reaped by acquire() sweep") + shared2.unsafe_hard_delete() + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) From 075ab415197f2e90c42e8b8713b5798efe0e39fc Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Mon, 15 Dec 2025 01:58:09 +0000 Subject: [PATCH 45/48] Remove redundant line --- sdks/python/apache_beam/utils/multi_process_shared.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sdks/python/apache_beam/utils/multi_process_shared.py b/sdks/python/apache_beam/utils/multi_process_shared.py index a943f0fe1a6b..1b3fd50eb107 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared.py +++ b/sdks/python/apache_beam/utils/multi_process_shared.py @@ -270,7 +270,6 @@ def _run_server_process(address_file, tag, constructor, authkey): Includes a 'Suicide Pact' monitor: If parent dies, I die. """ parent_pid = os.getppid() - os.setsid() def cleanup_files(): logging.info("Server process exiting. Deleting files for %s", tag) From c627783c59b0512e374047626fcf3cf4d2679717 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Mon, 15 Dec 2025 04:06:03 +0000 Subject: [PATCH 46/48] Add uuid to make sure eviction clears model --- .../apache_beam/ml/inference/model_manager.py | 42 ++++++++++++++++--- 1 file changed, 36 insertions(+), 6 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py b/sdks/python/apache_beam/ml/inference/model_manager.py index 30dbe08304f3..f5d18d355da7 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager.py +++ b/sdks/python/apache_beam/ml/inference/model_manager.py @@ -24,6 +24,7 @@ usage and performance. """ +import uuid import time import threading import subprocess @@ -249,6 +250,34 @@ def _solve(self): logger.error("Solver failed: %s", e) +class TrackedModelProxy: + def __init__(self, obj): + object.__setattr__(self, "_wrapped_obj", obj) + object.__setattr__(self, "_beam_tracking_id", str(uuid.uuid4())) + + def __getattr__(self, name): + return getattr(self._wrapped_obj, name) + + def __setattr__(self, name, value): + setattr(self._wrapped_obj, name, value) + + def __call__(self, *args, **kwargs): + return self._wrapped_obj(*args, **kwargs) + + def __str__(self): + return str(self._wrapped_obj) + + def __repr__(self): + return repr(self._wrapped_obj) + + def __dir__(self): + return dir(self._wrapped_obj) + + def unsafe_hard_delete(self): + if hasattr(self._wrapped_obj, "unsafe_hard_delete"): + self._wrapped_obj.unsafe_hard_delete() + + class ModelManager: _lock = threading.Lock() @@ -546,12 +575,13 @@ def _perform_eviction(self, key, tag, instance, score): if key in self._idle_lru: del self._idle_lru[key] - if instance in self._models[tag]: - self._models[tag].remove(instance) - - if hasattr(instance, "unsafe_hard_delete"): - instance.unsafe_hard_delete() + target_id = instance._beam_tracking_id + for i, inst in enumerate(self._models[tag]): + if inst._beam_tracking_id == target_id: + del self._models[tag][i] + break + instance.unsafe_hard_delete() del instance gc.collect() torch.cuda.empty_cache() @@ -565,7 +595,7 @@ def _spawn_new_model(self, tag, loader_func, is_unknown, est_cost): with self._load_lock: logger.info("Loading Model: %s (Unknown: %s)", tag, is_unknown) isolation_baseline_snap, _, _ = self._monitor.get_stats() - instance = loader_func() + instance = TrackedModelProxy(loader_func()) _, peak_during_load, _ = self._monitor.get_stats() with self._cv: From ebb6ff5fc7e40c2be53a4e1b21cffed33727504e Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Mon, 15 Dec 2025 04:10:04 +0000 Subject: [PATCH 47/48] Avoid pickling issue --- sdks/python/apache_beam/ml/inference/model_manager.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py b/sdks/python/apache_beam/ml/inference/model_manager.py index f5d18d355da7..e56b8ee8d03f 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager.py +++ b/sdks/python/apache_beam/ml/inference/model_manager.py @@ -264,6 +264,12 @@ def __setattr__(self, name, value): def __call__(self, *args, **kwargs): return self._wrapped_obj(*args, **kwargs) + def __setstate__(self, state): + self.__dict__.update(state) + + def __getstate__(self): + return self.__dict__ + def __str__(self): return str(self._wrapped_obj) From 1d3f8c93b30704ee47704fa0621a52921f2372f2 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Mon, 15 Dec 2025 19:35:21 +0000 Subject: [PATCH 48/48] Wait on server start and log if taking too long --- sdks/python/apache_beam/utils/multi_process_shared.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/sdks/python/apache_beam/utils/multi_process_shared.py b/sdks/python/apache_beam/utils/multi_process_shared.py index 1b3fd50eb107..0efa01f45570 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared.py +++ b/sdks/python/apache_beam/utils/multi_process_shared.py @@ -502,6 +502,7 @@ def cleanup_process(): atexit.register(cleanup_process) start_time = time.time() + last_log = start_time while True: if os.path.exists(address_file): break @@ -523,10 +524,12 @@ def cleanup_process(): "Shared Server Process died unexpectedly" f" with exit code {exit_code}") - if time.time() - start_time > 30: - if p.is_alive(): p.terminate() - raise TimeoutError( - f"Timed out waiting for server process {self._tag} to start.") + if time.time() - last_log > 300: + logging.warning( + "Still waiting for %s to initialize... %ss elapsed)", + self._tag, + int(time.time() - start_time)) + last_log = time.time() time.sleep(0.05)