diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index d79565ee24da..14030f669b78 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -63,8 +63,10 @@ try: # pylint: disable=wrong-import-order, wrong-import-position import resource + from apache_beam.ml.inference.model_manager import ModelManager except ImportError: resource = None # type: ignore[assignment] + ModelManager = None # type: ignore[assignment] _NANOSECOND_TO_MILLISECOND = 1_000_000 _NANOSECOND_TO_MICROSECOND = 1_000 @@ -467,11 +469,12 @@ def request( raise NotImplementedError(type(self)) -class _ModelManager: +class _ModelHandlerManager: """ - A class for efficiently managing copies of multiple models. Will load a - single copy of each model into a multi_process_shared object and then - return a lookup key for that object. + A class for efficiently managing copies of multiple model handlers. + Will load a single copy of each model from the model handler into a + multi_process_shared object and then return a lookup key for that + object. Used for KeyedModelHandler only. """ def __init__(self, mh_map: dict[str, ModelHandler]): """ @@ -536,8 +539,9 @@ def load(self, key: str) -> _ModelLoadStats: def increment_max_models(self, increment: int): """ - Increments the number of models that this instance of a _ModelManager is - able to hold. If it is never called, no limit is imposed. + Increments the number of models that this instance of a + _ModelHandlerManager is able to hold. If it is never called, + no limit is imposed. Args: increment: the amount by which we are incrementing the number of models. """ @@ -590,7 +594,7 @@ def __init__( class KeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT], ModelHandler[tuple[KeyT, ExampleT], tuple[KeyT, PredictionT], - Union[ModelT, _ModelManager]]): + Union[ModelT, _ModelHandlerManager]]): def __init__( self, unkeyed: Union[ModelHandler[ExampleT, PredictionT, ModelT], @@ -743,15 +747,15 @@ def __init__( 'to exactly one model handler.') self._key_to_id_map[key] = keys[0] - def load_model(self) -> Union[ModelT, _ModelManager]: + def load_model(self) -> Union[ModelT, _ModelHandlerManager]: if self._single_model: return self._unkeyed.load_model() - return _ModelManager(self._id_to_mh_map) + return _ModelHandlerManager(self._id_to_mh_map) def run_inference( self, batch: Sequence[tuple[KeyT, ExampleT]], - model: Union[ModelT, _ModelManager], + model: Union[ModelT, _ModelHandlerManager], inference_args: Optional[dict[str, Any]] = None ) -> Iterable[tuple[KeyT, PredictionT]]: if self._single_model: @@ -853,7 +857,7 @@ def validate_inference_args(self, inference_args: Optional[dict[str, Any]]): def update_model_paths( self, - model: Union[ModelT, _ModelManager], + model: Union[ModelT, _ModelHandlerManager], model_paths: list[KeyModelPathMapping[KeyT]] = None): # When there are many models, the keyed model handler is responsible for # reorganizing the model handlers into cohorts and telling the model @@ -1272,6 +1276,8 @@ def __init__( model_metadata_pcoll: beam.PCollection[ModelMetadata] = None, watch_model_pattern: Optional[str] = None, model_identifier: Optional[str] = None, + use_model_manager: bool = False, + model_manager_args: Optional[dict[str, Any]] = None, **kwargs): """ A transform that takes a PCollection of examples (or features) for use @@ -1312,6 +1318,8 @@ def __init__( self._exception_handling_timeout = None self._timeout = None self._watch_model_pattern = watch_model_pattern + self._use_model_manager = use_model_manager + self._model_manager_args = model_manager_args self._kwargs = kwargs # Generate a random tag to use for shared.py and multi_process_shared.py to # allow us to effectively disambiguate in multi-model settings. Only use @@ -1424,7 +1432,9 @@ def expand( self._clock, self._metrics_namespace, load_model_at_runtime, - self._model_tag), + self._model_tag, + self._use_model_manager, + self._model_manager_args), self._inference_args, beam.pvalue.AsSingleton( self._model_metadata_pcoll, @@ -1737,21 +1747,47 @@ def load_model_status( return shared.Shared().acquire(lambda: _ModelStatus(False), tag=tag) +class _ProxyLoader: + """ + A helper callable to wrap the loader for MultiProcessShared. + """ + def __init__(self, loader_func, model_tag): + self.loader_func = loader_func + self.model_tag = model_tag + + def __call__(self): + unique_tag = self.model_tag + '_' + uuid.uuid4().hex + # Ensure that each model loaded in a different process for parallelism + return multi_process_shared.MultiProcessShared( + self.loader_func, tag=unique_tag, always_proxy=True, + spawn_process=True).acquire() + + class _SharedModelWrapper(): """A router class to map incoming calls to the correct model. This allows us to round robin calls to models sitting in different processes so that we can more efficiently use resources (e.g. GPUs). """ - def __init__(self, models: list[Any], model_tag: str): + def __init__( + self, + models: Union[list[Any], ModelManager], + model_tag: str, + loader_func: Callable[[], Any] = None): self.models = models - if len(models) > 1: + self.use_model_manager = not isinstance(models, list) + self.model_tag = model_tag + self.loader_func = loader_func + if not self.use_model_manager and len(models) > 1: self.model_router = multi_process_shared.MultiProcessShared( lambda: _ModelRoutingStrategy(), tag=f'{model_tag}_counter', always_proxy=True).acquire() def next_model(self): + if self.use_model_manager: + loader_wrapper = _ProxyLoader(self.loader_func, self.model_tag) + return self.models.acquire_model(self.model_tag, loader_wrapper) if len(self.models) == 1: # Short circuit if there's no routing strategy needed in order to # avoid the cross-process call @@ -1759,9 +1795,19 @@ def next_model(self): return self.models[self.model_router.next_model_index(len(self.models))] + def release_model(self, model_tag: str, model: Any): + if self.use_model_manager: + self.models.release_model(model_tag, model) + def all_models(self): + if self.use_model_manager: + return self.models.all_models()[self.model_tag] return self.models + def force_reset(self): + if self.use_model_manager: + self.models.force_reset() + class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT, PredictionT]): def __init__( @@ -1770,7 +1816,9 @@ def __init__( clock, metrics_namespace, load_model_at_runtime: bool = False, - model_tag: str = "RunInference"): + model_tag: str = "RunInference", + use_model_manager: bool = False, + model_manager_args: Optional[dict[str, Any]] = None): """A DoFn implementation generic to frameworks. Args: @@ -1794,6 +1842,8 @@ def __init__( # _cur_tag is the tag of the actually loaded model self._model_tag = model_tag self._cur_tag = model_tag + self.use_model_manager = use_model_manager + self._model_manager_args = model_manager_args or {} def _load_model( self, @@ -1828,7 +1878,15 @@ def load(): model_tag = side_input_model_path # Ensure the tag we're loading is valid, if not replace it with a valid tag self._cur_tag = self._model_metadata.get_valid_tag(model_tag) - if self._model_handler.share_model_across_processes(): + if self.use_model_manager: + logging.info("Using Model Manager to manage models automatically.") + model_manager = multi_process_shared.MultiProcessShared( + lambda: ModelManager(**self._model_manager_args), + tag='model_manager', + always_proxy=True).acquire() + model_wrapper = _SharedModelWrapper( + model_manager, self._cur_tag, self._model_handler.load_model) + elif self._model_handler.share_model_across_processes(): models = [] for copy_tag in _get_tags_for_copies(self._cur_tag, self._model_handler.model_copies()): @@ -1885,6 +1943,8 @@ def _run_inference(self, batch, inference_args): model = self._model.next_model() result_generator = self._model_handler.run_inference( batch, model, inference_args) + if self.use_model_manager: + self._model.release_model(self._model_tag, model) except BaseException as e: if self._metrics_collector: self._metrics_collector.failed_batches_counter.inc() diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index 574e71de89ce..b0f52e16d181 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -117,6 +117,19 @@ def batch_elements_kwargs(self): return {'min_batch_size': 1, 'max_batch_size': 1} +class SimpleFakeModelHanlder(base.ModelHandler[int, int, FakeModel]): + def load_model(self): + return FakeModel() + + def run_inference( + self, + batch: Sequence[int], + model: FakeModel, + inference_args=None) -> Iterable[int]: + for example in batch: + yield model.predict(example) + + class FakeModelHandler(base.ModelHandler[int, int, FakeModel]): def __init__( self, @@ -310,6 +323,15 @@ def validate_inference_args(self, inference_args): pass +def try_import_model_manager(): + try: + # pylint: disable=unused-import + from apache_beam.ml.inference.model_manager import ModelManager + return True + except ImportError: + return False + + class RunInferenceBaseTest(unittest.TestCase): def test_run_inference_impl_simple_examples(self): with TestPipeline() as pipeline: @@ -1599,13 +1621,13 @@ def test_child_class_without_env_vars(self): actual = pcoll | base.RunInference(FakeModelHandlerNoEnvVars()) assert_that(actual, equal_to(expected), label='assert:inferences') - def test_model_manager_loads_shared_model(self): + def test_model_handler_manager_loads_shared_model(self): mhs = { 'key1': FakeModelHandler(state=1), 'key2': FakeModelHandler(state=2), 'key3': FakeModelHandler(state=3) } - mm = base._ModelManager(mh_map=mhs) + mm = base._ModelHandlerManager(mh_map=mhs) tag1 = mm.load('key1').model_tag # Use bad_mh's load function to make sure we're actually loading the # version already stored @@ -1623,12 +1645,12 @@ def test_model_manager_loads_shared_model(self): self.assertEqual(2, model2.predict(10)) self.assertEqual(3, model3.predict(10)) - def test_model_manager_evicts_models(self): + def test_model_handler_manager_evicts_models(self): mh1 = FakeModelHandler(state=1) mh2 = FakeModelHandler(state=2) mh3 = FakeModelHandler(state=3) mhs = {'key1': mh1, 'key2': mh2, 'key3': mh3} - mm = base._ModelManager(mh_map=mhs) + mm = base._ModelHandlerManager(mh_map=mhs) mm.increment_max_models(2) tag1 = mm.load('key1').model_tag sh1 = multi_process_shared.MultiProcessShared(mh1.load_model, tag=tag1) @@ -1667,10 +1689,10 @@ def test_model_manager_evicts_models(self): mh3.load_model, tag=tag3).acquire() self.assertEqual(8, model3.predict(10)) - def test_model_manager_evicts_models_after_update(self): + def test_model_handler_manager_evicts_models_after_update(self): mh1 = FakeModelHandler(state=1) mhs = {'key1': mh1} - mm = base._ModelManager(mh_map=mhs) + mm = base._ModelHandlerManager(mh_map=mhs) tag1 = mm.load('key1').model_tag sh1 = multi_process_shared.MultiProcessShared(mh1.load_model, tag=tag1) model1 = sh1.acquire() @@ -1697,13 +1719,12 @@ def test_model_manager_evicts_models_after_update(self): self.assertEqual(6, model1.predict(10)) sh1.release(model1) - def test_model_manager_evicts_correct_num_of_models_after_being_incremented( - self): + def test_model_handler_manager_evicts_correctly_after_being_incremented(self): mh1 = FakeModelHandler(state=1) mh2 = FakeModelHandler(state=2) mh3 = FakeModelHandler(state=3) mhs = {'key1': mh1, 'key2': mh2, 'key3': mh3} - mm = base._ModelManager(mh_map=mhs) + mm = base._ModelHandlerManager(mh_map=mhs) mm.increment_max_models(1) mm.increment_max_models(1) tag1 = mm.load('key1').model_tag @@ -1888,6 +1909,36 @@ def test_model_status_provides_valid_garbage_collection(self): self.assertEqual(0, len(tags)) + @unittest.skipIf( + not try_import_model_manager(), 'Model Manager not available') + def test_run_inference_impl_with_model_manager(self): + with TestPipeline() as pipeline: + examples = [1, 5, 3, 10] + expected = [example + 1 for example in examples] + pcoll = pipeline | 'start' >> beam.Create(examples) + actual = pcoll | base.RunInference( + SimpleFakeModelHanlder(), use_model_manager=True) + assert_that(actual, equal_to(expected), label='assert:inferences') + + @unittest.skipIf( + not try_import_model_manager(), 'Model Manager not available') + def test_run_inference_impl_with_model_manager_args(self): + with TestPipeline() as pipeline: + examples = [1, 5, 3, 10] + expected = [example + 1 for example in examples] + pcoll = pipeline | 'start' >> beam.Create(examples) + actual = pcoll | base.RunInference( + SimpleFakeModelHanlder(), + use_model_manager=True, + model_manager_args={ + 'slack_percentage': 0.2, + 'poll_interval': 1.0, + 'peak_window_seconds': 10.0, + 'min_data_points': 10, + 'smoothing_factor': 0.5 + }) + assert_that(actual, equal_to(expected), label='assert:inferences') + def _always_retry(e: Exception) -> bool: return True diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py b/sdks/python/apache_beam/ml/inference/model_manager.py new file mode 100644 index 000000000000..e56b8ee8d03f --- /dev/null +++ b/sdks/python/apache_beam/ml/inference/model_manager.py @@ -0,0 +1,669 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Module for managing ML models in Apache Beam pipelines. + +This module provides classes and functions to efficiently manage multiple +machine learning models within Apache Beam pipelines. It includes functionality +for loading, caching, and updating models using multi-process shared memory, +ensuring that models are reused across different workers to optimize resource +usage and performance. +""" + +import uuid +import time +import threading +import subprocess +import logging +import gc +import numpy as np +from scipy.optimize import nnls +import torch +import heapq +import itertools +from collections import defaultdict, deque, Counter, OrderedDict +from typing import Dict, Any, Tuple, Optional, Callable + +logger = logging.getLogger(__name__) + + +class GPUMonitor: + def __init__( + self, + fallback_memory_mb: float = 16000.0, + poll_interval: float = 0.5, + peak_window_seconds: float = 30.0): + self._current_usage = 0.0 + self._peak_usage = 0.0 + self._total_memory = fallback_memory_mb + self._poll_interval = poll_interval + self._peak_window_seconds = peak_window_seconds + self._memory_history = deque() + self._running = False + self._thread = None + self._lock = threading.Lock() + self._gpu_available = self._detect_hardware() + + def _detect_hardware(self): + try: + cmd = [ + "nvidia-smi", + "--query-gpu=memory.total", + "--format=csv,noheader,nounits" + ] + output = subprocess.check_output(cmd, text=True).strip() + self._total_memory = float(output) + return True + except (FileNotFoundError, subprocess.CalledProcessError): + logger.warning( + "nvidia-smi not found or failed. Defaulting total memory to %s MB", + self._total_memory) + return False + except Exception as e: + logger.warning( + "Error parsing nvidia-smi output: %s. " + "Defaulting total memory to %s MB", + e, + self._total_memory) + return False + + def start(self): + if self._running or not self._gpu_available: + return + self._running = True + self._thread = threading.Thread(target=self._poll_loop, daemon=True) + self._thread.start() + + def stop(self): + self._running = False + if self._thread: + self._thread.join() + + def reset_peak(self): + with self._lock: + now = time.time() + self._memory_history.clear() + self._memory_history.append((now, self._current_usage)) + self._peak_usage = self._current_usage + + def get_stats(self) -> Tuple[float, float, float]: + with self._lock: + return self._current_usage, self._peak_usage, self._total_memory + + def refresh(self): + """Forces an immediate poll of the GPU.""" + usage = self._get_nvidia_smi_used() + now = time.time() + with self._lock: + self._current_usage = usage + self._memory_history.append((now, usage)) + # Recalculate peak immediately + while self._memory_history and (now - self._memory_history[0][0] + > self._peak_window_seconds): + self._memory_history.popleft() + self._peak_usage = ( + max(m for _, m in self._memory_history) + if self._memory_history else usage) + + def _get_nvidia_smi_used(self) -> float: + try: + cmd = "nvidia-smi --query-gpu=memory.free --format=csv,noheader,nounits" + output = subprocess.check_output(cmd, shell=True).decode("utf-8").strip() + free_memory = float(output) + return self._total_memory - free_memory + except Exception: + return 0.0 + + def _poll_loop(self): + while self._running: + usage = self._get_nvidia_smi_used() + now = time.time() + with self._lock: + self._current_usage = usage + self._memory_history.append((now, usage)) + while self._memory_history and (now - self._memory_history[0][0] + > self._peak_window_seconds): + self._memory_history.popleft() + self._peak_usage = ( + max(m for _, m in self._memory_history) + if self._memory_history else usage) + time.sleep(self._poll_interval) + + +class ResourceEstimator: + def __init__(self, smoothing_factor: float = 0.2, min_data_points: int = 5): + self.smoothing_factor = smoothing_factor + self.min_data_points = min_data_points + self.estimates: Dict[str, float] = {} + self.history = defaultdict(lambda: deque(maxlen=20)) + self.known_models = set() + self._lock = threading.Lock() + + def is_unknown(self, model_tag: str) -> bool: + with self._lock: + return model_tag not in self.estimates + + def get_estimate(self, model_tag: str, default_mb: float = 4000.0) -> float: + with self._lock: + return self.estimates.get(model_tag, default_mb) + + def set_initial_estimate(self, model_tag: str, cost: float): + with self._lock: + self.estimates[model_tag] = cost + self.known_models.add(model_tag) + logger.info("Initial Profile for %s: %s MB", model_tag, cost) + + def add_observation( + self, active_snapshot: Dict[str, int], peak_memory: float): + logger.info( + "Adding Observation: Snapshot=%s, PeakMemory=%.1f MB", + active_snapshot, + peak_memory) + if not active_snapshot: + return + with self._lock: + config_key = tuple(sorted(active_snapshot.items())) + self.history[config_key].append(peak_memory) + for tag in active_snapshot: + self.known_models.add(tag) + self._solve() + + def _solve(self): + """ + Solves Ax=b using raw readings (no pre-averaging) and NNLS. + This creates a 'tall' matrix A where every memory reading is + a separate equation. + """ + unique = sorted(list(self.known_models)) + + # We need to build the matrix first to know if we have enough data points + A, b = [], [] + + for config_key, mem_values in self.history.items(): + if not mem_values: + continue + + # 1. Create the feature row for this configuration ONCE + # (It represents the model counts + bias) + counts = dict(config_key) + feature_row = [counts.get(model, 0) for model in unique] + feature_row.append(1) # Bias column + + # 2. Add a separate row to the matrix for EVERY individual reading + # Instead of averaging, we flatten the history into the matrix + for reading in mem_values: + A.append(feature_row) # The inputs (models) stay the same + b.append(reading) # The output (memory) varies due to noise + + # Convert to numpy for SciPy + A = np.array(A) + b = np.array(b) + + if len( + self.history.keys()) < len(unique) + 1 or len(A) < self.min_data_points: + # Not enough data to solve yet + return + + logger.info( + "Solving with %s total observations for %s models.", + len(A), + len(unique)) + + try: + # Solve using Non-Negative Least Squares + # x will be >= 0 + x, _ = nnls(A, b) + + weights = x[:-1] + bias = x[-1] + + for i, model in enumerate(unique): + calculated_cost = weights[i] + + if model in self.estimates: + old = self.estimates[model] + new = (old * (1 - self.smoothing_factor)) + ( + calculated_cost * self.smoothing_factor) + self.estimates[model] = new + else: + self.estimates[model] = calculated_cost + + logger.info( + "Updated Estimate for %s: %.1f MB", model, self.estimates[model]) + logger.info("System Bias: %s MB", bias) + + except Exception as e: + logger.error("Solver failed: %s", e) + + +class TrackedModelProxy: + def __init__(self, obj): + object.__setattr__(self, "_wrapped_obj", obj) + object.__setattr__(self, "_beam_tracking_id", str(uuid.uuid4())) + + def __getattr__(self, name): + return getattr(self._wrapped_obj, name) + + def __setattr__(self, name, value): + setattr(self._wrapped_obj, name, value) + + def __call__(self, *args, **kwargs): + return self._wrapped_obj(*args, **kwargs) + + def __setstate__(self, state): + self.__dict__.update(state) + + def __getstate__(self): + return self.__dict__ + + def __str__(self): + return str(self._wrapped_obj) + + def __repr__(self): + return repr(self._wrapped_obj) + + def __dir__(self): + return dir(self._wrapped_obj) + + def unsafe_hard_delete(self): + if hasattr(self._wrapped_obj, "unsafe_hard_delete"): + self._wrapped_obj.unsafe_hard_delete() + + +class ModelManager: + _lock = threading.Lock() + + def __init__( + self, + monitor: Optional['GPUMonitor'] = None, + slack_percentage: float = 0.10, + poll_interval: float = 0.5, + peak_window_seconds: float = 30.0, + min_data_points: int = 5, + smoothing_factor: float = 0.2, + eviction_cooldown_seconds: float = 10.0, + min_model_copies: int = 1): + + self._estimator = ResourceEstimator( + min_data_points=min_data_points, smoothing_factor=smoothing_factor) + self._monitor = monitor if monitor else GPUMonitor( + poll_interval=poll_interval, peak_window_seconds=peak_window_seconds) + self._slack_percentage = slack_percentage + + self._eviction_cooldown = eviction_cooldown_seconds + self._min_model_copies = min_model_copies + + # Resource State + self._models = defaultdict(list) + self._idle_lru = OrderedDict() + self._active_counts = Counter() + self._total_active_jobs = 0 + self._pending_reservations = 0.0 + + self._isolation_mode = False + self._pending_isolation_count = 0 + self._isolation_baseline = 0.0 + + self._wait_queue = [] + self._ticket_counter = itertools.count() + self._cv = threading.Condition() + self._load_lock = threading.Lock() + + self._monitor.start() + + def all_models(self, tag) -> list[Any]: + return self._models[tag] + + def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: + current_priority = 0 if self._estimator.is_unknown(tag) else 1 + ticket_num = next(self._ticket_counter) + my_id = object() + + with self._cv: + # FAST PATH + if self._pending_isolation_count == 0 and not self._isolation_mode: + cached_instance = self._try_grab_from_lru(tag) + if cached_instance: + return cached_instance + + # SLOW PATH + logger.info( + "Acquire Queued: tag=%s, priority=%d " + "total models count=%s ticket num=%s", + tag, + current_priority, + len(self._models[tag]), + ticket_num) + heapq.heappush( + self._wait_queue, (current_priority, ticket_num, my_id, tag)) + + should_spawn = False + est_cost = 0.0 + is_unknown = False + + try: + while True: + if not self._wait_queue or self._wait_queue[0][2] is not my_id: + logger.info( + "Waiting for its turn: tag=%s ticket num=%s", tag, ticket_num) + self._cv.wait() + continue + + real_is_unknown = self._estimator.is_unknown(tag) + real_priority = 0 if real_is_unknown else 1 + + if current_priority != real_priority: + heapq.heappop(self._wait_queue) + current_priority = real_priority + heapq.heappush( + self._wait_queue, (current_priority, ticket_num, my_id, tag)) + self._cv.notify_all() + continue + + cached_instance = self._try_grab_from_lru(tag) + if cached_instance: + return cached_instance + + is_unknown = real_is_unknown + + # Path A: Isolation + if is_unknown: + if self._total_active_jobs > 0: + logger.info( + "Waiting to enter isolation: tag=%s ticket num=%s", + tag, + ticket_num) + self._cv.wait() + continue + + logger.info("Unknown model %s detected. Flushing GPU.", tag) + self._delete_all_models() + + self._isolation_mode = True + self._total_active_jobs += 1 + self._isolation_baseline, _, _ = self._monitor.get_stats() + self._monitor.reset_peak() + should_spawn = True + break + + # Path B: Concurrent + else: + if self._pending_isolation_count > 0 or self._isolation_mode: + logger.info( + "Waiting due to isolation in progress: tag=%s ticket num%s", + tag, + ticket_num) + self._cv.wait() + continue + + curr, _, total = self._monitor.get_stats() + est_cost = self._estimator.get_estimate(tag) + limit = total * (1 - self._slack_percentage) + + # Use current usage for capacity check (ignore old spikes) + if (curr + self._pending_reservations + est_cost) <= limit: + self._pending_reservations += est_cost + self._total_active_jobs += 1 + self._active_counts[tag] += 1 + should_spawn = True + break + + # Evict to make space (passing tag to check demand/existence) + if self._evict_to_make_space(limit, est_cost, requesting_tag=tag): + continue + + idle_count = 0 + other_idle_count = 0 + for item in self._idle_lru.items(): + if item[1][0] == tag: + idle_count += 1 + else: + other_idle_count += 1 + total_model_count = 0 + for _, instances in self._models.items(): + total_model_count += len(instances) + curr, _, _ = self._monitor.get_stats() + logger.info( + "Waiting for resources to free up: " + "tag=%s ticket num%s model count=%s " + "idle count=%s resource usage=%.1f MB " + "total models count=%s other idle=%s", + tag, + ticket_num, + len(self._models[tag]), + idle_count, + curr, + total_model_count, + other_idle_count) + self._cv.wait(timeout=10.0) + + finally: + if self._wait_queue and self._wait_queue[0][2] is my_id: + heapq.heappop(self._wait_queue) + else: + for i, item in enumerate(self._wait_queue): + if item[2] is my_id: + self._wait_queue.pop(i) + heapq.heapify(self._wait_queue) + self._cv.notify_all() + + if should_spawn: + return self._spawn_new_model(tag, loader_func, is_unknown, est_cost) + + def release_model(self, tag: str, instance: Any): + with self._cv: + try: + self._total_active_jobs -= 1 + if self._active_counts[tag] > 0: + self._active_counts[tag] -= 1 + + self._idle_lru[id(instance)] = (tag, instance, time.time()) + + _, peak_during_job, _ = self._monitor.get_stats() + + if self._isolation_mode and self._active_counts[tag] == 0: + cost = max(0, peak_during_job - self._isolation_baseline) + self._estimator.set_initial_estimate(tag, cost) + self._isolation_mode = False + self._isolation_baseline = 0.0 + else: + snapshot = { + t: len(instances) + for t, instances in self._models.items() if len(instances) > 0 + } + if snapshot: + self._estimator.add_observation(snapshot, peak_during_job) + + finally: + self._cv.notify_all() + + def _try_grab_from_lru(self, tag: str) -> Any: + target_key = None + target_instance = None + + for key, (t, instance, _) in reversed(self._idle_lru.items()): + if t == tag: + target_key = key + target_instance = instance + break + + if target_instance: + del self._idle_lru[target_key] + self._active_counts[tag] += 1 + self._total_active_jobs += 1 + return target_instance + + logger.info("No idle model found for tag: %s", tag) + return None + + def _evict_to_make_space( + self, limit: float, est_cost: float, requesting_tag: str) -> bool: + """ + Evicts models based on Demand Magnitude + Tiers. + Crucially: If we have 0 active copies of 'requesting_tag', we FORCE eviction + of the lowest-demand candidate to avoid starvation. + """ + evicted_something = False + curr, _, _ = self._monitor.get_stats() + projected_usage = curr + self._pending_reservations + est_cost + + if projected_usage <= limit: + return False + + now = time.time() + + demand_map = Counter() + for item in self._wait_queue: + if len(item) >= 4: + demand_map[item[3]] += 1 + + my_demand = demand_map[requesting_tag] + am_i_starving = len(self._models[requesting_tag]) == 0 + + candidates = [] + for key, (tag, instance, release_time) in self._idle_lru.items(): + candidate_demand = demand_map[tag] + + if not am_i_starving: + if candidate_demand >= my_demand: + continue + + age = now - release_time + is_cold = age >= self._eviction_cooldown + + total_copies = len(self._models[tag]) + is_surplus = total_copies > self._min_model_copies + + if is_cold and is_surplus: tier = 0 + elif not is_cold and is_surplus: tier = 1 + elif is_cold and not is_surplus: tier = 2 + else: tier = 3 + + score = (candidate_demand * 10) + tier + + candidates.append((score, release_time, key, tag, instance)) + + candidates.sort(key=lambda x: (x[0], x[1])) + + for score, _, key, tag, instance in candidates: + if projected_usage <= limit: + break + + if key not in self._idle_lru: continue + + self._perform_eviction(key, tag, instance, score) + evicted_something = True + + curr, _, _ = self._monitor.get_stats() + projected_usage = curr + self._pending_reservations + est_cost + + return evicted_something + + def _perform_eviction(self, key, tag, instance, score): + logger.info("Evicting Model: %s (Score %d)", tag, score) + curr, _, _ = self._monitor.get_stats() + logger.info("Resource Usage Before Eviction: %.1f MB", curr) + + if key in self._idle_lru: + del self._idle_lru[key] + + target_id = instance._beam_tracking_id + for i, inst in enumerate(self._models[tag]): + if inst._beam_tracking_id == target_id: + del self._models[tag][i] + break + + instance.unsafe_hard_delete() + del instance + gc.collect() + torch.cuda.empty_cache() + self._monitor.refresh() + self._monitor.reset_peak() + curr, _, _ = self._monitor.get_stats() + logger.info("Resource Usage After Eviction: %.1f MB", curr) + + def _spawn_new_model(self, tag, loader_func, is_unknown, est_cost): + try: + with self._load_lock: + logger.info("Loading Model: %s (Unknown: %s)", tag, is_unknown) + isolation_baseline_snap, _, _ = self._monitor.get_stats() + instance = TrackedModelProxy(loader_func()) + _, peak_during_load, _ = self._monitor.get_stats() + + with self._cv: + snapshot = {tag: 1} + self._estimator.add_observation( + snapshot, peak_during_load - isolation_baseline_snap) + + if not is_unknown: + self._pending_reservations = max( + 0.0, self._pending_reservations - est_cost) + self._models[tag].append(instance) + return instance + + except Exception as e: + logger.error("Load Failed: %s. Error: %s", tag, e) + with self._cv: + self._total_active_jobs -= 1 + if is_unknown: + self._isolation_mode = False + self._isolation_baseline = 0.0 + else: + self._pending_reservations = max( + 0.0, self._pending_reservations - est_cost) + self._active_counts[tag] -= 1 + self._cv.notify_all() + raise e + + def _delete_all_models(self): + self._idle_lru.clear() + for _, instances in self._models.items(): + for instance in instances: + if hasattr(instance, "unsafe_hard_delete"): + instance.unsafe_hard_delete() + del instance + self._models.clear() + self._active_counts.clear() + gc.collect() + torch.cuda.empty_cache() + self._monitor.refresh() + self._monitor.reset_peak() + + def _force_reset(self): + logger.warning("Force Reset Triggered") + self._delete_all_models() + self._models = defaultdict(list) + self._idle_lru = OrderedDict() + self._active_counts = Counter() + self._wait_queue = [] + self._total_active_jobs = 0 + self._pending_reservations = 0.0 + self._isolation_mode = False + self._pending_isolation_count = 0 + self._isolation_baseline = 0.0 + + def shutdown(self): + self._delete_all_models() + gc.collect() + torch.cuda.empty_cache() + self._monitor.stop() + + def __del__(self): + self.shutdown() + + def __exit__(self, exc_type, exc_value, traceback): + self.shutdown() diff --git a/sdks/python/apache_beam/ml/inference/model_manager_it_test.py b/sdks/python/apache_beam/ml/inference/model_manager_it_test.py new file mode 100644 index 000000000000..609de0606a14 --- /dev/null +++ b/sdks/python/apache_beam/ml/inference/model_manager_it_test.py @@ -0,0 +1,184 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +import apache_beam as beam +from apache_beam.ml.inference.base import RunInference +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import assert_that, equal_to + +# pylint: disable=ungrouped-imports +try: + import torch + from apache_beam.ml.inference.huggingface_inference import HuggingFacePipelineModelHandler +except ImportError as e: + raise unittest.SkipTest( + "HuggingFace model handler dependencies are not installed") + + +class HuggingFaceGpuTest(unittest.TestCase): + + # Skips the test if you run it on a machine without a GPU + @unittest.skipIf( + not torch.cuda.is_available(), "No GPU detected, skipping GPU test") + def test_sentiment_analysis_on_gpu_large_input(self): + """ + Runs inference on a GPU (device=0) with a larger set of inputs. + """ + model_handler = HuggingFacePipelineModelHandler( + task="sentiment-analysis", + model="distilbert-base-uncased-finetuned-sst-2-english", + device=0, + inference_args={"batch_size": 4}) + DUPLICATE_FACTOR = 2 + + with TestPipeline() as pipeline: + examples = [ + "I absolutely love this product, it's a game changer!", + "This is the worst experience I have ever had.", + "The weather is okay, but I wish it were sunnier.", + "Apache Beam makes parallel processing incredibly efficient.", + "I am extremely disappointed with the service.", + "Logic and reason are the pillars of good debugging.", + "I'm so happy today!", + "This error message is confusing and unhelpful.", + "The movie was fantastic and the acting was superb.", + "I hate waiting in line for so long." + ] * DUPLICATE_FACTOR + + pcoll = pipeline | 'CreateInputs' >> beam.Create(examples) + + predictions = pcoll | 'RunInference' >> RunInference( + model_handler, use_model_manager=True) + + actual_labels = predictions | beam.Map(lambda x: x.inference['label']) + + expected_labels = [ + 'POSITIVE', # "love this product" + 'NEGATIVE', # "worst experience" + 'NEGATIVE', # "weather is okay, but..." + 'POSITIVE', # "incredibly efficient" + 'NEGATIVE', # "disappointed" + 'POSITIVE', # "pillars of good debugging" + 'POSITIVE', # "so happy" + 'NEGATIVE', # "confusing and unhelpful" + 'POSITIVE', # "fantastic" + 'NEGATIVE' # "hate waiting" + ] * DUPLICATE_FACTOR + + assert_that( + actual_labels, equal_to(expected_labels), label='CheckPredictions') + + @unittest.skipIf(not torch.cuda.is_available(), "No GPU detected") + def test_sentiment_analysis_large_roberta_gpu(self): + """ + Runs inference using a Large architecture (RoBERTa-Large, ~355M params). + This tests if the GPU can handle larger weights and requires more VRAM. + """ + + model_handler = HuggingFacePipelineModelHandler( + task="sentiment-analysis", + model="Siebert/sentiment-roberta-large-english", + device=0, + inference_args={"batch_size": 2}) + + DUPLICATE_FACTOR = 2 + + with TestPipeline() as pipeline: + examples = [ + "I absolutely love this product, it's a game changer!", + "This is the worst experience I have ever had.", + "Apache Beam scales effortlessly to massive datasets.", + "I am somewhat annoyed by the delay.", + "The nuanced performance of this large model is impressive.", + "I regret buying this immediately.", + "The sunset looks beautiful tonight.", + "This documentation is sparse and misleading.", + "Winning the championship felt surreal.", + "I'm feeling very neutral about this whole situation." + ] * DUPLICATE_FACTOR + + pcoll = pipeline | 'CreateInputs' >> beam.Create(examples) + predictions = pcoll | 'RunInference' >> RunInference( + model_handler, use_model_manager=True) + actual_labels = predictions | beam.Map(lambda x: x.inference['label']) + + expected_labels = [ + 'POSITIVE', # love + 'NEGATIVE', # worst + 'POSITIVE', # scales effortlessly + 'NEGATIVE', # annoyed + 'POSITIVE', # impressive + 'NEGATIVE', # regret + 'POSITIVE', # beautiful + 'NEGATIVE', # misleading + 'POSITIVE', # surreal + 'NEGATIVE' # "neutral" + ] * DUPLICATE_FACTOR + + assert_that( + actual_labels, + equal_to(expected_labels), + label='CheckPredictionsLarge') + + @unittest.skipIf(not torch.cuda.is_available(), "No GPU detected") + def test_parallel_inference_branches(self): + """ + Tests a branching pipeline where one input source feeds two + RunInference transforms running in parallel. + + Topology: + [ Input Data ] + | + +--------+--------+ + | | + [ Translation ] [ Sentiment ] + """ + + translator_handler = HuggingFacePipelineModelHandler( + task="translation_en_to_es", + model="Helsinki-NLP/opus-mt-en-es", + device=0, + inference_args={"batch_size": 8}) + sentiment_handler = HuggingFacePipelineModelHandler( + task="sentiment-analysis", + model="nlptown/bert-base-multilingual-uncased-sentiment", + device=0, + inference_args={"batch_size": 8}) + base_examples = [ + "I love this product.", + "This is terrible.", + "Hello world.", + "The service was okay.", + "I am extremely angry." + ] + MULTIPLIER = 10 + examples = base_examples * MULTIPLIER + + with TestPipeline() as pipeline: + inputs = pipeline | 'CreateInputs' >> beam.Create(examples) + _ = ( + inputs + | 'RunTranslation' >> RunInference( + translator_handler, use_model_manager=True) + | 'ExtractSpanish' >> + beam.Map(lambda x: x.inference['translation_text'])) + _ = ( + inputs + | 'RunSentiment' >> RunInference( + sentiment_handler, use_model_manager=True) + | 'ExtractLabel' >> beam.Map(lambda x: x.inference['label'])) diff --git a/sdks/python/apache_beam/ml/inference/model_manager_test.py b/sdks/python/apache_beam/ml/inference/model_manager_test.py new file mode 100644 index 000000000000..7412ea6a6c64 --- /dev/null +++ b/sdks/python/apache_beam/ml/inference/model_manager_test.py @@ -0,0 +1,548 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +import time +import threading +import random +from concurrent.futures import ThreadPoolExecutor +from unittest.mock import patch + +try: + from apache_beam.ml.inference.model_manager import ModelManager, GPUMonitor, ResourceEstimator +except ImportError as e: + raise unittest.SkipTest("Model Manager dependencies are not installed") + + +class MockGPUMonitor: + """ + Simulates GPU hardware with cumulative memory tracking. + Allows simulating specific allocation spikes and baseline usage. + """ + def __init__(self, total_memory=12000.0, peak_window: int = 5): + self._current = 0.0 + self._peak = 0.0 + self._total = total_memory + self._lock = threading.Lock() + self.running = False + self.history = [] + self.peak_window = peak_window + + def start(self): + self.running = True + + def stop(self): + self.running = False + + def get_stats(self): + with self._lock: + return self._current, self._peak, self._total + + def reset_peak(self): + with self._lock: + self._peak = self._current + self.history = [self._current] + + def set_usage(self, current_mb): + """Sets absolute usage (legacy helper).""" + with self._lock: + self._current = current_mb + self._peak = max(self._peak, current_mb) + + def allocate(self, amount_mb): + """Simulates memory allocation (e.g., tensors loaded to VRAM).""" + with self._lock: + self._current += amount_mb + self.history.append(self._current) + if len(self.history) > self.peak_window: + self.history.pop(0) + self._peak = max(self.history) + + def free(self, amount_mb): + """Simulates memory freeing (not used often if pooling is active).""" + with self._lock: + self._current = max(0.0, self._current - amount_mb) + self.history.append(self._current) + if len(self.history) > self.peak_window: + self.history.pop(0) + self._peak = max(self.history) + + def refresh(self): + """Simulates a refresh of the monitor stats (no-op for mock).""" + pass + + +class MockModel: + def __init__(self, name, size, monitor): + self.name = name + self.size = size + self.monitor = monitor + self.deleted = False + self.monitor.allocate(size) + + def unsafe_hard_delete(self): + if not self.deleted: + self.monitor.free(self.size) + self.deleted = True + + +class TestModelManager(unittest.TestCase): + def setUp(self): + """Force reset the Singleton ModelManager before every test.""" + ModelManager._instance = None + self.mock_monitor = MockGPUMonitor() + self.manager = ModelManager(monitor=self.mock_monitor) + + def tearDown(self): + self.manager.shutdown() + + def test_model_manager_capacity_check(self): + """ + Test that the manager blocks when spawning models exceeds the limit, + and unblocks when resources become available (via reuse). + """ + model_name = "known_model" + model_cost = 3000.0 + self.manager._estimator.set_initial_estimate(model_name, model_cost) + acquired_refs = [] + + def loader(): + self.mock_monitor.allocate(model_cost) + return model_name + + # 1. Saturate GPU with 3 models (9000 MB usage) + for _ in range(3): + inst = self.manager.acquire_model(model_name, loader) + acquired_refs.append(inst) + + # 2. Spawn one more (Should Block because 9000 + 3000 > Limit) + def run_inference(): + return self.manager.acquire_model(model_name, loader) + + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(run_inference) + try: + future.result(timeout=0.5) + self.fail("Should have blocked due to capacity") + except TimeoutError: + pass + + # 3. Release resources to unblock + item_to_release = acquired_refs.pop() + self.manager.release_model(model_name, item_to_release) + + result = future.result(timeout=2.0) + self.assertIsNotNone(result) + self.assertEqual(result, item_to_release) + + def test_model_manager_unknown_model_runs_isolated(self): + """Test that a model with no history runs in isolation.""" + model_name = "unknown_model_v1" + self.assertTrue(self.manager._estimator.is_unknown(model_name)) + + def dummy_loader(): + time.sleep(0.05) + return "model_instance" + + instance = self.manager.acquire_model(model_name, dummy_loader) + + self.assertTrue(self.manager._isolation_mode) + self.assertEqual(self.manager._total_active_jobs, 1) + + self.manager.release_model(model_name, instance) + self.assertFalse(self.manager._isolation_mode) + self.assertFalse(self.manager._estimator.is_unknown(model_name)) + + def test_model_manager_concurrent_execution(self): + """Test that multiple small known models can run together.""" + model_a = "small_model_a" + model_b = "small_model_b" + + self.manager._estimator.set_initial_estimate(model_a, 1000.0) + self.manager._estimator.set_initial_estimate(model_b, 1000.0) + self.mock_monitor.set_usage(1000.0) + + inst_a = self.manager.acquire_model(model_a, lambda: "A") + inst_b = self.manager.acquire_model(model_b, lambda: "B") + + self.assertEqual(self.manager._total_active_jobs, 2) + + self.manager.release_model(model_a, inst_a) + self.manager.release_model(model_b, inst_b) + self.assertEqual(self.manager._total_active_jobs, 0) + + def test_model_manager_concurrent_mixed_workload_convergence(self): + """ + Simulates a production environment with multiple model types running + concurrently. Verifies that the estimator converges. + """ + TRUE_COSTS = {"model_small": 1500.0, "model_medium": 3000.0} + + def run_job(model_name): + cost = TRUE_COSTS[model_name] + + def loader(): + model = MockModel(model_name, cost, self.mock_monitor) + return model + + instance = self.manager.acquire_model(model_name, loader) + time.sleep(random.uniform(0.01, 0.05)) + self.manager.release_model(model_name, instance) + + # Create a workload stream + workload = ["model_small"] * 15 + ["model_medium"] * 15 + random.shuffle(workload) + + with ThreadPoolExecutor(max_workers=8) as executor: + futures = [executor.submit(run_job, name) for name in workload] + for f in futures: + f.result() + + est_small = self.manager._estimator.get_estimate("model_small") + est_med = self.manager._estimator.get_estimate("model_medium") + + self.assertAlmostEqual(est_small, TRUE_COSTS["model_small"], delta=100.0) + self.assertAlmostEqual(est_med, TRUE_COSTS["model_medium"], delta=100.0) + + def test_model_manager_oom_recovery(self): + """Test that the manager recovers state if a loader crashes.""" + model_name = "crasher_model" + self.manager._estimator.set_initial_estimate(model_name, 1000.0) + + def crashing_loader(): + raise RuntimeError("CUDA OOM or similar") + + with self.assertRaises(RuntimeError): + self.manager.acquire_model(model_name, crashing_loader) + + self.assertEqual(self.manager._total_active_jobs, 0) + self.assertEqual(self.manager._pending_reservations, 0.0) + self.assertFalse(self.manager._cv._is_owned()) + + def test_model_managaer_force_reset_on_exception(self): + """Test that force_reset clears all models from the manager.""" + model_name = "test_model" + + def dummy_loader(): + self.mock_monitor.allocate(1000.0) + raise RuntimeError("Simulated loader exception") + + try: + instance = self.manager.acquire_model( + model_name, lambda: "model_instance") + self.manager.release_model(model_name, instance) + instance = self.manager.acquire_model(model_name, dummy_loader) + except RuntimeError: + self.manager._force_reset() + self.assertTrue(len(self.manager._models[model_name]) == 0) + self.assertEqual(self.manager._total_active_jobs, 0) + self.assertEqual(self.manager._pending_reservations, 0.0) + self.assertFalse(self.manager._isolation_mode) + pass + + instance = self.manager.acquire_model(model_name, lambda: "model_instance") + self.manager.release_model(model_name, instance) + + def test_single_model_convergence_with_fluctuations(self): + """ + Tests that the estimator converges to the true usage with fluctuations. + """ + model_name = "fluctuating_model" + model_cost = 3000.0 + load_cost = 2000.0 + + def loader(): + self.mock_monitor.allocate(load_cost) + return model_name + + model = self.manager.acquire_model(model_name, loader) + self.manager.release_model(model_name, model) + initial_est = self.manager._estimator.get_estimate(model_name) + self.assertEqual(initial_est, load_cost) + + def run_inference(): + model = self.manager.acquire_model(model_name, loader) + noise = model_cost - load_cost + random.uniform(-300.0, 300.0) + self.mock_monitor.allocate(noise) + time.sleep(0.1) + self.mock_monitor.free(noise) + self.manager.release_model(model_name, model) + return + + with ThreadPoolExecutor(max_workers=8) as executor: + futures = [executor.submit(run_inference) for _ in range(100)] + + for f in futures: + f.result() + + est_cost = self.manager._estimator.get_estimate(model_name) + self.assertAlmostEqual(est_cost, model_cost, delta=100.0) + + +class TestModelManagerEviction(unittest.TestCase): + def setUp(self): + self.mock_monitor = MockGPUMonitor(total_memory=12000.0) + ModelManager._instance = None + self.manager = ModelManager( + monitor=self.mock_monitor, + slack_percentage=0.0, + min_data_points=1, + eviction_cooldown_seconds=10.0, + min_model_copies=1) + + def tearDown(self): + self.manager.shutdown() + + def create_loader(self, name, size): + return lambda: MockModel(name, size, self.mock_monitor) + + def test_basic_lru_eviction(self): + self.manager._estimator.set_initial_estimate("A", 4000) + self.manager._estimator.set_initial_estimate("B", 4000) + self.manager._estimator.set_initial_estimate("C", 5000) + + model_a = self.manager.acquire_model("A", self.create_loader("A", 4000)) + self.manager.release_model("A", model_a) + + model_b = self.manager.acquire_model("B", self.create_loader("B", 4000)) + self.manager.release_model("B", model_b) + + key_a = list(self.manager._idle_lru.keys())[0] + self.manager._idle_lru[key_a] = ("A", model_a, time.time() - 20.0) + + key_b = list(self.manager._idle_lru.keys())[1] + self.manager._idle_lru[key_b] = ("B", model_b, time.time() - 20.0) + + model_a_again = self.manager.acquire_model( + "A", self.create_loader("A", 4000)) + self.manager.release_model("A", model_a_again) + + self.manager.acquire_model("C", self.create_loader("C", 5000)) + + self.assertEqual(len(self.manager.all_models("B")), 0) + self.assertEqual(len(self.manager.all_models("A")), 1) + + def test_chained_eviction(self): + self.manager._estimator.set_initial_estimate("big_guy", 8000) + models = [] + for i in range(4): + name = f"small_{i}" + m = self.manager.acquire_model(name, self.create_loader(name, 3000)) + self.manager.release_model(name, m) + models.append(m) + + self.manager.acquire_model("big_guy", self.create_loader("big_guy", 8000)) + + self.assertTrue(models[0].deleted) + self.assertTrue(models[1].deleted) + self.assertTrue(models[2].deleted) + self.assertFalse(models[3].deleted) + + def test_active_models_are_protected(self): + self.manager._estimator.set_initial_estimate("A", 6000) + self.manager._estimator.set_initial_estimate("B", 4000) + self.manager._estimator.set_initial_estimate("C", 4000) + + model_a = self.manager.acquire_model("A", self.create_loader("A", 6000)) + model_b = self.manager.acquire_model("B", self.create_loader("B", 4000)) + self.manager.release_model("B", model_b) + + key_b = list(self.manager._idle_lru.keys())[0] + self.manager._idle_lru[key_b] = ("B", model_b, time.time() - 20.0) + + def acquire_c(): + return self.manager.acquire_model("C", self.create_loader("C", 4000)) + + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(acquire_c) + model_c = future.result(timeout=2.0) + + self.assertTrue(model_b.deleted) + self.assertFalse(model_a.deleted) + + self.manager.release_model("A", model_a) + self.manager.release_model("C", model_c) + + def test_unknown_model_clears_memory(self): + self.manager._estimator.set_initial_estimate("A", 2000) + model_a = self.manager.acquire_model("A", self.create_loader("A", 2000)) + self.manager.release_model("A", model_a) + self.assertFalse(model_a.deleted) + + self.assertTrue(self.manager._estimator.is_unknown("X")) + model_x = self.manager.acquire_model("X", self.create_loader("X", 10000)) + + self.assertTrue(model_a.deleted, "Model A should be deleted for isolation") + self.assertEqual(len(self.manager.all_models("A")), 0) + self.assertTrue(self.manager._isolation_mode) + self.manager.release_model("X", model_x) + + def test_concurrent_eviction_pressure(self): + def worker(idx): + name = f"model_{idx % 5}" + try: + m = self.manager.acquire_model(name, self.create_loader(name, 4000)) + time.sleep(0.001) + self.manager.release_model(name, m) + except Exception: + pass + + with ThreadPoolExecutor(max_workers=8) as executor: + futures = [executor.submit(worker, i) for i in range(50)] + for f in futures: + f.result() + + curr, _, _ = self.mock_monitor.get_stats() + expected_usage = 0 + for _, instances in self.manager._models.items(): + expected_usage += len(instances) * 4000 + + self.assertAlmostEqual(curr, expected_usage) + + def test_starvation_prevention_overrides_demand(self): + self.manager._estimator.set_initial_estimate("A", 12000) + m_a = self.manager.acquire_model("A", self.create_loader("A", 12000)) + self.manager.release_model("A", m_a) + + def cycle_a(): + try: + m = self.manager.acquire_model("A", self.create_loader("A", 12000)) + time.sleep(0.3) + self.manager.release_model("A", m) + except Exception: + pass + + executor = ThreadPoolExecutor(max_workers=5) + for _ in range(5): + executor.submit(cycle_a) + + def acquire_b(): + return self.manager.acquire_model("B", self.create_loader("B", 4000)) + + b_future = executor.submit(acquire_b) + model_b = b_future.result() + + self.assertTrue(m_a.deleted) + self.manager.release_model("B", model_b) + executor.shutdown(wait=True) + + +class TestGPUMonitor(unittest.TestCase): + def setUp(self): + self.subprocess_patcher = patch('subprocess.check_output') + self.mock_subprocess = self.subprocess_patcher.start() + + def tearDown(self): + self.subprocess_patcher.stop() + + def test_init_hardware_detected(self): + """Test that init correctly reads total memory when nvidia-smi exists.""" + self.mock_subprocess.return_value = "24576" + monitor = GPUMonitor() + self.assertTrue(monitor._gpu_available) + self.assertEqual(monitor._total_memory, 24576.0) + + def test_init_hardware_missing(self): + """Test fallback behavior when nvidia-smi is missing.""" + self.mock_subprocess.side_effect = FileNotFoundError() + monitor = GPUMonitor(fallback_memory_mb=12000.0) + self.assertFalse(monitor._gpu_available) + self.assertEqual(monitor._total_memory, 12000.0) + + @patch('time.sleep') + def test_polling_updates_stats(self, mock_sleep): + """Test that the polling loop updates current and peak usage.""" + def subprocess_side_effect(*args, **kwargs): + if isinstance(args[0], list) and "memory.total" in args[0][1]: + return "16000" + + if isinstance(args[0], str) and "memory.free" in args[0]: + return b"12000" + + raise Exception("Unexpected command") + + self.mock_subprocess.side_effect = subprocess_side_effect + self.mock_subprocess.return_value = None + + monitor = GPUMonitor() + monitor.start() + time.sleep(0.1) + curr, peak, total = monitor.get_stats() + monitor.stop() + + self.assertEqual(curr, 4000.0) + self.assertEqual(peak, 4000.0) + self.assertEqual(total, 16000.0) + + def test_reset_peak(self): + """Test that resetting peak usage works.""" + monitor = GPUMonitor() + monitor._gpu_available = True + + with monitor._lock: + monitor._current_usage = 2000.0 + monitor._peak_usage = 8000.0 + monitor._memory_history.append((time.time(), 8000.0)) + monitor._memory_history.append((time.time(), 2000.0)) + + monitor.reset_peak() + + _, peak, _ = monitor.get_stats() + self.assertEqual(peak, 2000.0) + + +class TestResourceEstimatorSolver(unittest.TestCase): + def setUp(self): + self.estimator = ResourceEstimator() + + @patch('apache_beam.ml.inference.model_manager.nnls') + def test_solver_respects_min_data_points(self, mock_nnls): + mock_nnls.return_value = ([100.0, 50.0], 0.0) + + self.estimator.add_observation({'model_A': 1}, 500) + self.estimator.add_observation({'model_B': 1}, 500) + self.assertFalse(mock_nnls.called) + + self.estimator.add_observation({'model_A': 1, 'model_B': 1}, 1000) + self.assertFalse(mock_nnls.called) + + self.estimator.add_observation({'model_A': 1}, 500) + self.assertFalse(mock_nnls.called) + + self.estimator.add_observation({'model_B': 1}, 500) + self.assertTrue(mock_nnls.called) + + @patch('apache_beam.ml.inference.model_manager.nnls') + def test_solver_respects_unique_model_constraint(self, mock_nnls): + mock_nnls.return_value = ([100.0, 100.0, 50.0], 0.0) + + for _ in range(5): + self.estimator.add_observation({'model_A': 1, 'model_B': 1}, 800) + + for _ in range(5): + self.estimator.add_observation({'model_C': 1}, 400) + + self.assertFalse(mock_nnls.called) + + self.estimator.add_observation({'model_A': 1}, 300) + self.estimator.add_observation({'model_B': 1}, 300) + + self.assertTrue(mock_nnls.called) + + +if __name__ == "__main__": + unittest.main() diff --git a/sdks/python/apache_beam/utils/multi_process_shared.py b/sdks/python/apache_beam/utils/multi_process_shared.py index aecb1284a1d4..0efa01f45570 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared.py +++ b/sdks/python/apache_beam/utils/multi_process_shared.py @@ -25,6 +25,10 @@ import logging import multiprocessing.managers import os +import time +import traceback +import atexit +import sys import tempfile import threading from typing import Any @@ -79,6 +83,10 @@ def singletonProxy_release(self): assert self._SingletonProxy_valid self._SingletonProxy_valid = False + def unsafe_hard_delete(self): + assert self._SingletonProxy_valid + self._SingletonProxy_entry.unsafe_hard_delete() + def __getattr__(self, name): if not self._SingletonProxy_valid: raise RuntimeError('Entry was released.') @@ -105,17 +113,39 @@ def __dir__(self): dir = self._SingletonProxy_entry.obj.__dir__() dir.append('singletonProxy_call__') dir.append('singletonProxy_release') + dir.append('unsafe_hard_delete') return dir +def _run_with_oom_protection(func, *args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + # Check string to avoid hard import dependency + if 'CUDA out of memory' in str(e): + logging.warning("Caught CUDA OOM during operation. Cleaning memory.") + try: + import gc + import torch + gc.collect() + torch.cuda.empty_cache() + except ImportError: + pass + except Exception as cleanup_error: + logging.error("Failed to clean up CUDA memory: %s", cleanup_error) + raise e + + class _SingletonEntry: """Represents a single, refcounted entry in this process.""" - def __init__(self, constructor, initialize_eagerly=True): + def __init__( + self, constructor, initialize_eagerly=True, hard_delete_callback=None): self.constructor = constructor + self._hard_delete_callback = hard_delete_callback self.refcount = 0 self.lock = threading.Lock() if initialize_eagerly: - self.obj = constructor() + self.obj = _run_with_oom_protection(constructor) self.initialied = True else: self.initialied = False @@ -123,7 +153,7 @@ def __init__(self, constructor, initialize_eagerly=True): def acquire(self): with self.lock: if not self.initialied: - self.obj = self.constructor() + self.obj = _run_with_oom_protection(self.constructor) self.initialied = True self.refcount += 1 return _SingletonProxy(self) @@ -141,14 +171,28 @@ def unsafe_hard_delete(self): if self.initialied: del self.obj self.initialied = False + if self._hard_delete_callback: + self._hard_delete_callback() class _SingletonManager: entries: Dict[Any, Any] = {} - def register_singleton(self, constructor, tag, initialize_eagerly=True): + def __init__(self): + self._hard_delete_callback = None + + def set_hard_delete_callback(self, callback): + self._hard_delete_callback = callback + + def register_singleton( + self, + constructor, + tag, + initialize_eagerly=True, + hard_delete_callback=None): assert tag not in self.entries, tag - self.entries[tag] = _SingletonEntry(constructor, initialize_eagerly) + self.entries[tag] = _SingletonEntry( + constructor, initialize_eagerly, hard_delete_callback) def has_singleton(self, tag): return tag in self.entries @@ -160,7 +204,8 @@ def release_singleton(self, tag, obj): return self.entries[tag].release(obj) def unsafe_hard_delete_singleton(self, tag): - return self.entries[tag].unsafe_hard_delete() + self.entries[tag].unsafe_hard_delete() + self._hard_delete_callback() _process_level_singleton_manager = _SingletonManager() @@ -200,9 +245,99 @@ def __call__(self, *args, **kwargs): def __getattr__(self, name): return getattr(self._proxyObject, name) + def __setstate__(self, state): + self.__dict__.update(state) + + def __getstate__(self): + return self.__dict__ + def get_auto_proxy_object(self): return self._proxyObject + def unsafe_hard_delete(self): + try: + self._proxyObject.unsafe_hard_delete() + except (EOFError, ConnectionResetError, BrokenPipeError): + pass + except Exception as e: + logging.warning( + "Exception %s when trying to hard delete shared object proxy", e) + + +def _run_server_process(address_file, tag, constructor, authkey): + """ + Runs in a separate process. + Includes a 'Suicide Pact' monitor: If parent dies, I die. + """ + parent_pid = os.getppid() + + def cleanup_files(): + logging.info("Server process exiting. Deleting files for %s", tag) + try: + if os.path.exists(address_file): + os.remove(address_file) + if os.path.exists(address_file + ".error"): + os.remove(address_file + ".error") + except Exception: + pass + + def handle_unsafe_hard_delete(): + cleanup_files() + os._exit(0) + + def _monitor_parent(): + """Checks if parent is alive every second.""" + while True: + try: + os.kill(parent_pid, 0) + except OSError: + logging.warning( + "Process %s detected Parent %s died. Self-destructing.", + os.getpid(), + parent_pid) + cleanup_files() + os._exit(0) + time.sleep(0.5) + + atexit.register(cleanup_files) + + try: + t = threading.Thread(target=_monitor_parent, daemon=True) + t.start() + + logging.getLogger().setLevel(logging.INFO) + multiprocessing.current_process().authkey = authkey + + serving_manager = _SingletonRegistrar( + address=('localhost', 0), authkey=authkey) + _process_level_singleton_manager.set_hard_delete_callback( + handle_unsafe_hard_delete) + _process_level_singleton_manager.register_singleton( + constructor, + tag, + initialize_eagerly=True, + hard_delete_callback=handle_unsafe_hard_delete) + + server = serving_manager.get_server() + logging.info( + 'Process %s: Proxy serving %s at %s', os.getpid(), tag, server.address) + + with open(address_file + '.tmp', 'w') as fout: + fout.write('%s:%d' % server.address) + os.rename(address_file + '.tmp', address_file) + + server.serve_forever() + + except Exception: + tb = traceback.format_exc() + try: + with open(address_file + ".error.tmp", 'w') as fout: + fout.write(tb) + os.rename(address_file + ".error.tmp", address_file + ".error") + except Exception: + print(f"CRITICAL ERROR IN SHARED SERVER:\n{tb}", file=sys.stderr) + os._exit(1) + class MultiProcessShared(Generic[T]): """MultiProcessShared is used to share a single object across processes. @@ -252,7 +387,8 @@ def __init__( tag: Any, *, path: str = tempfile.gettempdir(), - always_proxy: Optional[bool] = None): + always_proxy: Optional[bool] = None, + spawn_process: bool = False): self._constructor = constructor self._tag = tag self._path = path @@ -262,6 +398,7 @@ def __init__( self._rpc_address = None self._cross_process_lock = fasteners.InterProcessLock( os.path.join(self._path, self._tag) + '.lock') + self._spawn_process = spawn_process def _get_manager(self): if self._manager is None: @@ -301,6 +438,10 @@ def acquire(self): # Caveat: They must always agree, as they will be ignored if the object # is already constructed. singleton = self._get_manager().acquire_singleton(self._tag) + # Trigger a sweep of zombie processes. + # calling active_children() has the side-effect of joining any finished + # processes, effectively reaping zombies from previous unsafe_hard_deletes. + if self._spawn_process: multiprocessing.active_children() return _AutoProxyWrapper(singleton) def release(self, obj): @@ -315,25 +456,102 @@ def unsafe_hard_delete(self): to this object exist, or (b) you are ok with all existing references to this object throwing strange errors when derefrenced. """ - self._get_manager().unsafe_hard_delete_singleton(self._tag) + try: + self._get_manager().unsafe_hard_delete_singleton(self._tag) + except (EOFError, ConnectionResetError, BrokenPipeError): + pass + except Exception as e: + logging.warning( + "Exception %s when trying to hard delete shared object %s", + e, + self._tag) def _create_server(self, address_file): - # We need to be able to authenticate with both the manager and the process. - self._serving_manager = _SingletonRegistrar( - address=('localhost', 0), authkey=AUTH_KEY) - multiprocessing.current_process().authkey = AUTH_KEY - # Initialize eagerly to avoid acting as the server if there are issues. - # Note, however, that _create_server itself is called lazily. - _process_level_singleton_manager.register_singleton( - self._constructor, self._tag, initialize_eagerly=True) - self._server = self._serving_manager.get_server() - logging.info( - 'Starting proxy server at %s for shared %s', - self._server.address, - self._tag) - with open(address_file + '.tmp', 'w') as fout: - fout.write('%s:%d' % self._server.address) - os.rename(address_file + '.tmp', address_file) - t = threading.Thread(target=self._server.serve_forever, daemon=True) - t.start() - logging.info('Done starting server') + if self._spawn_process: + error_file = address_file + ".error" + + if os.path.exists(error_file): + try: + os.remove(error_file) + except OSError: + pass + + ctx = multiprocessing.get_context('spawn') + p = ctx.Process( + target=_run_server_process, + args=(address_file, self._tag, self._constructor, AUTH_KEY), + daemon=False # Must be False for nested proxies + ) + p.start() + logging.info("Parent: Waiting for %s to write address file...", self._tag) + + def cleanup_process(): + if p.is_alive(): + logging.info( + "Parent: Terminating server process %s for %s", p.pid, self._tag) + p.terminate() + p.join() + try: + if os.path.exists(address_file): + os.remove(address_file) + if os.path.exists(error_file): + os.remove(error_file) + except Exception: + pass + + atexit.register(cleanup_process) + + start_time = time.time() + last_log = start_time + while True: + if os.path.exists(address_file): + break + + if os.path.exists(error_file): + with open(error_file, 'r') as f: + error_msg = f.read() + try: + os.remove(error_file) + except OSError: + pass + + if p.is_alive(): p.terminate() + raise RuntimeError(f"Shared Server Process crashed:\n{error_msg}") + + if not p.is_alive(): + exit_code = p.exitcode + raise RuntimeError( + "Shared Server Process died unexpectedly" + f" with exit code {exit_code}") + + if time.time() - last_log > 300: + logging.warning( + "Still waiting for %s to initialize... %ss elapsed)", + self._tag, + int(time.time() - start_time)) + last_log = time.time() + + time.sleep(0.05) + + logging.info('External process successfully started for %s', self._tag) + else: + # We need to be able to authenticate with both the manager + # and the process. + self._serving_manager = _SingletonRegistrar( + address=('localhost', 0), authkey=AUTH_KEY) + multiprocessing.current_process().authkey = AUTH_KEY + # Initialize eagerly to avoid acting as the server if there are issues. + # Note, however, that _create_server itself is called lazily. + _process_level_singleton_manager.register_singleton( + self._constructor, self._tag, initialize_eagerly=True) + self._server = self._serving_manager.get_server() + logging.info( + 'Starting proxy server at %s for shared %s', + self._server.address, + self._tag) + with open(address_file + '.tmp', 'w') as fout: + fout.write('%s:%d' % self._server.address) + os.rename(address_file + '.tmp', address_file) + t = threading.Thread(target=self._server.serve_forever, daemon=True) + t.start() + logging.info('Done starting server') diff --git a/sdks/python/apache_beam/utils/multi_process_shared_test.py b/sdks/python/apache_beam/utils/multi_process_shared_test.py index 0b7957632368..f3258cf0a968 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared_test.py +++ b/sdks/python/apache_beam/utils/multi_process_shared_test.py @@ -18,6 +18,9 @@ import logging import threading +import tempfile +import os +import multiprocessing import unittest from typing import Any @@ -82,6 +85,14 @@ def __getattribute__(self, __name: str) -> Any: return object.__getattribute__(self, __name) +class SimpleClass: + def make_proxy( + self, tag: str = 'proxy_on_proxy', spawn_process: bool = False): + return multi_process_shared.MultiProcessShared( + Counter, tag=tag, always_proxy=True, + spawn_process=spawn_process).acquire() + + class MultiProcessSharedTest(unittest.TestCase): @classmethod def setUpClass(cls): @@ -193,6 +204,34 @@ def test_unsafe_hard_delete(self): self.assertEqual(counter3.increment(), 1) + def test_unsafe_hard_delete_autoproxywrapper(self): + shared1 = multi_process_shared.MultiProcessShared( + Counter, + tag='test_unsafe_hard_delete_autoproxywrapper', + always_proxy=True) + shared2 = multi_process_shared.MultiProcessShared( + Counter, + tag='test_unsafe_hard_delete_autoproxywrapper', + always_proxy=True) + + counter1 = shared1.acquire() + counter2 = shared2.acquire() + self.assertEqual(counter1.increment(), 1) + self.assertEqual(counter2.increment(), 2) + + counter2.unsafe_hard_delete() + + with self.assertRaises(Exception): + counter1.get() + with self.assertRaises(Exception): + counter2.get() + + counter3 = multi_process_shared.MultiProcessShared( + Counter, + tag='test_unsafe_hard_delete_autoproxywrapper', + always_proxy=True).acquire() + self.assertEqual(counter3.increment(), 1) + def test_unsafe_hard_delete_no_op(self): shared1 = multi_process_shared.MultiProcessShared( Counter, tag='test_unsafe_hard_delete_no_op', always_proxy=True) @@ -242,6 +281,185 @@ def test_release_always_proxy(self): with self.assertRaisesRegex(Exception, 'released'): counter1.get() + def test_proxy_on_proxy(self): + shared1 = multi_process_shared.MultiProcessShared( + SimpleClass, tag='proxy_on_proxy_main', always_proxy=True) + instance = shared1.acquire() + proxy_instance = instance.make_proxy() + self.assertEqual(proxy_instance.increment(), 1) + + +class MultiProcessSharedSpawnProcessTest(unittest.TestCase): + def setUp(self): + tempdir = tempfile.gettempdir() + for tag in ['basic', + 'proxy_on_proxy', + 'proxy_on_proxy_main', + 'main', + 'to_delete', + 'mix1', + 'mix2' + 'test_process_exit']: + for ext in ['', '.address', '.address.error']: + try: + os.remove(os.path.join(tempdir, tag + ext)) + except OSError: + pass + + def tearDown(self): + for p in multiprocessing.active_children(): + p.terminate() + p.join() + + def test_call(self): + shared = multi_process_shared.MultiProcessShared( + Counter, tag='basic', always_proxy=True, spawn_process=True).acquire() + self.assertEqual(shared.get(), 0) + self.assertEqual(shared.increment(), 1) + self.assertEqual(shared.increment(10), 11) + self.assertEqual(shared.increment(value=10), 21) + self.assertEqual(shared.get(), 21) + + def test_proxy_on_proxy(self): + shared1 = multi_process_shared.MultiProcessShared( + SimpleClass, tag='main', always_proxy=True) + instance = shared1.acquire() + proxy_instance = instance.make_proxy(spawn_process=True) + self.assertEqual(proxy_instance.increment(), 1) + proxy_instance.unsafe_hard_delete() + + proxy_instance2 = instance.make_proxy(tag='proxy_2', spawn_process=True) + self.assertEqual(proxy_instance2.increment(), 1) + + def test_unsafe_hard_delete_autoproxywrapper(self): + shared1 = multi_process_shared.MultiProcessShared( + Counter, tag='to_delete', always_proxy=True, spawn_process=True) + shared2 = multi_process_shared.MultiProcessShared( + Counter, tag='to_delete', always_proxy=True, spawn_process=True) + counter3 = multi_process_shared.MultiProcessShared( + Counter, tag='basic', always_proxy=True, spawn_process=True).acquire() + + counter1 = shared1.acquire() + counter2 = shared2.acquire() + self.assertEqual(counter1.increment(), 1) + self.assertEqual(counter2.increment(), 2) + + counter2.unsafe_hard_delete() + + with self.assertRaises(Exception): + counter1.get() + with self.assertRaises(Exception): + counter2.get() + + counter4 = multi_process_shared.MultiProcessShared( + Counter, tag='to_delete', always_proxy=True, + spawn_process=True).acquire() + + self.assertEqual(counter3.increment(), 1) + self.assertEqual(counter4.increment(), 1) + + def test_mix_usage(self): + shared1 = multi_process_shared.MultiProcessShared( + Counter, tag='mix1', always_proxy=True, spawn_process=False).acquire() + shared2 = multi_process_shared.MultiProcessShared( + Counter, tag='mix2', always_proxy=True, spawn_process=True).acquire() + + self.assertEqual(shared1.get(), 0) + self.assertEqual(shared1.increment(), 1) + self.assertEqual(shared2.get(), 0) + self.assertEqual(shared2.increment(), 1) + + def test_process_exits_on_unsafe_hard_delete(self): + shared = multi_process_shared.MultiProcessShared( + Counter, tag='test_process_exit', always_proxy=True, spawn_process=True) + obj = shared.acquire() + + self.assertEqual(obj.increment(), 1) + + children = multiprocessing.active_children() + server_process = None + for p in children: + if p.pid != os.getpid() and p.is_alive(): + server_process = p + break + + self.assertIsNotNone( + server_process, "Could not find spawned server process") + obj.unsafe_hard_delete() + server_process.join(timeout=5) + + self.assertFalse( + server_process.is_alive(), + f"Server process {server_process.pid} is still alive after hard delete") + self.assertIsNotNone( + server_process.exitcode, "Process has no exit code (did not exit)") + + with self.assertRaises(Exception): + obj.get() + + def test_process_exits_on_unsafe_hard_delete_with_manager(self): + shared = multi_process_shared.MultiProcessShared( + Counter, tag='test_process_exit', always_proxy=True, spawn_process=True) + obj = shared.acquire() + + self.assertEqual(obj.increment(), 1) + + children = multiprocessing.active_children() + server_process = None + for p in children: + if p.pid != os.getpid() and p.is_alive(): + server_process = p + break + + self.assertIsNotNone( + server_process, "Could not find spawned server process") + shared.unsafe_hard_delete() + server_process.join(timeout=5) + + self.assertFalse( + server_process.is_alive(), + f"Server process {server_process.pid} is still alive after hard delete") + self.assertIsNotNone( + server_process.exitcode, "Process has no exit code (did not exit)") + + with self.assertRaises(Exception): + obj.get() + + def test_zombie_reaping_on_acquire(self): + shared1 = multi_process_shared.MultiProcessShared( + Counter, tag='test_zombie_reap', always_proxy=True, spawn_process=True) + obj = shared1.acquire() + + children = multiprocessing.active_children() + server_pid = next( + p.pid for p in children if p.is_alive() and p.pid != os.getpid()) + + obj.unsafe_hard_delete() + + try: + os.kill(server_pid, 0) + is_zombie = True + except OSError: + is_zombie = False + self.assertTrue( + is_zombie, + f"Server process {server_pid} was reaped too early before acquire()") + + shared2 = multi_process_shared.MultiProcessShared( + Counter, tag='unrelated_tag', always_proxy=True, spawn_process=True) + _ = shared2.acquire() + + pid_exists = True + try: + os.kill(server_pid, 0) + except OSError: + pid_exists = False + + self.assertFalse( + pid_exists, + f"Old server process {server_pid} was not reaped by acquire() sweep") + shared2.unsafe_hard_delete() + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO)