Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
f283322
Add model manager and rename modelmanager in base
AMOOOMA Dec 5, 2025
efb3f4d
Update indent
AMOOOMA Dec 5, 2025
d896b9d
RunInference with model manager
AMOOOMA Dec 5, 2025
4c0a933
fix
AMOOOMA Dec 5, 2025
a6ba692
fix pickle
AMOOOMA Dec 5, 2025
a273f6e
Fix pickling and auto proxy
AMOOOMA Dec 6, 2025
3f6a7b9
fix
AMOOOMA Dec 6, 2025
379080a
Add more tests
AMOOOMA Dec 8, 2025
fac671c
Fix test
AMOOOMA Dec 8, 2025
a2b7901
Add more test
AMOOOMA Dec 8, 2025
019ec73
Add more test
AMOOOMA Dec 8, 2025
e979c67
Add more test
AMOOOMA Dec 8, 2025
7df7d53
Add more test and error handling
AMOOOMA Dec 9, 2025
eb21943
Add logging
AMOOOMA Dec 9, 2025
4b65323
Add logging
AMOOOMA Dec 9, 2025
e2e96bd
Fix memory check
AMOOOMA Dec 9, 2025
34ea3b3
Merge branch 'apache:master' into model
AMOOOMA Dec 9, 2025
d3537ee
Fix solver check
AMOOOMA Dec 10, 2025
a2e1178
Update logging
AMOOOMA Dec 10, 2025
6c208de
Fix force reset
AMOOOMA Dec 10, 2025
a2694a7
Allow passing in model manager args
AMOOOMA Dec 11, 2025
d1e9a8f
Fix multiprocessingshared
AMOOOMA Dec 11, 2025
5860f27
Fix unsafe delete
AMOOOMA Dec 11, 2025
ed4a578
Make more fixes
AMOOOMA Dec 11, 2025
f26923c
Fix indent
AMOOOMA Dec 11, 2025
0c8d26d
Fix assert
AMOOOMA Dec 11, 2025
c8f064e
Supports model eviction
AMOOOMA Dec 12, 2025
3b0f4df
Fix override check
AMOOOMA Dec 12, 2025
bafe6e4
Fix parallel process issue and add capabiility to spawn process with MPS
AMOOOMA Dec 12, 2025
716d273
Remove override check
AMOOOMA Dec 12, 2025
6231af0
Update the default slack percentage
AMOOOMA Dec 12, 2025
6b428db
Skip tests if dependencies is missing
AMOOOMA Dec 13, 2025
b245c13
Reorder import
AMOOOMA Dec 13, 2025
ffd53d2
Add license
AMOOOMA Dec 13, 2025
3d8a2b5
Skip model manager import if error
AMOOOMA Dec 13, 2025
6a0bd89
Fix skipif
AMOOOMA Dec 13, 2025
e375b45
Enable model manager on the it test
AMOOOMA Dec 13, 2025
540cc8c
Add timeout and force poll gpu usage to prevent race condition
AMOOOMA Dec 13, 2025
d392438
Cleanup queue to avoid leaving zombie ticket
AMOOOMA Dec 13, 2025
ed0c2d8
Add logging for debugging
AMOOOMA Dec 13, 2025
c380699
Add more logging
AMOOOMA Dec 14, 2025
2e1b7d9
Fix delete model
AMOOOMA Dec 14, 2025
89047b7
Make sure process exit
AMOOOMA Dec 15, 2025
79eafe8
Make sure process exit without manager
AMOOOMA Dec 15, 2025
d819504
Add tests for reaping
AMOOOMA Dec 15, 2025
075ab41
Remove redundant line
AMOOOMA Dec 15, 2025
c627783
Add uuid to make sure eviction clears model
AMOOOMA Dec 15, 2025
ebb6ff5
Avoid pickling issue
AMOOOMA Dec 15, 2025
1d3f8c9
Wait on server start and log if taking too long
AMOOOMA Dec 15, 2025
970ecf5
Merge branch 'apache:master' into model
AMOOOMA Dec 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 76 additions & 16 deletions sdks/python/apache_beam/ml/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,10 @@
try:
# pylint: disable=wrong-import-order, wrong-import-position
import resource
from apache_beam.ml.inference.model_manager import ModelManager
except ImportError:
resource = None # type: ignore[assignment]
ModelManager = None # type: ignore[assignment]

