Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 34 additions & 5 deletions sdks/python/apache_beam/ml/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1330,6 +1330,30 @@ def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
return self._base.get_postprocess_fns() + [self._postprocess_fn]


class OOMProtectedFn:
def __init__(self, func):
self.func = func

def __call__(self, *args, **kwargs):
try:
return self.func(*args, **kwargs)
except Exception as e:
# Check string to avoid hard import dependency
if 'out of memory' in str(e) and 'CUDA' in str(e):
logging.warning("Caught CUDA OOM during operation. Cleaning memory.")
try:
import gc

import torch
gc.collect()
torch.cuda.empty_cache()
except ImportError:
pass
except Exception as cleanup_error:
logging.error("Failed to clean up CUDA memory: %s", cleanup_error)
raise e


class RunInference(beam.PTransform[beam.PCollection[Union[ExampleT,
Iterable[ExampleT]]],
beam.PCollection[PredictionT]]):
Expand Down Expand Up @@ -1831,7 +1855,9 @@ def __call__(self):
unique_tag = self.model_tag + '_' + uuid.uuid4().hex
# Ensure that each model loaded in a different process for parallelism
multi_process_shared.MultiProcessShared(
self.loader_func, tag=unique_tag, always_proxy=True,
OOMProtectedFn(self.loader_func),
tag=unique_tag,
always_proxy=True,
spawn_process=True).acquire()
# Only return the tag to avoid pickling issues with the model itself.
return unique_tag
Expand Down Expand Up @@ -2021,10 +2047,13 @@ def _run_inference(self, batch, inference_args):
unique_tag = model
model = multi_process_shared.MultiProcessShared(
lambda: None, tag=model, always_proxy=True).acquire()
result_generator = self._model_handler.run_inference(
batch, model, inference_args)
if self.use_model_manager:
self._model.release_model(self._model_tag, unique_tag)
try:
result_generator = (OOMProtectedFn(self._model_handler.run_inference))(
batch, model, inference_args)
finally:
# Always release the model so that it can be reloaded.
if self.use_model_manager:
self._model.release_model(self._model_tag, unique_tag)
except BaseException as e:
if self._metrics_collector:
self._metrics_collector.failed_batches_counter.inc()
Expand Down
26 changes: 26 additions & 0 deletions sdks/python/apache_beam/ml/inference/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import multiprocessing
import os
import pickle
import random
import sys
import tempfile
import time
Expand Down Expand Up @@ -2338,6 +2339,31 @@ def test_run_inference_impl_with_model_manager_args(self):
})
assert_that(actual, equal_to(expected), label='assert:inferences')

@unittest.skipIf(
not try_import_model_manager(), 'Model Manager not available')
def test_run_inference_impl_with_model_manager_oom(self):
class OOMFakeModelHandler(SimpleFakeModelHandler):
def run_inference(
self,
batch: Sequence[int],
model: FakeModel,
inference_args=None) -> Iterable[int]:
if random.random() < 0.8:
raise MemoryError("Simulated OOM")
for example in batch:
yield model.predict(example)

def batch_elements_kwargs(self):
return {'min_batch_size': 1, 'max_batch_size': 1}

with self.assertRaises(Exception):
with TestPipeline() as pipeline:
examples = [1, 5, 3, 10]
pcoll = pipeline | 'start' >> beam.Create(examples)
actual = pcoll | base.RunInference(
OOMFakeModelHandler(), use_model_manager=True)
assert_that(actual, equal_to([2, 6, 4, 11]), label='assert:inferences')
Comment on lines 2345 to 2365
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This new test has a few issues that could be improved:

  1. Non-determinism: The use of random.random() makes this test non-deterministic, which can lead to flaky builds. It's better to make tests deterministic.

  2. Confusing structure: The assert_that call is inside the with self.assertRaises(Exception): block. This means the assertion is only checked if the test is about to fail anyway because no exception was raised. It's clearer to have separate tests for success and failure cases.

  3. Incomplete test coverage: The OOMProtectedFn specifically looks for 'out of memory' and 'CUDA' in the exception string to trigger the memory cleanup logic. The MemoryError("Simulated OOM") raised here does not contain 'out of memory', so the cleanup path is not actually being tested.

