6868try :
6969 # pylint: disable=wrong-import-order, wrong-import-position
7070 import resource
71+
72+ from apache_beam .ml .inference .model_manager import ModelManager
7173except ImportError :
7274 resource = None # type: ignore[assignment]
75+ ModelManager = None # type: ignore[assignment]
7376
7477_NANOSECOND_TO_MILLISECOND = 1_000_000
7578_NANOSECOND_TO_MICROSECOND = 1_000
@@ -533,11 +536,12 @@ def request(
533536 raise NotImplementedError (type (self ))
534537
535538
536- class _ModelManager :
539+ class _ModelHandlerManager :
537540 """
538- A class for efficiently managing copies of multiple models. Will load a
539- single copy of each model into a multi_process_shared object and then
540- return a lookup key for that object.
541+ A class for efficiently managing copies of multiple model handlers.
542+ Will load a single copy of each model from the model handler into a
543+ multi_process_shared object and then return a lookup key for that
544+ object. Used for KeyedModelHandler only.
541545 """
542546 def __init__ (self , mh_map : dict [str , ModelHandler ]):
543547 """
@@ -602,8 +606,9 @@ def load(self, key: str) -> _ModelLoadStats:
602606
603607 def increment_max_models (self , increment : int ):
604608 """
605- Increments the number of models that this instance of a _ModelManager is
606- able to hold. If it is never called, no limit is imposed.
609+ Increments the number of models that this instance of a
610+ _ModelHandlerManager is able to hold. If it is never called,
611+ no limit is imposed.
607612 Args:
608613 increment: the amount by which we are incrementing the number of models.
609614 """
@@ -656,7 +661,7 @@ def __init__(
656661class KeyedModelHandler (Generic [KeyT , ExampleT , PredictionT , ModelT ],
657662 ModelHandler [tuple [KeyT , ExampleT ],
658663 tuple [KeyT , PredictionT ],
659- Union [ModelT , _ModelManager ]]):
664+ Union [ModelT , _ModelHandlerManager ]]):
660665 def __init__ (
661666 self ,
662667 unkeyed : Union [ModelHandler [ExampleT , PredictionT , ModelT ],
@@ -809,15 +814,15 @@ def __init__(
809814 'to exactly one model handler.' )
810815 self ._key_to_id_map [key ] = keys [0 ]
811816
812- def load_model (self ) -> Union [ModelT , _ModelManager ]:
817+ def load_model (self ) -> Union [ModelT , _ModelHandlerManager ]:
813818 if self ._single_model :
814819 return self ._unkeyed .load_model ()
815- return _ModelManager (self ._id_to_mh_map )
820+ return _ModelHandlerManager (self ._id_to_mh_map )
816821
817822 def run_inference (
818823 self ,
819824 batch : Sequence [tuple [KeyT , ExampleT ]],
820- model : Union [ModelT , _ModelManager ],
825+ model : Union [ModelT , _ModelHandlerManager ],
821826 inference_args : Optional [dict [str , Any ]] = None
822827 ) -> Iterable [tuple [KeyT , PredictionT ]]:
823828 if self ._single_model :
@@ -919,7 +924,7 @@ def validate_inference_args(self, inference_args: Optional[dict[str, Any]]):
919924
920925 def update_model_paths (
921926 self ,
922- model : Union [ModelT , _ModelManager ],
927+ model : Union [ModelT , _ModelHandlerManager ],
923928 model_paths : list [KeyModelPathMapping [KeyT ]] = None ):
924929 # When there are many models, the keyed model handler is responsible for
925930 # reorganizing the model handlers into cohorts and telling the model
@@ -1338,6 +1343,8 @@ def __init__(
13381343 model_metadata_pcoll : beam .PCollection [ModelMetadata ] = None ,
13391344 watch_model_pattern : Optional [str ] = None ,
13401345 model_identifier : Optional [str ] = None ,
1346+ use_model_manager : bool = False ,
1347+ model_manager_args : Optional [dict [str , Any ]] = None ,
13411348 ** kwargs ):
13421349 """
13431350 A transform that takes a PCollection of examples (or features) for use
@@ -1378,6 +1385,8 @@ def __init__(
13781385 self ._exception_handling_timeout = None
13791386 self ._timeout = None
13801387 self ._watch_model_pattern = watch_model_pattern
1388+ self ._use_model_manager = use_model_manager
1389+ self ._model_manager_args = model_manager_args
13811390 self ._kwargs = kwargs
13821391 # Generate a random tag to use for shared.py and multi_process_shared.py to
13831392 # allow us to effectively disambiguate in multi-model settings. Only use
@@ -1490,7 +1499,9 @@ def expand(
14901499 self ._clock ,
14911500 self ._metrics_namespace ,
14921501 load_model_at_runtime ,
1493- self ._model_tag ),
1502+ self ._model_tag ,
1503+ self ._use_model_manager ,
1504+ self ._model_manager_args ),
14941505 self ._inference_args ,
14951506 beam .pvalue .AsSingleton (
14961507 self ._model_metadata_pcoll ,
@@ -1803,31 +1814,75 @@ def load_model_status(
18031814 return shared .Shared ().acquire (lambda : _ModelStatus (False ), tag = tag )
18041815
18051816
1817+ class _ProxyLoader :
1818+ """
1819+ A helper callable to wrap the loader for MultiProcessShared.
1820+ """
1821+ def __init__ (self , loader_func , model_tag ):
1822+ self .loader_func = loader_func
1823+ self .model_tag = model_tag
1824+
1825+ def __call__ (self ):
1826+ # Generate a unique tag for the model being loaded so that
1827+ # we will have unique instances of the model in multi_process_shared
1828+ # space instead of reusing the same instance over. The instance will
1829+ # be initialized and left running as a separate process, which then
1830+ # can be grabbed again using the unique tag if needed during inference.
1831+ unique_tag = self .model_tag + '_' + uuid .uuid4 ().hex
1832+ # Ensure that each model loaded in a different process for parallelism
1833+ multi_process_shared .MultiProcessShared (
1834+ self .loader_func , tag = unique_tag , always_proxy = True ,
1835+ spawn_process = True ).acquire ()
1836+ # Only return the tag to avoid pickling issues with the model itself.
1837+ return unique_tag
1838+
1839+
18061840class _SharedModelWrapper ():
18071841 """A router class to map incoming calls to the correct model.
18081842
18091843 This allows us to round robin calls to models sitting in different
18101844 processes so that we can more efficiently use resources (e.g. GPUs).
18111845 """
1812- def __init__ (self , models : list [Any ], model_tag : str ):
1846+ def __init__ (
1847+ self ,
1848+ models : Union [list [Any ], ModelManager ],
1849+ model_tag : str ,
1850+ loader_func : Optional [Callable [[], Any ]] = None ):
18131851 self .models = models
1814- if len (models ) > 1 :
1852+ self .use_model_manager = not isinstance (models , list )
1853+ self .model_tag = model_tag
1854+ self .loader_func = loader_func
1855+ if not self .use_model_manager and len (models ) > 1 :
18151856 self .model_router = multi_process_shared .MultiProcessShared (
18161857 lambda : _ModelRoutingStrategy (),
18171858 tag = f'{ model_tag } _counter' ,
18181859 always_proxy = True ).acquire ()
18191860
18201861 def next_model (self ):
1862+ if self .use_model_manager :
1863+ loader_wrapper = _ProxyLoader (self .loader_func , self .model_tag )
1864+ return self .models .acquire_model (self .model_tag , loader_wrapper )
1865+
18211866 if len (self .models ) == 1 :
18221867 # Short circuit if there's no routing strategy needed in order to
18231868 # avoid the cross-process call
18241869 return self .models [0 ]
18251870
18261871 return self .models [self .model_router .next_model_index (len (self .models ))]
18271872
1873+ def release_model (self , model_tag : str , model : Any ):
1874+ if self .use_model_manager :
1875+ self .models .release_model (model_tag , model )
1876+
18281877 def all_models (self ):
1878+ if self .use_model_manager :
1879+ return self .models .all_models ()[self .model_tag ]
18291880 return self .models
18301881
1882+ def force_reset (self ):
1883+ if self .use_model_manager :
1884+ self .models .force_reset ()
1885+
18311886
18321887class _RunInferenceDoFn (beam .DoFn , Generic [ExampleT , PredictionT ]):
18331888 def __init__ (
@@ -1836,7 +1891,9 @@ def __init__(
18361891 clock ,
18371892 metrics_namespace ,
18381893 load_model_at_runtime : bool = False ,
1839- model_tag : str = "RunInference" ):
1894+ model_tag : str = "RunInference" ,
1895+ use_model_manager : bool = False ,
1896+ model_manager_args : Optional [dict [str , Any ]] = None ):
18401897 """A DoFn implementation generic to frameworks.
18411898
18421899 Args:
@@ -1860,6 +1917,8 @@ def __init__(
18601917 # _cur_tag is the tag of the actually loaded model
18611918 self ._model_tag = model_tag
18621919 self ._cur_tag = model_tag
1920+ self .use_model_manager = use_model_manager
1921+ self ._model_manager_args = model_manager_args or {}
18631922
18641923 def _load_model (
18651924 self ,
@@ -1894,7 +1953,15 @@ def load():
18941953 model_tag = side_input_model_path
18951954 # Ensure the tag we're loading is valid, if not replace it with a valid tag
18961955 self ._cur_tag = self ._model_metadata .get_valid_tag (model_tag )
1897- if self ._model_handler .share_model_across_processes ():
1956+ if self .use_model_manager :
1957+ logging .info ("Using Model Manager to manage models automatically." )
1958+ model_manager = multi_process_shared .MultiProcessShared (
1959+ lambda : ModelManager (** self ._model_manager_args ),
1960+ tag = 'model_manager' ,
1961+ always_proxy = True ).acquire ()
1962+ model_wrapper = _SharedModelWrapper (
1963+ model_manager , self ._cur_tag , self ._model_handler .load_model )
1964+ elif self ._model_handler .share_model_across_processes ():
18981965 models = []
18991966 for copy_tag in _get_tags_for_copies (self ._cur_tag ,
19001967 self ._model_handler .model_copies ()):
@@ -1949,8 +2016,15 @@ def _run_inference(self, batch, inference_args):
19492016 start_time = _to_microseconds (self ._clock .time_ns ())
19502017 try :
19512018 model = self ._model .next_model ()
2019+ if isinstance (model , str ):
2020+ # ModelManager with MultiProcessShared returns the model tag
2021+ unique_tag = model
2022+ model = multi_process_shared .MultiProcessShared (
2023+ lambda : None , tag = model , always_proxy = True ).acquire ()
19522024 result_generator = self ._model_handler .run_inference (
19532025 batch , model , inference_args )
2026+ if self .use_model_manager :
2027+ self ._model .release_model (self ._model_tag , unique_tag )
19542028 except BaseException as e :
19552029 if self ._metrics_collector :
19562030 self ._metrics_collector .failed_batches_counter .inc ()
0 commit comments