From 10ee00afb5ac63bde3d9cc2013fa1e5db393f404 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Mon, 15 Dec 2025 21:52:49 +0000 Subject: [PATCH] Add model manager that automatically manage model across processes --- .../apache_beam/ml/inference/model_manager.py | 669 ++++++++++++++++++ .../ml/inference/model_manager_test.py | 548 ++++++++++++++ 2 files changed, 1217 insertions(+) create mode 100644 sdks/python/apache_beam/ml/inference/model_manager.py create mode 100644 sdks/python/apache_beam/ml/inference/model_manager_test.py diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py b/sdks/python/apache_beam/ml/inference/model_manager.py new file mode 100644 index 000000000000..e56b8ee8d03f --- /dev/null +++ b/sdks/python/apache_beam/ml/inference/model_manager.py @@ -0,0 +1,669 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Module for managing ML models in Apache Beam pipelines. + +This module provides classes and functions to efficiently manage multiple +machine learning models within Apache Beam pipelines. It includes functionality +for loading, caching, and updating models using multi-process shared memory, +ensuring that models are reused across different workers to optimize resource +usage and performance. +""" + +import uuid +import time +import threading +import subprocess +import logging +import gc +import numpy as np +from scipy.optimize import nnls +import torch +import heapq +import itertools +from collections import defaultdict, deque, Counter, OrderedDict +from typing import Dict, Any, Tuple, Optional, Callable + +logger = logging.getLogger(__name__) + + +class GPUMonitor: + def __init__( + self, + fallback_memory_mb: float = 16000.0, + poll_interval: float = 0.5, + peak_window_seconds: float = 30.0): + self._current_usage = 0.0 + self._peak_usage = 0.0 + self._total_memory = fallback_memory_mb + self._poll_interval = poll_interval + self._peak_window_seconds = peak_window_seconds + self._memory_history = deque() + self._running = False + self._thread = None + self._lock = threading.Lock() + self._gpu_available = self._detect_hardware() + + def _detect_hardware(self): + try: + cmd = [ + "nvidia-smi", + "--query-gpu=memory.total", + "--format=csv,noheader,nounits" + ] + output = subprocess.check_output(cmd, text=True).strip() + self._total_memory = float(output) + return True + except (FileNotFoundError, subprocess.CalledProcessError): + logger.warning( + "nvidia-smi not found or failed. Defaulting total memory to %s MB", + self._total_memory) + return False + except Exception as e: + logger.warning( + "Error parsing nvidia-smi output: %s. " + "Defaulting total memory to %s MB", + e, + self._total_memory) + return False + + def start(self): + if self._running or not self._gpu_available: + return + self._running = True + self._thread = threading.Thread(target=self._poll_loop, daemon=True) + self._thread.start() + + def stop(self): + self._running = False + if self._thread: + self._thread.join() + + def reset_peak(self): + with self._lock: + now = time.time() + self._memory_history.clear() + self._memory_history.append((now, self._current_usage)) + self._peak_usage = self._current_usage + + def get_stats(self) -> Tuple[float, float, float]: + with self._lock: + return self._current_usage, self._peak_usage, self._total_memory + + def refresh(self): + """Forces an immediate poll of the GPU.""" + usage = self._get_nvidia_smi_used() + now = time.time() + with self._lock: + self._current_usage = usage + self._memory_history.append((now, usage)) + # Recalculate peak immediately + while self._memory_history and (now - self._memory_history[0][0] + > self._peak_window_seconds): + self._memory_history.popleft() + self._peak_usage = ( + max(m for _, m in self._memory_history) + if self._memory_history else usage) + + def _get_nvidia_smi_used(self) -> float: + try: + cmd = "nvidia-smi --query-gpu=memory.free --format=csv,noheader,nounits" + output = subprocess.check_output(cmd, shell=True).decode("utf-8").strip() + free_memory = float(output) + return self._total_memory - free_memory + except Exception: + return 0.0 + + def _poll_loop(self): + while self._running: + usage = self._get_nvidia_smi_used() + now = time.time() + with self._lock: + self._current_usage = usage + self._memory_history.append((now, usage)) + while self._memory_history and (now - self._memory_history[0][0] + > self._peak_window_seconds): + self._memory_history.popleft() + self._peak_usage = ( + max(m for _, m in self._memory_history) + if self._memory_history else usage) + time.sleep(self._poll_interval) + + +class ResourceEstimator: + def __init__(self, smoothing_factor: float = 0.2, min_data_points: int = 5): + self.smoothing_factor = smoothing_factor + self.min_data_points = min_data_points + self.estimates: Dict[str, float] = {} + self.history = defaultdict(lambda: deque(maxlen=20)) + self.known_models = set() + self._lock = threading.Lock() + + def is_unknown(self, model_tag: str) -> bool: + with self._lock: + return model_tag not in self.estimates + + def get_estimate(self, model_tag: str, default_mb: float = 4000.0) -> float: + with self._lock: + return self.estimates.get(model_tag, default_mb) + + def set_initial_estimate(self, model_tag: str, cost: float): + with self._lock: + self.estimates[model_tag] = cost + self.known_models.add(model_tag) + logger.info("Initial Profile for %s: %s MB", model_tag, cost) + + def add_observation( + self, active_snapshot: Dict[str, int], peak_memory: float): + logger.info( + "Adding Observation: Snapshot=%s, PeakMemory=%.1f MB", + active_snapshot, + peak_memory) + if not active_snapshot: + return + with self._lock: + config_key = tuple(sorted(active_snapshot.items())) + self.history[config_key].append(peak_memory) + for tag in active_snapshot: + self.known_models.add(tag) + self._solve() + + def _solve(self): + """ + Solves Ax=b using raw readings (no pre-averaging) and NNLS. + This creates a 'tall' matrix A where every memory reading is + a separate equation. + """ + unique = sorted(list(self.known_models)) + + # We need to build the matrix first to know if we have enough data points + A, b = [], [] + + for config_key, mem_values in self.history.items(): + if not mem_values: + continue + + # 1. Create the feature row for this configuration ONCE + # (It represents the model counts + bias) + counts = dict(config_key) + feature_row = [counts.get(model, 0) for model in unique] + feature_row.append(1) # Bias column + + # 2. Add a separate row to the matrix for EVERY individual reading + # Instead of averaging, we flatten the history into the matrix + for reading in mem_values: + A.append(feature_row) # The inputs (models) stay the same + b.append(reading) # The output (memory) varies due to noise + + # Convert to numpy for SciPy + A = np.array(A) + b = np.array(b) + + if len( + self.history.keys()) < len(unique) + 1 or len(A) < self.min_data_points: + # Not enough data to solve yet + return + + logger.info( + "Solving with %s total observations for %s models.", + len(A), + len(unique)) + + try: + # Solve using Non-Negative Least Squares + # x will be >= 0 + x, _ = nnls(A, b) + + weights = x[:-1] + bias = x[-1] + + for i, model in enumerate(unique): + calculated_cost = weights[i] + + if model in self.estimates: + old = self.estimates[model] + new = (old * (1 - self.smoothing_factor)) + ( + calculated_cost * self.smoothing_factor) + self.estimates[model] = new + else: + self.estimates[model] = calculated_cost + + logger.info( + "Updated Estimate for %s: %.1f MB", model, self.estimates[model]) + logger.info("System Bias: %s MB", bias) + + except Exception as e: + logger.error("Solver failed: %s", e) + + +class TrackedModelProxy: + def __init__(self, obj): + object.__setattr__(self, "_wrapped_obj", obj) + object.__setattr__(self, "_beam_tracking_id", str(uuid.uuid4())) + + def __getattr__(self, name): + return getattr(self._wrapped_obj, name) + + def __setattr__(self, name, value): + setattr(self._wrapped_obj, name, value) + + def __call__(self, *args, **kwargs): + return self._wrapped_obj(*args, **kwargs) + + def __setstate__(self, state): + self.__dict__.update(state) + + def __getstate__(self): + return self.__dict__ + + def __str__(self): + return str(self._wrapped_obj) + + def __repr__(self): + return repr(self._wrapped_obj) + + def __dir__(self): + return dir(self._wrapped_obj) + + def unsafe_hard_delete(self): + if hasattr(self._wrapped_obj, "unsafe_hard_delete"): + self._wrapped_obj.unsafe_hard_delete() + + +class ModelManager: + _lock = threading.Lock() + + def __init__( + self, + monitor: Optional['GPUMonitor'] = None, + slack_percentage: float = 0.10, + poll_interval: float = 0.5, + peak_window_seconds: float = 30.0, + min_data_points: int = 5, + smoothing_factor: float = 0.2, + eviction_cooldown_seconds: float = 10.0, + min_model_copies: int = 1): + + self._estimator = ResourceEstimator( + min_data_points=min_data_points, smoothing_factor=smoothing_factor) + self._monitor = monitor if monitor else GPUMonitor( + poll_interval=poll_interval, peak_window_seconds=peak_window_seconds) + self._slack_percentage = slack_percentage + + self._eviction_cooldown = eviction_cooldown_seconds + self._min_model_copies = min_model_copies + + # Resource State + self._models = defaultdict(list) + self._idle_lru = OrderedDict() + self._active_counts = Counter() + self._total_active_jobs = 0 + self._pending_reservations = 0.0 + + self._isolation_mode = False + self._pending_isolation_count = 0 + self._isolation_baseline = 0.0 + + self._wait_queue = [] + self._ticket_counter = itertools.count() + self._cv = threading.Condition() + self._load_lock = threading.Lock() + + self._monitor.start() + + def all_models(self, tag) -> list[Any]: + return self._models[tag] + + def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: + current_priority = 0 if self._estimator.is_unknown(tag) else 1 + ticket_num = next(self._ticket_counter) + my_id = object() + + with self._cv: + # FAST PATH + if self._pending_isolation_count == 0 and not self._isolation_mode: + cached_instance = self._try_grab_from_lru(tag) + if cached_instance: + return cached_instance + + # SLOW PATH + logger.info( + "Acquire Queued: tag=%s, priority=%d " + "total models count=%s ticket num=%s", + tag, + current_priority, + len(self._models[tag]), + ticket_num) + heapq.heappush( + self._wait_queue, (current_priority, ticket_num, my_id, tag)) + + should_spawn = False + est_cost = 0.0 + is_unknown = False + + try: + while True: + if not self._wait_queue or self._wait_queue[0][2] is not my_id: + logger.info( + "Waiting for its turn: tag=%s ticket num=%s", tag, ticket_num) + self._cv.wait() + continue + + real_is_unknown = self._estimator.is_unknown(tag) + real_priority = 0 if real_is_unknown else 1 + + if current_priority != real_priority: + heapq.heappop(self._wait_queue) + current_priority = real_priority + heapq.heappush( + self._wait_queue, (current_priority, ticket_num, my_id, tag)) + self._cv.notify_all() + continue + + cached_instance = self._try_grab_from_lru(tag) + if cached_instance: + return cached_instance + + is_unknown = real_is_unknown + + # Path A: Isolation + if is_unknown: + if self._total_active_jobs > 0: + logger.info( + "Waiting to enter isolation: tag=%s ticket num=%s", + tag, + ticket_num) + self._cv.wait() + continue + + logger.info("Unknown model %s detected. Flushing GPU.", tag) + self._delete_all_models() + + self._isolation_mode = True + self._total_active_jobs += 1 + self._isolation_baseline, _, _ = self._monitor.get_stats() + self._monitor.reset_peak() + should_spawn = True + break + + # Path B: Concurrent + else: + if self._pending_isolation_count > 0 or self._isolation_mode: + logger.info( + "Waiting due to isolation in progress: tag=%s ticket num%s", + tag, + ticket_num) + self._cv.wait() + continue + + curr, _, total = self._monitor.get_stats() + est_cost = self._estimator.get_estimate(tag) + limit = total * (1 - self._slack_percentage) + + # Use current usage for capacity check (ignore old spikes) + if (curr + self._pending_reservations + est_cost) <= limit: + self._pending_reservations += est_cost + self._total_active_jobs += 1 + self._active_counts[tag] += 1 + should_spawn = True + break + + # Evict to make space (passing tag to check demand/existence) + if self._evict_to_make_space(limit, est_cost, requesting_tag=tag): + continue + + idle_count = 0 + other_idle_count = 0 + for item in self._idle_lru.items(): + if item[1][0] == tag: + idle_count += 1 + else: + other_idle_count += 1 + total_model_count = 0 + for _, instances in self._models.items(): + total_model_count += len(instances) + curr, _, _ = self._monitor.get_stats() + logger.info( + "Waiting for resources to free up: " + "tag=%s ticket num%s model count=%s " + "idle count=%s resource usage=%.1f MB " + "total models count=%s other idle=%s", + tag, + ticket_num, + len(self._models[tag]), + idle_count, + curr, + total_model_count, + other_idle_count) + self._cv.wait(timeout=10.0) + + finally: + if self._wait_queue and self._wait_queue[0][2] is my_id: + heapq.heappop(self._wait_queue) + else: + for i, item in enumerate(self._wait_queue): + if item[2] is my_id: + self._wait_queue.pop(i) + heapq.heapify(self._wait_queue) + self._cv.notify_all() + + if should_spawn: + return self._spawn_new_model(tag, loader_func, is_unknown, est_cost) + + def release_model(self, tag: str, instance: Any): + with self._cv: + try: + self._total_active_jobs -= 1 + if self._active_counts[tag] > 0: + self._active_counts[tag] -= 1 + + self._idle_lru[id(instance)] = (tag, instance, time.time()) + + _, peak_during_job, _ = self._monitor.get_stats() + + if self._isolation_mode and self._active_counts[tag] == 0: + cost = max(0, peak_during_job - self._isolation_baseline) + self._estimator.set_initial_estimate(tag, cost) + self._isolation_mode = False + self._isolation_baseline = 0.0 + else: + snapshot = { + t: len(instances) + for t, instances in self._models.items() if len(instances) > 0 + } + if snapshot: + self._estimator.add_observation(snapshot, peak_during_job) + + finally: + self._cv.notify_all() + + def _try_grab_from_lru(self, tag: str) -> Any: + target_key = None + target_instance = None + + for key, (t, instance, _) in reversed(self._idle_lru.items()): + if t == tag: + target_key = key + target_instance = instance + break + + if target_instance: + del self._idle_lru[target_key] + self._active_counts[tag] += 1 + self._total_active_jobs += 1 + return target_instance + + logger.info("No idle model found for tag: %s", tag) + return None + + def _evict_to_make_space( + self, limit: float, est_cost: float, requesting_tag: str) -> bool: + """ + Evicts models based on Demand Magnitude + Tiers. + Crucially: If we have 0 active copies of 'requesting_tag', we FORCE eviction + of the lowest-demand candidate to avoid starvation. + """ + evicted_something = False + curr, _, _ = self._monitor.get_stats() + projected_usage = curr + self._pending_reservations + est_cost + + if projected_usage <= limit: + return False + + now = time.time() + + demand_map = Counter() + for item in self._wait_queue: + if len(item) >= 4: + demand_map[item[3]] += 1 + + my_demand = demand_map[requesting_tag] + am_i_starving = len(self._models[requesting_tag]) == 0 + + candidates = [] + for key, (tag, instance, release_time) in self._idle_lru.items(): + candidate_demand = demand_map[tag] + + if not am_i_starving: + if candidate_demand >= my_demand: + continue + + age = now - release_time + is_cold = age >= self._eviction_cooldown + + total_copies = len(self._models[tag]) + is_surplus = total_copies > self._min_model_copies + + if is_cold and is_surplus: tier = 0 + elif not is_cold and is_surplus: tier = 1 + elif is_cold and not is_surplus: tier = 2 + else: tier = 3 + + score = (candidate_demand * 10) + tier + + candidates.append((score, release_time, key, tag, instance)) + + candidates.sort(key=lambda x: (x[0], x[1])) + + for score, _, key, tag, instance in candidates: + if projected_usage <= limit: + break + + if key not in self._idle_lru: continue + + self._perform_eviction(key, tag, instance, score) + evicted_something = True + + curr, _, _ = self._monitor.get_stats() + projected_usage = curr + self._pending_reservations + est_cost + + return evicted_something + + def _perform_eviction(self, key, tag, instance, score): + logger.info("Evicting Model: %s (Score %d)", tag, score) + curr, _, _ = self._monitor.get_stats() + logger.info("Resource Usage Before Eviction: %.1f MB", curr) + + if key in self._idle_lru: + del self._idle_lru[key] + + target_id = instance._beam_tracking_id + for i, inst in enumerate(self._models[tag]): + if inst._beam_tracking_id == target_id: + del self._models[tag][i] + break + + instance.unsafe_hard_delete() + del instance + gc.collect() + torch.cuda.empty_cache() + self._monitor.refresh() + self._monitor.reset_peak() + curr, _, _ = self._monitor.get_stats() + logger.info("Resource Usage After Eviction: %.1f MB", curr) + + def _spawn_new_model(self, tag, loader_func, is_unknown, est_cost): + try: + with self._load_lock: + logger.info("Loading Model: %s (Unknown: %s)", tag, is_unknown) + isolation_baseline_snap, _, _ = self._monitor.get_stats() + instance = TrackedModelProxy(loader_func()) + _, peak_during_load, _ = self._monitor.get_stats() + + with self._cv: + snapshot = {tag: 1} + self._estimator.add_observation( + snapshot, peak_during_load - isolation_baseline_snap) + + if not is_unknown: + self._pending_reservations = max( + 0.0, self._pending_reservations - est_cost) + self._models[tag].append(instance) + return instance + + except Exception as e: + logger.error("Load Failed: %s. Error: %s", tag, e) + with self._cv: + self._total_active_jobs -= 1 + if is_unknown: + self._isolation_mode = False + self._isolation_baseline = 0.0 + else: + self._pending_reservations = max( + 0.0, self._pending_reservations - est_cost) + self._active_counts[tag] -= 1 + self._cv.notify_all() + raise e + + def _delete_all_models(self): + self._idle_lru.clear() + for _, instances in self._models.items(): + for instance in instances: + if hasattr(instance, "unsafe_hard_delete"): + instance.unsafe_hard_delete() + del instance + self._models.clear() + self._active_counts.clear() + gc.collect() + torch.cuda.empty_cache() + self._monitor.refresh() + self._monitor.reset_peak() + + def _force_reset(self): + logger.warning("Force Reset Triggered") + self._delete_all_models() + self._models = defaultdict(list) + self._idle_lru = OrderedDict() + self._active_counts = Counter() + self._wait_queue = [] + self._total_active_jobs = 0 + self._pending_reservations = 0.0 + self._isolation_mode = False + self._pending_isolation_count = 0 + self._isolation_baseline = 0.0 + + def shutdown(self): + self._delete_all_models() + gc.collect() + torch.cuda.empty_cache() + self._monitor.stop() + + def __del__(self): + self.shutdown() + + def __exit__(self, exc_type, exc_value, traceback): + self.shutdown() diff --git a/sdks/python/apache_beam/ml/inference/model_manager_test.py b/sdks/python/apache_beam/ml/inference/model_manager_test.py new file mode 100644 index 000000000000..7412ea6a6c64 --- /dev/null +++ b/sdks/python/apache_beam/ml/inference/model_manager_test.py @@ -0,0 +1,548 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +import time +import threading +import random +from concurrent.futures import ThreadPoolExecutor +from unittest.mock import patch + +try: + from apache_beam.ml.inference.model_manager import ModelManager, GPUMonitor, ResourceEstimator +except ImportError as e: + raise unittest.SkipTest("Model Manager dependencies are not installed") + + +class MockGPUMonitor: + """ + Simulates GPU hardware with cumulative memory tracking. + Allows simulating specific allocation spikes and baseline usage. + """ + def __init__(self, total_memory=12000.0, peak_window: int = 5): + self._current = 0.0 + self._peak = 0.0 + self._total = total_memory + self._lock = threading.Lock() + self.running = False + self.history = [] + self.peak_window = peak_window + + def start(self): + self.running = True + + def stop(self): + self.running = False + + def get_stats(self): + with self._lock: + return self._current, self._peak, self._total + + def reset_peak(self): + with self._lock: + self._peak = self._current + self.history = [self._current] + + def set_usage(self, current_mb): + """Sets absolute usage (legacy helper).""" + with self._lock: + self._current = current_mb + self._peak = max(self._peak, current_mb) + + def allocate(self, amount_mb): + """Simulates memory allocation (e.g., tensors loaded to VRAM).""" + with self._lock: + self._current += amount_mb + self.history.append(self._current) + if len(self.history) > self.peak_window: + self.history.pop(0) + self._peak = max(self.history) + + def free(self, amount_mb): + """Simulates memory freeing (not used often if pooling is active).""" + with self._lock: + self._current = max(0.0, self._current - amount_mb) + self.history.append(self._current) + if len(self.history) > self.peak_window: + self.history.pop(0) + self._peak = max(self.history) + + def refresh(self): + """Simulates a refresh of the monitor stats (no-op for mock).""" + pass + + +class MockModel: + def __init__(self, name, size, monitor): + self.name = name + self.size = size + self.monitor = monitor + self.deleted = False + self.monitor.allocate(size) + + def unsafe_hard_delete(self): + if not self.deleted: + self.monitor.free(self.size) + self.deleted = True + + +class TestModelManager(unittest.TestCase): + def setUp(self): + """Force reset the Singleton ModelManager before every test.""" + ModelManager._instance = None + self.mock_monitor = MockGPUMonitor() + self.manager = ModelManager(monitor=self.mock_monitor) + + def tearDown(self): + self.manager.shutdown() + + def test_model_manager_capacity_check(self): + """ + Test that the manager blocks when spawning models exceeds the limit, + and unblocks when resources become available (via reuse). + """ + model_name = "known_model" + model_cost = 3000.0 + self.manager._estimator.set_initial_estimate(model_name, model_cost) + acquired_refs = [] + + def loader(): + self.mock_monitor.allocate(model_cost) + return model_name + + # 1. Saturate GPU with 3 models (9000 MB usage) + for _ in range(3): + inst = self.manager.acquire_model(model_name, loader) + acquired_refs.append(inst) + + # 2. Spawn one more (Should Block because 9000 + 3000 > Limit) + def run_inference(): + return self.manager.acquire_model(model_name, loader) + + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(run_inference) + try: + future.result(timeout=0.5) + self.fail("Should have blocked due to capacity") + except TimeoutError: + pass + + # 3. Release resources to unblock + item_to_release = acquired_refs.pop() + self.manager.release_model(model_name, item_to_release) + + result = future.result(timeout=2.0) + self.assertIsNotNone(result) + self.assertEqual(result, item_to_release) + + def test_model_manager_unknown_model_runs_isolated(self): + """Test that a model with no history runs in isolation.""" + model_name = "unknown_model_v1" + self.assertTrue(self.manager._estimator.is_unknown(model_name)) + + def dummy_loader(): + time.sleep(0.05) + return "model_instance" + + instance = self.manager.acquire_model(model_name, dummy_loader) + + self.assertTrue(self.manager._isolation_mode) + self.assertEqual(self.manager._total_active_jobs, 1) + + self.manager.release_model(model_name, instance) + self.assertFalse(self.manager._isolation_mode) + self.assertFalse(self.manager._estimator.is_unknown(model_name)) + + def test_model_manager_concurrent_execution(self): + """Test that multiple small known models can run together.""" + model_a = "small_model_a" + model_b = "small_model_b" + + self.manager._estimator.set_initial_estimate(model_a, 1000.0) + self.manager._estimator.set_initial_estimate(model_b, 1000.0) + self.mock_monitor.set_usage(1000.0) + + inst_a = self.manager.acquire_model(model_a, lambda: "A") + inst_b = self.manager.acquire_model(model_b, lambda: "B") + + self.assertEqual(self.manager._total_active_jobs, 2) + + self.manager.release_model(model_a, inst_a) + self.manager.release_model(model_b, inst_b) + self.assertEqual(self.manager._total_active_jobs, 0) + + def test_model_manager_concurrent_mixed_workload_convergence(self): + """ + Simulates a production environment with multiple model types running + concurrently. Verifies that the estimator converges. + """ + TRUE_COSTS = {"model_small": 1500.0, "model_medium": 3000.0} + + def run_job(model_name): + cost = TRUE_COSTS[model_name] + + def loader(): + model = MockModel(model_name, cost, self.mock_monitor) + return model + + instance = self.manager.acquire_model(model_name, loader) + time.sleep(random.uniform(0.01, 0.05)) + self.manager.release_model(model_name, instance) + + # Create a workload stream + workload = ["model_small"] * 15 + ["model_medium"] * 15 + random.shuffle(workload) + + with ThreadPoolExecutor(max_workers=8) as executor: + futures = [executor.submit(run_job, name) for name in workload] + for f in futures: + f.result() + + est_small = self.manager._estimator.get_estimate("model_small") + est_med = self.manager._estimator.get_estimate("model_medium") + + self.assertAlmostEqual(est_small, TRUE_COSTS["model_small"], delta=100.0) + self.assertAlmostEqual(est_med, TRUE_COSTS["model_medium"], delta=100.0) + + def test_model_manager_oom_recovery(self): + """Test that the manager recovers state if a loader crashes.""" + model_name = "crasher_model" + self.manager._estimator.set_initial_estimate(model_name, 1000.0) + + def crashing_loader(): + raise RuntimeError("CUDA OOM or similar") + + with self.assertRaises(RuntimeError): + self.manager.acquire_model(model_name, crashing_loader) + + self.assertEqual(self.manager._total_active_jobs, 0) + self.assertEqual(self.manager._pending_reservations, 0.0) + self.assertFalse(self.manager._cv._is_owned()) + + def test_model_managaer_force_reset_on_exception(self): + """Test that force_reset clears all models from the manager.""" + model_name = "test_model" + + def dummy_loader(): + self.mock_monitor.allocate(1000.0) + raise RuntimeError("Simulated loader exception") + + try: + instance = self.manager.acquire_model( + model_name, lambda: "model_instance") + self.manager.release_model(model_name, instance) + instance = self.manager.acquire_model(model_name, dummy_loader) + except RuntimeError: + self.manager._force_reset() + self.assertTrue(len(self.manager._models[model_name]) == 0) + self.assertEqual(self.manager._total_active_jobs, 0) + self.assertEqual(self.manager._pending_reservations, 0.0) + self.assertFalse(self.manager._isolation_mode) + pass + + instance = self.manager.acquire_model(model_name, lambda: "model_instance") + self.manager.release_model(model_name, instance) + + def test_single_model_convergence_with_fluctuations(self): + """ + Tests that the estimator converges to the true usage with fluctuations. + """ + model_name = "fluctuating_model" + model_cost = 3000.0 + load_cost = 2000.0 + + def loader(): + self.mock_monitor.allocate(load_cost) + return model_name + + model = self.manager.acquire_model(model_name, loader) + self.manager.release_model(model_name, model) + initial_est = self.manager._estimator.get_estimate(model_name) + self.assertEqual(initial_est, load_cost) + + def run_inference(): + model = self.manager.acquire_model(model_name, loader) + noise = model_cost - load_cost + random.uniform(-300.0, 300.0) + self.mock_monitor.allocate(noise) + time.sleep(0.1) + self.mock_monitor.free(noise) + self.manager.release_model(model_name, model) + return + + with ThreadPoolExecutor(max_workers=8) as executor: + futures = [executor.submit(run_inference) for _ in range(100)] + + for f in futures: + f.result() + + est_cost = self.manager._estimator.get_estimate(model_name) + self.assertAlmostEqual(est_cost, model_cost, delta=100.0) + + +class TestModelManagerEviction(unittest.TestCase): + def setUp(self): + self.mock_monitor = MockGPUMonitor(total_memory=12000.0) + ModelManager._instance = None + self.manager = ModelManager( + monitor=self.mock_monitor, + slack_percentage=0.0, + min_data_points=1, + eviction_cooldown_seconds=10.0, + min_model_copies=1) + + def tearDown(self): + self.manager.shutdown() + + def create_loader(self, name, size): + return lambda: MockModel(name, size, self.mock_monitor) + + def test_basic_lru_eviction(self): + self.manager._estimator.set_initial_estimate("A", 4000) + self.manager._estimator.set_initial_estimate("B", 4000) + self.manager._estimator.set_initial_estimate("C", 5000) + + model_a = self.manager.acquire_model("A", self.create_loader("A", 4000)) + self.manager.release_model("A", model_a) + + model_b = self.manager.acquire_model("B", self.create_loader("B", 4000)) + self.manager.release_model("B", model_b) + + key_a = list(self.manager._idle_lru.keys())[0] + self.manager._idle_lru[key_a] = ("A", model_a, time.time() - 20.0) + + key_b = list(self.manager._idle_lru.keys())[1] + self.manager._idle_lru[key_b] = ("B", model_b, time.time() - 20.0) + + model_a_again = self.manager.acquire_model( + "A", self.create_loader("A", 4000)) + self.manager.release_model("A", model_a_again) + + self.manager.acquire_model("C", self.create_loader("C", 5000)) + + self.assertEqual(len(self.manager.all_models("B")), 0) + self.assertEqual(len(self.manager.all_models("A")), 1) + + def test_chained_eviction(self): + self.manager._estimator.set_initial_estimate("big_guy", 8000) + models = [] + for i in range(4): + name = f"small_{i}" + m = self.manager.acquire_model(name, self.create_loader(name, 3000)) + self.manager.release_model(name, m) + models.append(m) + + self.manager.acquire_model("big_guy", self.create_loader("big_guy", 8000)) + + self.assertTrue(models[0].deleted) + self.assertTrue(models[1].deleted) + self.assertTrue(models[2].deleted) + self.assertFalse(models[3].deleted) + + def test_active_models_are_protected(self): + self.manager._estimator.set_initial_estimate("A", 6000) + self.manager._estimator.set_initial_estimate("B", 4000) + self.manager._estimator.set_initial_estimate("C", 4000) + + model_a = self.manager.acquire_model("A", self.create_loader("A", 6000)) + model_b = self.manager.acquire_model("B", self.create_loader("B", 4000)) + self.manager.release_model("B", model_b) + + key_b = list(self.manager._idle_lru.keys())[0] + self.manager._idle_lru[key_b] = ("B", model_b, time.time() - 20.0) + + def acquire_c(): + return self.manager.acquire_model("C", self.create_loader("C", 4000)) + + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(acquire_c) + model_c = future.result(timeout=2.0) + + self.assertTrue(model_b.deleted) + self.assertFalse(model_a.deleted) + + self.manager.release_model("A", model_a) + self.manager.release_model("C", model_c) + + def test_unknown_model_clears_memory(self): + self.manager._estimator.set_initial_estimate("A", 2000) + model_a = self.manager.acquire_model("A", self.create_loader("A", 2000)) + self.manager.release_model("A", model_a) + self.assertFalse(model_a.deleted) + + self.assertTrue(self.manager._estimator.is_unknown("X")) + model_x = self.manager.acquire_model("X", self.create_loader("X", 10000)) + + self.assertTrue(model_a.deleted, "Model A should be deleted for isolation") + self.assertEqual(len(self.manager.all_models("A")), 0) + self.assertTrue(self.manager._isolation_mode) + self.manager.release_model("X", model_x) + + def test_concurrent_eviction_pressure(self): + def worker(idx): + name = f"model_{idx % 5}" + try: + m = self.manager.acquire_model(name, self.create_loader(name, 4000)) + time.sleep(0.001) + self.manager.release_model(name, m) + except Exception: + pass + + with ThreadPoolExecutor(max_workers=8) as executor: + futures = [executor.submit(worker, i) for i in range(50)] + for f in futures: + f.result() + + curr, _, _ = self.mock_monitor.get_stats() + expected_usage = 0 + for _, instances in self.manager._models.items(): + expected_usage += len(instances) * 4000 + + self.assertAlmostEqual(curr, expected_usage) + + def test_starvation_prevention_overrides_demand(self): + self.manager._estimator.set_initial_estimate("A", 12000) + m_a = self.manager.acquire_model("A", self.create_loader("A", 12000)) + self.manager.release_model("A", m_a) + + def cycle_a(): + try: + m = self.manager.acquire_model("A", self.create_loader("A", 12000)) + time.sleep(0.3) + self.manager.release_model("A", m) + except Exception: + pass + + executor = ThreadPoolExecutor(max_workers=5) + for _ in range(5): + executor.submit(cycle_a) + + def acquire_b(): + return self.manager.acquire_model("B", self.create_loader("B", 4000)) + + b_future = executor.submit(acquire_b) + model_b = b_future.result() + + self.assertTrue(m_a.deleted) + self.manager.release_model("B", model_b) + executor.shutdown(wait=True) + + +class TestGPUMonitor(unittest.TestCase): + def setUp(self): + self.subprocess_patcher = patch('subprocess.check_output') + self.mock_subprocess = self.subprocess_patcher.start() + + def tearDown(self): + self.subprocess_patcher.stop() + + def test_init_hardware_detected(self): + """Test that init correctly reads total memory when nvidia-smi exists.""" + self.mock_subprocess.return_value = "24576" + monitor = GPUMonitor() + self.assertTrue(monitor._gpu_available) + self.assertEqual(monitor._total_memory, 24576.0) + + def test_init_hardware_missing(self): + """Test fallback behavior when nvidia-smi is missing.""" + self.mock_subprocess.side_effect = FileNotFoundError() + monitor = GPUMonitor(fallback_memory_mb=12000.0) + self.assertFalse(monitor._gpu_available) + self.assertEqual(monitor._total_memory, 12000.0) + + @patch('time.sleep') + def test_polling_updates_stats(self, mock_sleep): + """Test that the polling loop updates current and peak usage.""" + def subprocess_side_effect(*args, **kwargs): + if isinstance(args[0], list) and "memory.total" in args[0][1]: + return "16000" + + if isinstance(args[0], str) and "memory.free" in args[0]: + return b"12000" + + raise Exception("Unexpected command") + + self.mock_subprocess.side_effect = subprocess_side_effect + self.mock_subprocess.return_value = None + + monitor = GPUMonitor() + monitor.start() + time.sleep(0.1) + curr, peak, total = monitor.get_stats() + monitor.stop() + + self.assertEqual(curr, 4000.0) + self.assertEqual(peak, 4000.0) + self.assertEqual(total, 16000.0) + + def test_reset_peak(self): + """Test that resetting peak usage works.""" + monitor = GPUMonitor() + monitor._gpu_available = True + + with monitor._lock: + monitor._current_usage = 2000.0 + monitor._peak_usage = 8000.0 + monitor._memory_history.append((time.time(), 8000.0)) + monitor._memory_history.append((time.time(), 2000.0)) + + monitor.reset_peak() + + _, peak, _ = monitor.get_stats() + self.assertEqual(peak, 2000.0) + + +class TestResourceEstimatorSolver(unittest.TestCase): + def setUp(self): + self.estimator = ResourceEstimator() + + @patch('apache_beam.ml.inference.model_manager.nnls') + def test_solver_respects_min_data_points(self, mock_nnls): + mock_nnls.return_value = ([100.0, 50.0], 0.0) + + self.estimator.add_observation({'model_A': 1}, 500) + self.estimator.add_observation({'model_B': 1}, 500) + self.assertFalse(mock_nnls.called) + + self.estimator.add_observation({'model_A': 1, 'model_B': 1}, 1000) + self.assertFalse(mock_nnls.called) + + self.estimator.add_observation({'model_A': 1}, 500) + self.assertFalse(mock_nnls.called) + + self.estimator.add_observation({'model_B': 1}, 500) + self.assertTrue(mock_nnls.called) + + @patch('apache_beam.ml.inference.model_manager.nnls') + def test_solver_respects_unique_model_constraint(self, mock_nnls): + mock_nnls.return_value = ([100.0, 100.0, 50.0], 0.0) + + for _ in range(5): + self.estimator.add_observation({'model_A': 1, 'model_B': 1}, 800) + + for _ in range(5): + self.estimator.add_observation({'model_C': 1}, 400) + + self.assertFalse(mock_nnls.called) + + self.estimator.add_observation({'model_A': 1}, 300) + self.estimator.add_observation({'model_B': 1}, 300) + + self.assertTrue(mock_nnls.called) + + +if __name__ == "__main__": + unittest.main()