I'm suggesting a change to make this test deterministic and to correctly test the OOM cleanup path by raising a more specific error. This will make the test more reliable and ensure the new functionality is properly verified.

    class OOMFakeModelHandler(SimpleFakeModelHandler):
      def run_inference(
          self,
          batch: Sequence[int],
          model: FakeModel,
          inference_args=None) -> Iterable[int]:
        # This will always raise to test the OOM path.
        raise MemoryError("CUDA out of memory. Simulated OOM.")

    with self.assertRaises(MemoryError):
      with TestPipeline() as pipeline:
        examples = [1, 5, 3, 10]
        pcoll = pipeline | 'start' >> beam.Create(examples)
        # The pipeline will fail, so we don't need to check the output.
        _ = pcoll | base.RunInference(
            OOMFakeModelHandler(), use_model_manager=True)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the suggestion is good, but there is a 20% chance this test succeeds. Could we drop the batch size to 1? That would help since we'd get 4 run_inference calls instead of just one

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes good catch! I didn't notice the batch size, so was assuming 0.2^4, will update!



if __name__ == '__main__':
unittest.main()
63 changes: 42 additions & 21 deletions sdks/python/apache_beam/ml/inference/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,14 +176,23 @@ class ResourceEstimator:
individual models based on aggregate system memory readings and the
configuration of active models at that time.
"""
def __init__(self, smoothing_factor: float = 0.2, min_data_points: int = 5):
def __init__(
self,
smoothing_factor: float = 0.2,
min_data_points: int = 5,
verbose_logging: bool = False):
self.smoothing_factor = smoothing_factor
self.min_data_points = min_data_points
self.verbose_logging = verbose_logging
self.estimates: Dict[str, float] = {}
self.history = defaultdict(lambda: deque(maxlen=20))
self.known_models = set()
self._lock = threading.Lock()

def logging_info(self, message: str, *args):
if self.verbose_logging:
logger.info(message, *args)

def is_unknown(self, model_tag: str) -> bool:
with self._lock:
return model_tag not in self.estimates
Expand All @@ -196,7 +205,7 @@ def set_initial_estimate(self, model_tag: str, cost: float):
with self._lock:
self.estimates[model_tag] = cost
self.known_models.add(model_tag)
logger.info("Initial Profile for %s: %s MB", model_tag, cost)
self.logging_info("Initial Profile for %s: %s MB", model_tag, cost)

def add_observation(
self, active_snapshot: Dict[str, int], peak_memory: float):
Expand All @@ -207,7 +216,7 @@ def add_observation(
else:
model_list = "\t- None"

logger.info(
self.logging_info(
"Adding Observation:\n PeakMemory: %.1f MB\n Instances:\n%s",
peak_memory,
model_list)
Expand Down Expand Up @@ -256,7 +265,7 @@ def _solve(self):
# Not enough data to solve yet
return

logger.info(
self.logging_info(
"Solving with %s total observations for %s models.",
len(A),
len(unique))
Expand All @@ -280,9 +289,9 @@ def _solve(self):
else:
self.estimates[model] = calculated_cost

logger.info(
self.logging_info(
"Updated Estimate for %s: %.1f MB", model, self.estimates[model])
logger.info("System Bias: %s MB", bias)
self.logging_info("System Bias: %s MB", bias)

except Exception as e:
logger.error("Solver failed: %s", e)
Expand Down Expand Up @@ -321,10 +330,13 @@ def __init__(
eviction_cooldown_seconds: float = 10.0,
min_model_copies: int = 1,
wait_timeout_seconds: float = 300.0,
lock_timeout_seconds: float = 60.0):
lock_timeout_seconds: float = 60.0,
verbose_logging: bool = False):

self._estimator = ResourceEstimator(
min_data_points=min_data_points, smoothing_factor=smoothing_factor)
min_data_points=min_data_points,
smoothing_factor=smoothing_factor,
verbose_logging=verbose_logging)
self._monitor = monitor if monitor else GPUMonitor(
poll_interval=poll_interval, peak_window_seconds=peak_window_seconds)
self._slack_percentage = slack_percentage
Expand All @@ -333,6 +345,7 @@ def __init__(
self._min_model_copies = min_model_copies
self._wait_timeout_seconds = wait_timeout_seconds
self._lock_timeout_seconds = lock_timeout_seconds
self._verbose_logging = verbose_logging

# Resource State
self._models = defaultdict(list)
Expand Down Expand Up @@ -361,20 +374,24 @@ def __init__(

self._monitor.start()

def logging_info(self, message: str, *args):
if self._verbose_logging:
logger.info(message, *args)

def all_models(self, tag) -> list[Any]:
return self._models[tag]

# Should hold _cv lock when calling
def try_enter_isolation_mode(self, tag: str, ticket_num: int) -> bool:
if self._total_active_jobs > 0:
logger.info(
self.logging_info(
"Waiting to enter isolation: tag=%s ticket num=%s", tag, ticket_num)
self._cv.wait(timeout=self._lock_timeout_seconds)
# return False since we have waited and need to re-evaluate
# in caller to make sure our priority is still valid.
return False

logger.info("Unknown model %s detected. Flushing GPU.", tag)
self.logging_info("Unknown model %s detected. Flushing GPU.", tag)
self._delete_all_models()

self._isolation_mode = True
Expand Down Expand Up @@ -412,7 +429,7 @@ def should_spawn_model(self, tag: str, ticket_num: int) -> bool:
for _, instances in self._models.items():
total_model_count += len(instances)
curr, _, _ = self._monitor.get_stats()
logger.info(
self.logging_info(
"Waiting for resources to free up: "
"tag=%s ticket num%s model count=%s "
"idle count=%s resource usage=%.1f MB "
Expand Down Expand Up @@ -462,7 +479,7 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any:
# SLOW PATH: Enqueue and wait for turn to acquire model,
# with unknown models having priority and order enforced
# by ticket number as FIFO.
logger.info(
self.logging_info(
"Acquire Queued: tag=%s, priority=%d "
"total models count=%s ticket num=%s",
tag,
Expand All @@ -484,7 +501,7 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any:
f"after {wait_time_elapsed:.1f} seconds.")
if not self._wait_queue or self._wait_queue[
0].ticket_num != ticket_num:
logger.info(
self.logging_info(
"Waiting for its turn: tag=%s ticket num=%s", tag, ticket_num)
self._wait_in_queue(my_ticket)
continue
Expand Down Expand Up @@ -520,7 +537,7 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any:
# Path B: Concurrent
else:
if self._isolation_mode:
logger.info(
self.logging_info(
"Waiting due to isolation in progress: tag=%s ticket num%s",
tag,
ticket_num)
Expand Down Expand Up @@ -596,7 +613,7 @@ def _try_grab_from_lru(self, tag: str) -> Any:
self._total_active_jobs += 1
return target_instance

logger.info("No idle model found for tag: %s", tag)
self.logging_info("No idle model found for tag: %s", tag)
return None

def _evict_to_make_space(
Expand Down Expand Up @@ -671,17 +688,21 @@ def _delete_instance(self, instance: Any):
if isinstance(instance, str):
# If the instance is a string, it's a uuid used
# to retrieve the model from MultiProcessShared
multi_process_shared.MultiProcessShared(
lambda: "N/A", tag=instance).unsafe_hard_delete()
try:
multi_process_shared.MultiProcessShared(
lambda: "N/A", tag=instance).unsafe_hard_delete()
except (EOFError, OSError, BrokenPipeError):
# This can happen even in normal operation.
pass
if hasattr(instance, 'mock_model_unsafe_hard_delete'):
# Call the mock unsafe hard delete method for testing
instance.mock_model_unsafe_hard_delete()
del instance

def _perform_eviction(self, key: str, tag: str, instance: Any, score: int):
logger.info("Evicting Model: %s (Score %d)", tag, score)
self.logging_info("Evicting Model: %s (Score %d)", tag, score)
curr, _, _ = self._monitor.get_stats()
logger.info("Resource Usage Before Eviction: %.1f MB", curr)
self.logging_info("Resource Usage Before Eviction: %.1f MB", curr)

if key in self._idle_lru:
del self._idle_lru[key]
Expand All @@ -697,7 +718,7 @@ def _perform_eviction(self, key: str, tag: str, instance: Any, score: int):
self._monitor.refresh()
self._monitor.reset_peak()
curr, _, _ = self._monitor.get_stats()
logger.info("Resource Usage After Eviction: %.1f MB", curr)
self.logging_info("Resource Usage After Eviction: %.1f MB", curr)

def _spawn_new_model(
self,
Expand All @@ -707,7 +728,7 @@ def _spawn_new_model(
est_cost: float) -> Any:
try:
with self._cv:
logger.info("Loading Model: %s (Unknown: %s)", tag, is_unknown)
self.logging_info("Loading Model: %s (Unknown: %s)", tag, is_unknown)
baseline_snap, _, _ = self._monitor.get_stats()
instance = loader_func()
_, peak_during_load, _ = self._monitor.get_stats()
Expand Down
Loading