_NANOSECOND_TO_MILLISECOND = 1_000_000
_NANOSECOND_TO_MICROSECOND = 1_000
Expand Down Expand Up @@ -467,11 +469,12 @@ def request(
raise NotImplementedError(type(self))


class _ModelManager:
class _ModelHandlerManager:
"""
A class for efficiently managing copies of multiple models. Will load a
single copy of each model into a multi_process_shared object and then
return a lookup key for that object.
A class for efficiently managing copies of multiple model handlers.
Will load a single copy of each model from the model handler into a
multi_process_shared object and then return a lookup key for that
object. Used for KeyedModelHandler only.
"""
def __init__(self, mh_map: dict[str, ModelHandler]):
"""
Expand Down Expand Up @@ -536,8 +539,9 @@ def load(self, key: str) -> _ModelLoadStats:

def increment_max_models(self, increment: int):
"""
Increments the number of models that this instance of a _ModelManager is
able to hold. If it is never called, no limit is imposed.
Increments the number of models that this instance of a
_ModelHandlerManager is able to hold. If it is never called,
no limit is imposed.
Args:
increment: the amount by which we are incrementing the number of models.
"""
Expand Down Expand Up @@ -590,7 +594,7 @@ def __init__(
class KeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT],
ModelHandler[tuple[KeyT, ExampleT],
tuple[KeyT, PredictionT],
Union[ModelT, _ModelManager]]):
Union[ModelT, _ModelHandlerManager]]):
def __init__(
self,
unkeyed: Union[ModelHandler[ExampleT, PredictionT, ModelT],
Expand Down Expand Up @@ -743,15 +747,15 @@ def __init__(
'to exactly one model handler.')
self._key_to_id_map[key] = keys[0]

def load_model(self) -> Union[ModelT, _ModelManager]:
def load_model(self) -> Union[ModelT, _ModelHandlerManager]:
if self._single_model:
return self._unkeyed.load_model()
return _ModelManager(self._id_to_mh_map)
return _ModelHandlerManager(self._id_to_mh_map)

def run_inference(
self,
batch: Sequence[tuple[KeyT, ExampleT]],
model: Union[ModelT, _ModelManager],
model: Union[ModelT, _ModelHandlerManager],
inference_args: Optional[dict[str, Any]] = None
) -> Iterable[tuple[KeyT, PredictionT]]:
if self._single_model:
Expand Down Expand Up @@ -853,7 +857,7 @@ def validate_inference_args(self, inference_args: Optional[dict[str, Any]]):

def update_model_paths(
self,
model: Union[ModelT, _ModelManager],
model: Union[ModelT, _ModelHandlerManager],
model_paths: list[KeyModelPathMapping[KeyT]] = None):
# When there are many models, the keyed model handler is responsible for
# reorganizing the model handlers into cohorts and telling the model
Expand Down Expand Up @@ -1272,6 +1276,8 @@ def __init__(
model_metadata_pcoll: beam.PCollection[ModelMetadata] = None,
watch_model_pattern: Optional[str] = None,
model_identifier: Optional[str] = None,
use_model_manager: bool = False,
model_manager_args: Optional[dict[str, Any]] = None,
**kwargs):
"""
A transform that takes a PCollection of examples (or features) for use
Expand Down Expand Up @@ -1312,6 +1318,8 @@ def __init__(
self._exception_handling_timeout = None
self._timeout = None
self._watch_model_pattern = watch_model_pattern
self._use_model_manager = use_model_manager
self._model_manager_args = model_manager_args
self._kwargs = kwargs
# Generate a random tag to use for shared.py and multi_process_shared.py to
# allow us to effectively disambiguate in multi-model settings. Only use
Expand Down Expand Up @@ -1424,7 +1432,9 @@ def expand(
self._clock,
self._metrics_namespace,
load_model_at_runtime,
self._model_tag),
self._model_tag,
self._use_model_manager,
self._model_manager_args),
self._inference_args,
beam.pvalue.AsSingleton(
self._model_metadata_pcoll,
Expand Down Expand Up @@ -1737,31 +1747,67 @@ def load_model_status(
return shared.Shared().acquire(lambda: _ModelStatus(False), tag=tag)


class _ProxyLoader:
"""
A helper callable to wrap the loader for MultiProcessShared.
"""
def __init__(self, loader_func, model_tag):
self.loader_func = loader_func
self.model_tag = model_tag

def __call__(self):
unique_tag = self.model_tag + '_' + uuid.uuid4().hex
# Ensure that each model loaded in a different process for parallelism
return multi_process_shared.MultiProcessShared(
self.loader_func, tag=unique_tag, always_proxy=True,
spawn_process=True).acquire()


class _SharedModelWrapper():
"""A router class to map incoming calls to the correct model.

This allows us to round robin calls to models sitting in different
processes so that we can more efficiently use resources (e.g. GPUs).
"""
def __init__(self, models: list[Any], model_tag: str):
def __init__(
self,
models: Union[list[Any], ModelManager],
model_tag: str,
loader_func: Callable[[], Any] = None):
self.models = models
if len(models) > 1:
self.use_model_manager = not isinstance(models, list)
self.model_tag = model_tag
self.loader_func = loader_func
if not self.use_model_manager and len(models) > 1:
self.model_router = multi_process_shared.MultiProcessShared(
lambda: _ModelRoutingStrategy(),
tag=f'{model_tag}_counter',
always_proxy=True).acquire()

def next_model(self):
if self.use_model_manager:
loader_wrapper = _ProxyLoader(self.loader_func, self.model_tag)
return self.models.acquire_model(self.model_tag, loader_wrapper)
if len(self.models) == 1:
# Short circuit if there's no routing strategy needed in order to
# avoid the cross-process call
return self.models[0]

return self.models[self.model_router.next_model_index(len(self.models))]

def release_model(self, model_tag: str, model: Any):
if self.use_model_manager:
self.models.release_model(model_tag, model)

def all_models(self):
if self.use_model_manager:
return self.models.all_models()[self.model_tag]
return self.models

def force_reset(self):
if self.use_model_manager:
self.models.force_reset()


class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT, PredictionT]):
def __init__(
Expand All @@ -1770,7 +1816,9 @@ def __init__(
clock,
metrics_namespace,
load_model_at_runtime: bool = False,
model_tag: str = "RunInference"):
model_tag: str = "RunInference",
use_model_manager: bool = False,
model_manager_args: Optional[dict[str, Any]] = None):
"""A DoFn implementation generic to frameworks.

Args:
Expand All @@ -1794,6 +1842,8 @@ def __init__(
# _cur_tag is the tag of the actually loaded model
self._model_tag = model_tag
self._cur_tag = model_tag
self.use_model_manager = use_model_manager
self._model_manager_args = model_manager_args or {}

def _load_model(
self,
Expand Down Expand Up @@ -1828,7 +1878,15 @@ def load():
model_tag = side_input_model_path
# Ensure the tag we're loading is valid, if not replace it with a valid tag
self._cur_tag = self._model_metadata.get_valid_tag(model_tag)
if self._model_handler.share_model_across_processes():
if self.use_model_manager:
logging.info("Using Model Manager to manage models automatically.")
model_manager = multi_process_shared.MultiProcessShared(
lambda: ModelManager(**self._model_manager_args),
tag='model_manager',
always_proxy=True).acquire()
model_wrapper = _SharedModelWrapper(
model_manager, self._cur_tag, self._model_handler.load_model)
elif self._model_handler.share_model_across_processes():
models = []
for copy_tag in _get_tags_for_copies(self._cur_tag,
self._model_handler.model_copies()):
Expand Down Expand Up @@ -1885,6 +1943,8 @@ def _run_inference(self, batch, inference_args):
model = self._model.next_model()
result_generator = self._model_handler.run_inference(
batch, model, inference_args)
if self.use_model_manager:
self._model.release_model(self._model_tag, model)
except BaseException as e:
if self._metrics_collector:
self._metrics_collector.failed_batches_counter.inc()
Expand Down
69 changes: 60 additions & 9 deletions sdks/python/apache_beam/ml/inference/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,19 @@ def batch_elements_kwargs(self):
return {'min_batch_size': 1, 'max_batch_size': 1}


class SimpleFakeModelHanlder(base.ModelHandler[int, int, FakeModel]):
def load_model(self):
return FakeModel()

def run_inference(
self,
batch: Sequence[int],
model: FakeModel,
inference_args=None) -> Iterable[int]:
for example in batch:
yield model.predict(example)


class FakeModelHandler(base.ModelHandler[int, int, FakeModel]):
def __init__(
self,
Expand Down Expand Up @@ -310,6 +323,15 @@ def validate_inference_args(self, inference_args):
pass


def try_import_model_manager():
try:
# pylint: disable=unused-import
from apache_beam.ml.inference.model_manager import ModelManager
return True
except ImportError:
return False


class RunInferenceBaseTest(unittest.TestCase):
def test_run_inference_impl_simple_examples(self):
with TestPipeline() as pipeline:
Expand Down Expand Up @@ -1599,13 +1621,13 @@ def test_child_class_without_env_vars(self):
actual = pcoll | base.RunInference(FakeModelHandlerNoEnvVars())
assert_that(actual, equal_to(expected), label='assert:inferences')

def test_model_manager_loads_shared_model(self):
def test_model_handler_manager_loads_shared_model(self):
mhs = {
'key1': FakeModelHandler(state=1),
'key2': FakeModelHandler(state=2),
'key3': FakeModelHandler(state=3)
}
mm = base._ModelManager(mh_map=mhs)
mm = base._ModelHandlerManager(mh_map=mhs)
tag1 = mm.load('key1').model_tag
# Use bad_mh's load function to make sure we're actually loading the
# version already stored
Expand All @@ -1623,12 +1645,12 @@ def test_model_manager_loads_shared_model(self):
self.assertEqual(2, model2.predict(10))
self.assertEqual(3, model3.predict(10))

def test_model_manager_evicts_models(self):
def test_model_handler_manager_evicts_models(self):
mh1 = FakeModelHandler(state=1)
mh2 = FakeModelHandler(state=2)
mh3 = FakeModelHandler(state=3)
mhs = {'key1': mh1, 'key2': mh2, 'key3': mh3}
mm = base._ModelManager(mh_map=mhs)
mm = base._ModelHandlerManager(mh_map=mhs)
mm.increment_max_models(2)
tag1 = mm.load('key1').model_tag
sh1 = multi_process_shared.MultiProcessShared(mh1.load_model, tag=tag1)
Expand Down Expand Up @@ -1667,10 +1689,10 @@ def test_model_manager_evicts_models(self):
mh3.load_model, tag=tag3).acquire()
self.assertEqual(8, model3.predict(10))

def test_model_manager_evicts_models_after_update(self):
def test_model_handler_manager_evicts_models_after_update(self):
mh1 = FakeModelHandler(state=1)
mhs = {'key1': mh1}
mm = base._ModelManager(mh_map=mhs)
mm = base._ModelHandlerManager(mh_map=mhs)
tag1 = mm.load('key1').model_tag
sh1 = multi_process_shared.MultiProcessShared(mh1.load_model, tag=tag1)
model1 = sh1.acquire()
Expand All @@ -1697,13 +1719,12 @@ def test_model_manager_evicts_models_after_update(self):
self.assertEqual(6, model1.predict(10))
sh1.release(model1)

def test_model_manager_evicts_correct_num_of_models_after_being_incremented(
self):
def test_model_handler_manager_evicts_correctly_after_being_incremented(self):
mh1 = FakeModelHandler(state=1)
mh2 = FakeModelHandler(state=2)
mh3 = FakeModelHandler(state=3)
mhs = {'key1': mh1, 'key2': mh2, 'key3': mh3}
mm = base._ModelManager(mh_map=mhs)
mm = base._ModelHandlerManager(mh_map=mhs)
mm.increment_max_models(1)
mm.increment_max_models(1)
tag1 = mm.load('key1').model_tag
Expand Down Expand Up @@ -1888,6 +1909,36 @@ def test_model_status_provides_valid_garbage_collection(self):

self.assertEqual(0, len(tags))

@unittest.skipIf(
not try_import_model_manager(), 'Model Manager not available')
def test_run_inference_impl_with_model_manager(self):
with TestPipeline() as pipeline:
examples = [1, 5, 3, 10]
expected = [example + 1 for example in examples]
pcoll = pipeline | 'start' >> beam.Create(examples)
actual = pcoll | base.RunInference(
SimpleFakeModelHanlder(), use_model_manager=True)
assert_that(actual, equal_to(expected), label='assert:inferences')

@unittest.skipIf(
not try_import_model_manager(), 'Model Manager not available')
def test_run_inference_impl_with_model_manager_args(self):
with TestPipeline() as pipeline:
examples = [1, 5, 3, 10]
expected = [example + 1 for example in examples]
pcoll = pipeline | 'start' >> beam.Create(examples)
actual = pcoll | base.RunInference(
SimpleFakeModelHanlder(),
use_model_manager=True,
model_manager_args={
'slack_percentage': 0.2,
'poll_interval': 1.0,
'peak_window_seconds': 10.0,
'min_data_points': 10,
'smoothing_factor': 0.5
})
assert_that(actual, equal_to(expected), label='assert:inferences')


def _always_retry(e: Exception) -> bool:
return True
Expand Down
Loading
Loading