Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
249 changes: 224 additions & 25 deletions sdks/python/apache_beam/utils/multi_process_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
import logging
import multiprocessing.managers
import os
import time
import traceback
import atexit
import sys
import tempfile
import threading
from typing import Any
Expand Down Expand Up @@ -79,6 +83,10 @@ def singletonProxy_release(self):
assert self._SingletonProxy_valid
self._SingletonProxy_valid = False

def unsafe_hard_delete(self):
assert self._SingletonProxy_valid
self._SingletonProxy_entry.unsafe_hard_delete()

def __getattr__(self, name):
if not self._SingletonProxy_valid:
raise RuntimeError('Entry was released.')
Expand All @@ -105,13 +113,16 @@ def __dir__(self):
dir = self._SingletonProxy_entry.obj.__dir__()
dir.append('singletonProxy_call__')
dir.append('singletonProxy_release')
dir.append('unsafe_hard_delete')
return dir


class _SingletonEntry:
"""Represents a single, refcounted entry in this process."""
def __init__(self, constructor, initialize_eagerly=True):
def __init__(
self, constructor, initialize_eagerly=True, hard_delete_callback=None):
self.constructor = constructor
self._hard_delete_callback = hard_delete_callback
self.refcount = 0
self.lock = threading.Lock()
if initialize_eagerly:
Expand Down Expand Up @@ -141,14 +152,28 @@ def unsafe_hard_delete(self):
if self.initialied:
del self.obj
self.initialied = False
if self._hard_delete_callback:
self._hard_delete_callback()


class _SingletonManager:
entries: Dict[Any, Any] = {}

def register_singleton(self, constructor, tag, initialize_eagerly=True):
def __init__(self):
self._hard_delete_callback = None

def set_hard_delete_callback(self, callback):
self._hard_delete_callback = callback

def register_singleton(
self,
constructor,
tag,
initialize_eagerly=True,
hard_delete_callback=None):
assert tag not in self.entries, tag
self.entries[tag] = _SingletonEntry(constructor, initialize_eagerly)
self.entries[tag] = _SingletonEntry(
constructor, initialize_eagerly, hard_delete_callback)

def has_singleton(self, tag):
return tag in self.entries
Expand All @@ -160,7 +185,8 @@ def release_singleton(self, tag, obj):
return self.entries[tag].release(obj)

def unsafe_hard_delete_singleton(self, tag):
return self.entries[tag].unsafe_hard_delete()
self.entries[tag].unsafe_hard_delete()
self._hard_delete_callback()


_process_level_singleton_manager = _SingletonManager()
Expand Down Expand Up @@ -200,9 +226,99 @@ def __call__(self, *args, **kwargs):
def __getattr__(self, name):
return getattr(self._proxyObject, name)

def __setstate__(self, state):
self.__dict__.update(state)

def __getstate__(self):
return self.__dict__

def get_auto_proxy_object(self):
return self._proxyObject

def unsafe_hard_delete(self):
try:
self._proxyObject.unsafe_hard_delete()
except (EOFError, ConnectionResetError, BrokenPipeError):
pass
except Exception as e:
logging.warning(
"Exception %s when trying to hard delete shared object proxy", e)


def _run_server_process(address_file, tag, constructor, authkey):
"""
Runs in a separate process.
Includes a 'Suicide Pact' monitor: If parent dies, I die.
"""
parent_pid = os.getppid()

def cleanup_files():
logging.info("Server process exiting. Deleting files for %s", tag)
try:
if os.path.exists(address_file):
os.remove(address_file)
if os.path.exists(address_file + ".error"):
os.remove(address_file + ".error")
except Exception:
pass

def handle_unsafe_hard_delete():
cleanup_files()
os._exit(0)

def _monitor_parent():
"""Checks if parent is alive every second."""
while True:
try:
os.kill(parent_pid, 0)
except OSError:
logging.warning(
"Process %s detected Parent %s died. Self-destructing.",
os.getpid(),
parent_pid)
cleanup_files()
os._exit(0)
time.sleep(0.5)

atexit.register(cleanup_files)

try:
t = threading.Thread(target=_monitor_parent, daemon=True)
t.start()

logging.getLogger().setLevel(logging.INFO)
multiprocessing.current_process().authkey = authkey

serving_manager = _SingletonRegistrar(
address=('localhost', 0), authkey=authkey)
_process_level_singleton_manager.set_hard_delete_callback(
handle_unsafe_hard_delete)
_process_level_singleton_manager.register_singleton(
constructor,
tag,
initialize_eagerly=True,
hard_delete_callback=handle_unsafe_hard_delete)

server = serving_manager.get_server()
logging.info(
'Process %s: Proxy serving %s at %s', os.getpid(), tag, server.address)

with open(address_file + '.tmp', 'w') as fout:
fout.write('%s:%d' % server.address)
os.rename(address_file + '.tmp', address_file)

server.serve_forever()

except Exception:
tb = traceback.format_exc()
try:
with open(address_file + ".error.tmp", 'w') as fout:
fout.write(tb)
os.rename(address_file + ".error.tmp", address_file + ".error")
except Exception:
print(f"CRITICAL ERROR IN SHARED SERVER:\n{tb}", file=sys.stderr)
os._exit(1)


