From a5ee30d4d7c56f325079fbb100bc51444597c3df Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Mon, 15 Dec 2025 21:45:16 +0000 Subject: [PATCH 1/2] Allow multiprocessshared to spawn process and delete directly with obj --- .../apache_beam/utils/multi_process_shared.py | 272 ++++++++++++++++-- .../utils/multi_process_shared_test.py | 218 ++++++++++++++ 2 files changed, 463 insertions(+), 27 deletions(-) diff --git a/sdks/python/apache_beam/utils/multi_process_shared.py b/sdks/python/apache_beam/utils/multi_process_shared.py index aecb1284a1d4..0efa01f45570 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared.py +++ b/sdks/python/apache_beam/utils/multi_process_shared.py @@ -25,6 +25,10 @@ import logging import multiprocessing.managers import os +import time +import traceback +import atexit +import sys import tempfile import threading from typing import Any @@ -79,6 +83,10 @@ def singletonProxy_release(self): assert self._SingletonProxy_valid self._SingletonProxy_valid = False + def unsafe_hard_delete(self): + assert self._SingletonProxy_valid + self._SingletonProxy_entry.unsafe_hard_delete() + def __getattr__(self, name): if not self._SingletonProxy_valid: raise RuntimeError('Entry was released.') @@ -105,17 +113,39 @@ def __dir__(self): dir = self._SingletonProxy_entry.obj.__dir__() dir.append('singletonProxy_call__') dir.append('singletonProxy_release') + dir.append('unsafe_hard_delete') return dir +def _run_with_oom_protection(func, *args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + # Check string to avoid hard import dependency + if 'CUDA out of memory' in str(e): + logging.warning("Caught CUDA OOM during operation. Cleaning memory.") + try: + import gc + import torch + gc.collect() + torch.cuda.empty_cache() + except ImportError: + pass + except Exception as cleanup_error: + logging.error("Failed to clean up CUDA memory: %s", cleanup_error) + raise e + + class _SingletonEntry: """Represents a single, refcounted entry in this process.""" - def __init__(self, constructor, initialize_eagerly=True): + def __init__( + self, constructor, initialize_eagerly=True, hard_delete_callback=None): self.constructor = constructor + self._hard_delete_callback = hard_delete_callback self.refcount = 0 self.lock = threading.Lock() if initialize_eagerly: - self.obj = constructor() + self.obj = _run_with_oom_protection(constructor) self.initialied = True else: self.initialied = False @@ -123,7 +153,7 @@ def __init__(self, constructor, initialize_eagerly=True): def acquire(self): with self.lock: if not self.initialied: - self.obj = self.constructor() + self.obj = _run_with_oom_protection(self.constructor) self.initialied = True self.refcount += 1 return _SingletonProxy(self) @@ -141,14 +171,28 @@ def unsafe_hard_delete(self): if self.initialied: del self.obj self.initialied = False + if self._hard_delete_callback: + self._hard_delete_callback() class _SingletonManager: entries: Dict[Any, Any] = {} - def register_singleton(self, constructor, tag, initialize_eagerly=True): + def __init__(self): + self._hard_delete_callback = None + + def set_hard_delete_callback(self, callback): + self._hard_delete_callback = callback + + def register_singleton( + self, + constructor, + tag, + initialize_eagerly=True, + hard_delete_callback=None): assert tag not in self.entries, tag - self.entries[tag] = _SingletonEntry(constructor, initialize_eagerly) + self.entries[tag] = _SingletonEntry( + constructor, initialize_eagerly, hard_delete_callback) def has_singleton(self, tag): return tag in self.entries @@ -160,7 +204,8 @@ def release_singleton(self, tag, obj): return self.entries[tag].release(obj) def unsafe_hard_delete_singleton(self, tag): - return self.entries[tag].unsafe_hard_delete() + self.entries[tag].unsafe_hard_delete() + self._hard_delete_callback() _process_level_singleton_manager = _SingletonManager() @@ -200,9 +245,99 @@ def __call__(self, *args, **kwargs): def __getattr__(self, name): return getattr(self._proxyObject, name) + def __setstate__(self, state): + self.__dict__.update(state) + + def __getstate__(self): + return self.__dict__ + def get_auto_proxy_object(self): return self._proxyObject + def unsafe_hard_delete(self): + try: + self._proxyObject.unsafe_hard_delete() + except (EOFError, ConnectionResetError, BrokenPipeError): + pass + except Exception as e: + logging.warning( + "Exception %s when trying to hard delete shared object proxy", e) + + +def _run_server_process(address_file, tag, constructor, authkey): + """ + Runs in a separate process. + Includes a 'Suicide Pact' monitor: If parent dies, I die. + """ + parent_pid = os.getppid() + + def cleanup_files(): + logging.info("Server process exiting. Deleting files for %s", tag) + try: + if os.path.exists(address_file): + os.remove(address_file) + if os.path.exists(address_file + ".error"): + os.remove(address_file + ".error") + except Exception: + pass + + def handle_unsafe_hard_delete(): + cleanup_files() + os._exit(0) + + def _monitor_parent(): + """Checks if parent is alive every second.""" + while True: + try: + os.kill(parent_pid, 0) + except OSError: + logging.warning( + "Process %s detected Parent %s died. Self-destructing.", + os.getpid(), + parent_pid) + cleanup_files() + os._exit(0) + time.sleep(0.5) + + atexit.register(cleanup_files) + + try: + t = threading.Thread(target=_monitor_parent, daemon=True) + t.start() + + logging.getLogger().setLevel(logging.INFO) + multiprocessing.current_process().authkey = authkey + + serving_manager = _SingletonRegistrar( + address=('localhost', 0), authkey=authkey) + _process_level_singleton_manager.set_hard_delete_callback( + handle_unsafe_hard_delete) + _process_level_singleton_manager.register_singleton( + constructor, + tag, + initialize_eagerly=True, + hard_delete_callback=handle_unsafe_hard_delete) + + server = serving_manager.get_server() + logging.info( + 'Process %s: Proxy serving %s at %s', os.getpid(), tag, server.address) + + with open(address_file + '.tmp', 'w') as fout: + fout.write('%s:%d' % server.address) + os.rename(address_file + '.tmp', address_file) + + server.serve_forever() + + except Exception: + tb = traceback.format_exc() + try: + with open(address_file + ".error.tmp", 'w') as fout: + fout.write(tb) + os.rename(address_file + ".error.tmp", address_file + ".error") + except Exception: + print(f"CRITICAL ERROR IN SHARED SERVER:\n{tb}", file=sys.stderr) + os._exit(1) + class MultiProcessShared(Generic[T]): """MultiProcessShared is used to share a single object across processes. @@ -252,7 +387,8 @@ def __init__( tag: Any, *, path: str = tempfile.gettempdir(), - always_proxy: Optional[bool] = None): + always_proxy: Optional[bool] = None, + spawn_process: bool = False): self._constructor = constructor self._tag = tag self._path = path @@ -262,6 +398,7 @@ def __init__( self._rpc_address = None self._cross_process_lock = fasteners.InterProcessLock( os.path.join(self._path, self._tag) + '.lock') + self._spawn_process = spawn_process def _get_manager(self): if self._manager is None: @@ -301,6 +438,10 @@ def acquire(self): # Caveat: They must always agree, as they will be ignored if the object # is already constructed. singleton = self._get_manager().acquire_singleton(self._tag) + # Trigger a sweep of zombie processes. + # calling active_children() has the side-effect of joining any finished + # processes, effectively reaping zombies from previous unsafe_hard_deletes. + if self._spawn_process: multiprocessing.active_children() return _AutoProxyWrapper(singleton) def release(self, obj): @@ -315,25 +456,102 @@ def unsafe_hard_delete(self): to this object exist, or (b) you are ok with all existing references to this object throwing strange errors when derefrenced. """ - self._get_manager().unsafe_hard_delete_singleton(self._tag) + try: + self._get_manager().unsafe_hard_delete_singleton(self._tag) + except (EOFError, ConnectionResetError, BrokenPipeError): + pass + except Exception as e: + logging.warning( + "Exception %s when trying to hard delete shared object %s", + e, + self._tag) def _create_server(self, address_file): - # We need to be able to authenticate with both the manager and the process. - self._serving_manager = _SingletonRegistrar( - address=('localhost', 0), authkey=AUTH_KEY) - multiprocessing.current_process().authkey = AUTH_KEY - # Initialize eagerly to avoid acting as the server if there are issues. - # Note, however, that _create_server itself is called lazily. - _process_level_singleton_manager.register_singleton( - self._constructor, self._tag, initialize_eagerly=True) - self._server = self._serving_manager.get_server() - logging.info( - 'Starting proxy server at %s for shared %s', - self._server.address, - self._tag) - with open(address_file + '.tmp', 'w') as fout: - fout.write('%s:%d' % self._server.address) - os.rename(address_file + '.tmp', address_file) - t = threading.Thread(target=self._server.serve_forever, daemon=True) - t.start() - logging.info('Done starting server') + if self._spawn_process: + error_file = address_file + ".error" + + if os.path.exists(error_file): + try: + os.remove(error_file) + except OSError: + pass + + ctx = multiprocessing.get_context('spawn') + p = ctx.Process( + target=_run_server_process, + args=(address_file, self._tag, self._constructor, AUTH_KEY), + daemon=False # Must be False for nested proxies + ) + p.start() + logging.info("Parent: Waiting for %s to write address file...", self._tag) + + def cleanup_process(): + if p.is_alive(): + logging.info( + "Parent: Terminating server process %s for %s", p.pid, self._tag) + p.terminate() + p.join() + try: + if os.path.exists(address_file): + os.remove(address_file) + if os.path.exists(error_file): + os.remove(error_file) + except Exception: + pass + + atexit.register(cleanup_process) + + start_time = time.time() + last_log = start_time + while True: + if os.path.exists(address_file): + break + + if os.path.exists(error_file): + with open(error_file, 'r') as f: + error_msg = f.read() + try: + os.remove(error_file) + except OSError: + pass + + if p.is_alive(): p.terminate() + raise RuntimeError(f"Shared Server Process crashed:\n{error_msg}") + + if not p.is_alive(): + exit_code = p.exitcode + raise RuntimeError( + "Shared Server Process died unexpectedly" + f" with exit code {exit_code}") + + if time.time() - last_log > 300: + logging.warning( + "Still waiting for %s to initialize... %ss elapsed)", + self._tag, + int(time.time() - start_time)) + last_log = time.time() + + time.sleep(0.05) + + logging.info('External process successfully started for %s', self._tag) + else: + # We need to be able to authenticate with both the manager + # and the process. + self._serving_manager = _SingletonRegistrar( + address=('localhost', 0), authkey=AUTH_KEY) + multiprocessing.current_process().authkey = AUTH_KEY + # Initialize eagerly to avoid acting as the server if there are issues. + # Note, however, that _create_server itself is called lazily. + _process_level_singleton_manager.register_singleton( + self._constructor, self._tag, initialize_eagerly=True) + self._server = self._serving_manager.get_server() + logging.info( + 'Starting proxy server at %s for shared %s', + self._server.address, + self._tag) + with open(address_file + '.tmp', 'w') as fout: + fout.write('%s:%d' % self._server.address) + os.rename(address_file + '.tmp', address_file) + t = threading.Thread(target=self._server.serve_forever, daemon=True) + t.start() + logging.info('Done starting server') diff --git a/sdks/python/apache_beam/utils/multi_process_shared_test.py b/sdks/python/apache_beam/utils/multi_process_shared_test.py index 0b7957632368..f3258cf0a968 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared_test.py +++ b/sdks/python/apache_beam/utils/multi_process_shared_test.py @@ -18,6 +18,9 @@ import logging import threading +import tempfile +import os +import multiprocessing import unittest from typing import Any @@ -82,6 +85,14 @@ def __getattribute__(self, __name: str) -> Any: return object.__getattribute__(self, __name) +class SimpleClass: + def make_proxy( + self, tag: str = 'proxy_on_proxy', spawn_process: bool = False): + return multi_process_shared.MultiProcessShared( + Counter, tag=tag, always_proxy=True, + spawn_process=spawn_process).acquire() + + class MultiProcessSharedTest(unittest.TestCase): @classmethod def setUpClass(cls): @@ -193,6 +204,34 @@ def test_unsafe_hard_delete(self): self.assertEqual(counter3.increment(), 1) + def test_unsafe_hard_delete_autoproxywrapper(self): + shared1 = multi_process_shared.MultiProcessShared( + Counter, + tag='test_unsafe_hard_delete_autoproxywrapper', + always_proxy=True) + shared2 = multi_process_shared.MultiProcessShared( + Counter, + tag='test_unsafe_hard_delete_autoproxywrapper', + always_proxy=True) + + counter1 = shared1.acquire() + counter2 = shared2.acquire() + self.assertEqual(counter1.increment(), 1) + self.assertEqual(counter2.increment(), 2) + + counter2.unsafe_hard_delete() + + with self.assertRaises(Exception): + counter1.get() + with self.assertRaises(Exception): + counter2.get() + + counter3 = multi_process_shared.MultiProcessShared( + Counter, + tag='test_unsafe_hard_delete_autoproxywrapper', + always_proxy=True).acquire() + self.assertEqual(counter3.increment(), 1) + def test_unsafe_hard_delete_no_op(self): shared1 = multi_process_shared.MultiProcessShared( Counter, tag='test_unsafe_hard_delete_no_op', always_proxy=True) @@ -242,6 +281,185 @@ def test_release_always_proxy(self): with self.assertRaisesRegex(Exception, 'released'): counter1.get() + def test_proxy_on_proxy(self): + shared1 = multi_process_shared.MultiProcessShared( + SimpleClass, tag='proxy_on_proxy_main', always_proxy=True) + instance = shared1.acquire() + proxy_instance = instance.make_proxy() + self.assertEqual(proxy_instance.increment(), 1) + + +class MultiProcessSharedSpawnProcessTest(unittest.TestCase): + def setUp(self): + tempdir = tempfile.gettempdir() + for tag in ['basic', + 'proxy_on_proxy', + 'proxy_on_proxy_main', + 'main', + 'to_delete', + 'mix1', + 'mix2' + 'test_process_exit']: + for ext in ['', '.address', '.address.error']: + try: + os.remove(os.path.join(tempdir, tag + ext)) + except OSError: + pass + + def tearDown(self): + for p in multiprocessing.active_children(): + p.terminate() + p.join() + + def test_call(self): + shared = multi_process_shared.MultiProcessShared( + Counter, tag='basic', always_proxy=True, spawn_process=True).acquire() + self.assertEqual(shared.get(), 0) + self.assertEqual(shared.increment(), 1) + self.assertEqual(shared.increment(10), 11) + self.assertEqual(shared.increment(value=10), 21) + self.assertEqual(shared.get(), 21) + + def test_proxy_on_proxy(self): + shared1 = multi_process_shared.MultiProcessShared( + SimpleClass, tag='main', always_proxy=True) + instance = shared1.acquire() + proxy_instance = instance.make_proxy(spawn_process=True) + self.assertEqual(proxy_instance.increment(), 1) + proxy_instance.unsafe_hard_delete() + + proxy_instance2 = instance.make_proxy(tag='proxy_2', spawn_process=True) + self.assertEqual(proxy_instance2.increment(), 1) + + def test_unsafe_hard_delete_autoproxywrapper(self): + shared1 = multi_process_shared.MultiProcessShared( + Counter, tag='to_delete', always_proxy=True, spawn_process=True) + shared2 = multi_process_shared.MultiProcessShared( + Counter, tag='to_delete', always_proxy=True, spawn_process=True) + counter3 = multi_process_shared.MultiProcessShared( + Counter, tag='basic', always_proxy=True, spawn_process=True).acquire() + + counter1 = shared1.acquire() + counter2 = shared2.acquire() + self.assertEqual(counter1.increment(), 1) + self.assertEqual(counter2.increment(), 2) + + counter2.unsafe_hard_delete() + + with self.assertRaises(Exception): + counter1.get() + with self.assertRaises(Exception): + counter2.get() + + counter4 = multi_process_shared.MultiProcessShared( + Counter, tag='to_delete', always_proxy=True, + spawn_process=True).acquire() + + self.assertEqual(counter3.increment(), 1) + self.assertEqual(counter4.increment(), 1) + + def test_mix_usage(self): + shared1 = multi_process_shared.MultiProcessShared( + Counter, tag='mix1', always_proxy=True, spawn_process=False).acquire() + shared2 = multi_process_shared.MultiProcessShared( + Counter, tag='mix2', always_proxy=True, spawn_process=True).acquire() + + self.assertEqual(shared1.get(), 0) + self.assertEqual(shared1.increment(), 1) + self.assertEqual(shared2.get(), 0) + self.assertEqual(shared2.increment(), 1) + + def test_process_exits_on_unsafe_hard_delete(self): + shared = multi_process_shared.MultiProcessShared( + Counter, tag='test_process_exit', always_proxy=True, spawn_process=True) + obj = shared.acquire() + + self.assertEqual(obj.increment(), 1) + + children = multiprocessing.active_children() + server_process = None + for p in children: + if p.pid != os.getpid() and p.is_alive(): + server_process = p + break + + self.assertIsNotNone( + server_process, "Could not find spawned server process") + obj.unsafe_hard_delete() + server_process.join(timeout=5) + + self.assertFalse( + server_process.is_alive(), + f"Server process {server_process.pid} is still alive after hard delete") + self.assertIsNotNone( + server_process.exitcode, "Process has no exit code (did not exit)") + + with self.assertRaises(Exception): + obj.get() + + def test_process_exits_on_unsafe_hard_delete_with_manager(self): + shared = multi_process_shared.MultiProcessShared( + Counter, tag='test_process_exit', always_proxy=True, spawn_process=True) + obj = shared.acquire() + + self.assertEqual(obj.increment(), 1) + + children = multiprocessing.active_children() + server_process = None + for p in children: + if p.pid != os.getpid() and p.is_alive(): + server_process = p + break + + self.assertIsNotNone( + server_process, "Could not find spawned server process") + shared.unsafe_hard_delete() + server_process.join(timeout=5) + + self.assertFalse( + server_process.is_alive(), + f"Server process {server_process.pid} is still alive after hard delete") + self.assertIsNotNone( + server_process.exitcode, "Process has no exit code (did not exit)") + + with self.assertRaises(Exception): + obj.get() + + def test_zombie_reaping_on_acquire(self): + shared1 = multi_process_shared.MultiProcessShared( + Counter, tag='test_zombie_reap', always_proxy=True, spawn_process=True) + obj = shared1.acquire() + + children = multiprocessing.active_children() + server_pid = next( + p.pid for p in children if p.is_alive() and p.pid != os.getpid()) + + obj.unsafe_hard_delete() + + try: + os.kill(server_pid, 0) + is_zombie = True + except OSError: + is_zombie = False + self.assertTrue( + is_zombie, + f"Server process {server_pid} was reaped too early before acquire()") + + shared2 = multi_process_shared.MultiProcessShared( + Counter, tag='unrelated_tag', always_proxy=True, spawn_process=True) + _ = shared2.acquire() + + pid_exists = True + try: + os.kill(server_pid, 0) + except OSError: + pid_exists = False + + self.assertFalse( + pid_exists, + f"Old server process {server_pid} was not reaped by acquire() sweep") + shared2.unsafe_hard_delete() + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) From 0e8456ebf5289615f9951e603e8d056b92386b84 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Mon, 15 Dec 2025 21:47:09 +0000 Subject: [PATCH 2/2] Remove oom protection --- .../apache_beam/utils/multi_process_shared.py | 23 ++----------------- 1 file changed, 2 insertions(+), 21 deletions(-) diff --git a/sdks/python/apache_beam/utils/multi_process_shared.py b/sdks/python/apache_beam/utils/multi_process_shared.py index 0efa01f45570..1a7a751dba89 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared.py +++ b/sdks/python/apache_beam/utils/multi_process_shared.py @@ -117,25 +117,6 @@ def __dir__(self): return dir -def _run_with_oom_protection(func, *args, **kwargs): - try: - return func(*args, **kwargs) - except Exception as e: - # Check string to avoid hard import dependency - if 'CUDA out of memory' in str(e): - logging.warning("Caught CUDA OOM during operation. Cleaning memory.") - try: - import gc - import torch - gc.collect() - torch.cuda.empty_cache() - except ImportError: - pass - except Exception as cleanup_error: - logging.error("Failed to clean up CUDA memory: %s", cleanup_error) - raise e - - class _SingletonEntry: """Represents a single, refcounted entry in this process.""" def __init__( @@ -145,7 +126,7 @@ def __init__( self.refcount = 0 self.lock = threading.Lock() if initialize_eagerly: - self.obj = _run_with_oom_protection(constructor) + self.obj = constructor() self.initialied = True else: self.initialied = False @@ -153,7 +134,7 @@ def __init__( def acquire(self): with self.lock: if not self.initialied: - self.obj = _run_with_oom_protection(self.constructor) + self.obj = self.constructor() self.initialied = True self.refcount += 1 return _SingletonProxy(self)