class MultiProcessShared(Generic[T]):
"""MultiProcessShared is used to share a single object across processes.
Expand Down Expand Up @@ -252,7 +368,8 @@ def __init__(
tag: Any,
*,
path: str = tempfile.gettempdir(),
always_proxy: Optional[bool] = None):
always_proxy: Optional[bool] = None,
spawn_process: bool = False):
self._constructor = constructor
self._tag = tag
self._path = path
Expand All @@ -262,6 +379,7 @@ def __init__(
self._rpc_address = None
self._cross_process_lock = fasteners.InterProcessLock(
os.path.join(self._path, self._tag) + '.lock')
self._spawn_process = spawn_process

def _get_manager(self):
if self._manager is None:
Expand Down Expand Up @@ -301,6 +419,10 @@ def acquire(self):
# Caveat: They must always agree, as they will be ignored if the object
# is already constructed.
singleton = self._get_manager().acquire_singleton(self._tag)
# Trigger a sweep of zombie processes.
# calling active_children() has the side-effect of joining any finished
# processes, effectively reaping zombies from previous unsafe_hard_deletes.
if self._spawn_process: multiprocessing.active_children()
return _AutoProxyWrapper(singleton)

def release(self, obj):
Expand All @@ -315,25 +437,102 @@ def unsafe_hard_delete(self):
to this object exist, or (b) you are ok with all existing references to
this object throwing strange errors when derefrenced.
"""
self._get_manager().unsafe_hard_delete_singleton(self._tag)
try:
self._get_manager().unsafe_hard_delete_singleton(self._tag)
except (EOFError, ConnectionResetError, BrokenPipeError):
pass
except Exception as e:
logging.warning(
"Exception %s when trying to hard delete shared object %s",
e,
self._tag)

def _create_server(self, address_file):
# We need to be able to authenticate with both the manager and the process.
self._serving_manager = _SingletonRegistrar(
address=('localhost', 0), authkey=AUTH_KEY)
multiprocessing.current_process().authkey = AUTH_KEY
# Initialize eagerly to avoid acting as the server if there are issues.
# Note, however, that _create_server itself is called lazily.
_process_level_singleton_manager.register_singleton(
self._constructor, self._tag, initialize_eagerly=True)
self._server = self._serving_manager.get_server()
logging.info(
'Starting proxy server at %s for shared %s',
self._server.address,
self._tag)
with open(address_file + '.tmp', 'w') as fout:
fout.write('%s:%d' % self._server.address)
os.rename(address_file + '.tmp', address_file)
t = threading.Thread(target=self._server.serve_forever, daemon=True)
t.start()
logging.info('Done starting server')
if self._spawn_process:
error_file = address_file + ".error"

if os.path.exists(error_file):
try:
os.remove(error_file)
except OSError:
pass

ctx = multiprocessing.get_context('spawn')
p = ctx.Process(
target=_run_server_process,
args=(address_file, self._tag, self._constructor, AUTH_KEY),
daemon=False # Must be False for nested proxies
)
p.start()
logging.info("Parent: Waiting for %s to write address file...", self._tag)

def cleanup_process():
if p.is_alive():
logging.info(
"Parent: Terminating server process %s for %s", p.pid, self._tag)
p.terminate()
p.join()
try:
if os.path.exists(address_file):
os.remove(address_file)
if os.path.exists(error_file):
os.remove(error_file)
except Exception:
pass

atexit.register(cleanup_process)

start_time = time.time()
last_log = start_time
while True:
if os.path.exists(address_file):
break

if os.path.exists(error_file):
with open(error_file, 'r') as f:
error_msg = f.read()
try:
os.remove(error_file)
except OSError:
pass

if p.is_alive(): p.terminate()
raise RuntimeError(f"Shared Server Process crashed:\n{error_msg}")

if not p.is_alive():
exit_code = p.exitcode
raise RuntimeError(
"Shared Server Process died unexpectedly"
f" with exit code {exit_code}")

if time.time() - last_log > 300:
logging.warning(
"Still waiting for %s to initialize... %ss elapsed)",
self._tag,
int(time.time() - start_time))
last_log = time.time()

time.sleep(0.05)

logging.info('External process successfully started for %s', self._tag)
else:
# We need to be able to authenticate with both the manager
# and the process.
self._serving_manager = _SingletonRegistrar(
address=('localhost', 0), authkey=AUTH_KEY)
multiprocessing.current_process().authkey = AUTH_KEY
# Initialize eagerly to avoid acting as the server if there are issues.
# Note, however, that _create_server itself is called lazily.
_process_level_singleton_manager.register_singleton(
self._constructor, self._tag, initialize_eagerly=True)
self._server = self._serving_manager.get_server()
logging.info(
'Starting proxy server at %s for shared %s',
self._server.address,
self._tag)
with open(address_file + '.tmp', 'w') as fout:
fout.write('%s:%d' % self._server.address)
os.rename(address_file + '.tmp', address_file)
t = threading.Thread(target=self._server.serve_forever, daemon=True)
t.start()
logging.info('Done starting server')
Loading
Loading