diff --git a/sdks/python/apache_beam/runners/portability/portable_runner.py b/sdks/python/apache_beam/runners/portability/portable_runner.py index 94a467d5a249..e081185b5507 100644 --- a/sdks/python/apache_beam/runners/portability/portable_runner.py +++ b/sdks/python/apache_beam/runners/portability/portable_runner.py @@ -218,14 +218,12 @@ def run( """Run the job""" try: state_stream = self.job_service.GetStateStream( - beam_job_api_pb2.GetJobStateRequest(job_id=preparation_id), - timeout=self.timeout) + beam_job_api_pb2.GetJobStateRequest(job_id=preparation_id)) # If there's an error, we don't always get it until we try to read. # Fortunately, there's always an immediate current state published. state_stream = itertools.chain([next(state_stream)], state_stream) message_stream = self.job_service.GetMessageStream( - beam_job_api_pb2.JobMessagesRequest(job_id=preparation_id), - timeout=self.timeout) + beam_job_api_pb2.JobMessagesRequest(job_id=preparation_id)) except Exception: # TODO(https://github.com/apache/beam/issues/19284): Unify preparation_id # and job_id for all runners. diff --git a/sdks/python/apache_beam/utils/multi_process_shared.py b/sdks/python/apache_beam/utils/multi_process_shared.py index de4b94bc5da3..0c672910f257 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared.py +++ b/sdks/python/apache_beam/utils/multi_process_shared.py @@ -26,6 +26,7 @@ import logging import multiprocessing.managers import os +import socket import tempfile import threading import time @@ -85,6 +86,7 @@ def singletonProxy_release(self): def singletonProxy_unsafe_hard_delete(self): assert self._SingletonProxy_valid self._SingletonProxy_entry.unsafe_hard_delete() + self._SingletonProxy_valid = False def __getattr__(self, name): if not self._SingletonProxy_valid: @@ -231,6 +233,26 @@ def unsafe_hard_delete(self): self._proxyObject.unsafe_hard_delete() +def _wait_for_server_readiness(address, timeout=60): + start = time.time() + wait_secs = 0.1 + + while time.time() - start < timeout: + try: + s = socket.create_connection(address, timeout=wait_secs) + s.close() + return + except OSError: + wait_secs *= 1.2 + logging.log( + logging.WARNING if wait_secs > 1 else logging.DEBUG, + 'Waiting for server to be ready at %s', + address) + + raise RuntimeError( + f"Server at {address} failed to accept connections within {timeout}s") + + def _run_server_process(address_file, tag, constructor, authkey, life_line): """ Runs in a separate process. @@ -292,9 +314,14 @@ def _monitor_parent(): logging.info( 'Process %s: Proxy serving %s at %s', os.getpid(), tag, server.address) - with open(address_file + '.tmp', 'w') as fout: - fout.write('%s:%d' % server.address) - os.rename(address_file + '.tmp', address_file) + def publish_address(): + _wait_for_server_readiness(server.address) + with open(address_file + '.tmp', 'w') as fout: + fout.write('%s:%d' % server.address) + os.rename(address_file + '.tmp', address_file) + + t_pub = threading.Thread(target=publish_address) + t_pub.start() server.serve_forever() @@ -392,12 +419,30 @@ def _get_manager(self): manager = _SingletonRegistrar( address=(host, int(port)), authkey=AUTH_KEY) multiprocessing.current_process().authkey = AUTH_KEY - try: - manager.connect() - self._manager = manager - except ConnectionError: - # The server is no longer good, assume it died. - os.unlink(address_file) + last_error = None + for attempt in range( + 3): # Retry transient connection failures (e.g. CI) + try: + manager.connect() + self._manager = manager + last_error = None + break + except (ConnectionError, OSError) as e: + last_error = e + if attempt < 2: + time.sleep(0.2 * (attempt + 1)) + if self._manager is None and last_error is not None: + # Only unlink and retry from scratch if we use a separate server + # process; in-process server state would be stale and re-entry + # would raise. + if self._spawn_process: + logging.warning( + 'Connection to proxy at %s failed after retries: %s', + address, + last_error) + os.unlink(address_file) + else: + raise last_error return self._manager @@ -497,7 +542,12 @@ def cleanup_process(): "Shared Server Process died unexpectedly" f" with exit code {exit_code}") - if time.time() - last_log > 300: + if time.time() - start_time > 60: + if p.is_alive(): p.terminate() + raise RuntimeError( + "Shared Server Process failed to initialize within 60 seconds") + + if time.time() - last_log > 5: logging.warning( "Still waiting for %s to initialize... %ss elapsed)", self._tag, @@ -522,9 +572,12 @@ def cleanup_process(): 'Starting proxy server at %s for shared %s', self._server.address, self._tag) + t = threading.Thread(target=self._server.serve_forever, daemon=True) + t.start() + + _wait_for_server_readiness(self._server.address) + with open(address_file + '.tmp', 'w') as fout: fout.write('%s:%d' % self._server.address) os.rename(address_file + '.tmp', address_file) - t = threading.Thread(target=self._server.serve_forever, daemon=True) - t.start() logging.info('Done starting server') diff --git a/sdks/python/apache_beam/utils/multi_process_shared_test.py b/sdks/python/apache_beam/utils/multi_process_shared_test.py index 7b2b11857bfd..2c7758564f4e 100644 --- a/sdks/python/apache_beam/utils/multi_process_shared_test.py +++ b/sdks/python/apache_beam/utils/multi_process_shared_test.py @@ -19,10 +19,12 @@ import logging import multiprocessing import os +import shutil import tempfile import threading import unittest from typing import Any +from unittest import mock from apache_beam.utils import multi_process_shared @@ -460,6 +462,118 @@ def test_zombie_reaping_on_acquire(self): pass +class WaitForServerReadinessTest(unittest.TestCase): + def test_wait_for_server_readiness_timeout_raises(self): + with mock.patch.object(multi_process_shared.socket, + 'create_connection', + side_effect=OSError('connection refused')): + with self.assertRaises(RuntimeError) as ctx: + multi_process_shared._wait_for_server_readiness(('localhost', 12345), + timeout=0.2) + self.assertIn('failed to accept connections', str(ctx.exception)) + + def test_wait_for_server_readiness_success(self): + mock_socket = mock.Mock() + with mock.patch.object(multi_process_shared.socket, + 'create_connection', + return_value=mock_socket): + multi_process_shared._wait_for_server_readiness(('localhost', 12345), + timeout=1.0) + mock_socket.close.assert_called_once() + + def test_wait_for_server_readiness_retries_on_oserror_then_succeeds(self): + mock_socket = mock.Mock() + with mock.patch.object(multi_process_shared.socket, + 'create_connection', + side_effect=[OSError('refused'), + OSError('refused'), + mock_socket]): + multi_process_shared._wait_for_server_readiness(('localhost', 12345), + timeout=2.0) + mock_socket.close.assert_called_once() + + +class AutoProxyWrapperUnsafeHardDeleteTest(unittest.TestCase): + def test_wrapper_unsafe_hard_delete(self): + shared = multi_process_shared.MultiProcessShared( + Counter, tag='test_wrapper_unsafe_hard_delete', always_proxy=True) + obj = shared.acquire() + self.assertEqual(obj.get(), 0) + obj.increment() + try: + obj.unsafe_hard_delete() + except Exception: + pass + with self.assertRaises(Exception): + obj.get() + + +class GetManagerRetryTest(unittest.TestCase): + def setUp(self): + self.tempdir = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, self.tempdir, ignore_errors=True) + + def test_get_manager_retries_on_connection_error_then_succeeds(self): + address_file = os.path.join(self.tempdir, 'tag_retry') + '.address' + with open(address_file, 'w') as f: + f.write('127.0.0.1:0') + shared = multi_process_shared.MultiProcessShared( + Counter, tag='tag_retry', path=self.tempdir, always_proxy=True) + with mock.patch.object(multi_process_shared._SingletonRegistrar, + 'connect', + side_effect=[ConnectionError('refused'), None]): + manager = shared._get_manager() + self.assertIsNotNone(manager) + + def test_get_manager_raises_when_connection_fails_no_spawn(self): + address_file = os.path.join(self.tempdir, 'tag_fail') + '.address' + with open(address_file, 'w') as f: + f.write('127.0.0.1:99999') + shared = multi_process_shared.MultiProcessShared( + Counter, + tag='tag_fail', + path=self.tempdir, + always_proxy=True, + spawn_process=False) + + with mock.patch.object(multi_process_shared._SingletonRegistrar, + 'connect', + side_effect=ConnectionError('refused')): + with self.assertRaises((ConnectionError, OSError)): + shared._get_manager() + + def test_get_manager_unlinks_when_spawn_connection_fails(self): + address_file = os.path.join(self.tempdir, 'tag_spawn_fail') + '.address' + with open(address_file, 'w') as f: + f.write('127.0.0.1:99999') + shared = multi_process_shared.MultiProcessShared( + Counter, + tag='tag_spawn_fail', + path=self.tempdir, + always_proxy=False, + spawn_process=True) + unlink_calls = [] + + def track_unlink(path): + unlink_calls.append(path) + os.unlink(path) + + def mock_create_server(_): + multi_process_shared._process_level_singleton_manager.register_singleton( + Counter, shared._tag) + shared._manager = multi_process_shared._process_level_singleton_manager + + with mock.patch.object(multi_process_shared._SingletonRegistrar, + 'connect', + side_effect=ConnectionError('refused')): + with mock.patch.object(multi_process_shared.os, + 'unlink', + side_effect=track_unlink): + with mock.patch.object(shared, '_create_server', mock_create_server): + shared._get_manager() + self.assertGreater(len(unlink_calls), 0) + + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main() diff --git a/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py b/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py index f83697732598..843db0170b23 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py +++ b/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py @@ -55,6 +55,13 @@ def new_pipeline(): pickle_library='cloudpickle')) +def new_pipeline_expand_test(): + return beam.Pipeline( + runner='FnApiRunner', + options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) + + @unittest.skipIf(jsonschema is None, "Yaml dependencies not installed") class MainTest(unittest.TestCase): def assertYaml(self, expected, result): @@ -1048,7 +1055,7 @@ def test_expand_pipeline_with_pipeline_key_only(self): elements: [1,2,3] - type: LogForTesting ''' - with new_pipeline() as p: + with new_pipeline_expand_test() as p: expand_pipeline(p, spec, validate_schema=None) def test_expand_pipeline_with_pipeline_and_option_keys(self): @@ -1063,7 +1070,7 @@ def test_expand_pipeline_with_pipeline_and_option_keys(self): options: streaming: false ''' - with new_pipeline() as p: + with new_pipeline_expand_test() as p: expand_pipeline(p, spec, validate_schema=None) def test_expand_pipeline_with_extra_top_level_keys(self): @@ -1082,7 +1089,7 @@ def test_expand_pipeline_with_extra_top_level_keys(self): other_metadata: "This is an ignored comment." ''' - with new_pipeline() as p: + with new_pipeline_expand_test() as p: expand_pipeline(p, spec, validate_schema=None) def test_expand_pipeline_with_incorrect_pipelines_key_fails(self): @@ -1095,7 +1102,7 @@ def test_expand_pipeline_with_incorrect_pipelines_key_fails(self): elements: [1,2,3] - type: LogForTesting ''' - with new_pipeline() as p: + with new_pipeline_expand_test() as p: with self.assertRaises(KeyError): expand_pipeline(p, spec, validate_schema=None) @@ -1110,7 +1117,7 @@ def test_expand_pipeline_with_valid_schema(self): elements: [1,2,3] - type: LogForTesting ''' - with new_pipeline() as p: + with new_pipeline_expand_test() as p: expand_pipeline(p, spec, validate_schema='generic') @unittest.skipIf(jsonschema is None, "Yaml dependencies not installed") @@ -1124,7 +1131,7 @@ def test_expand_pipeline_with_invalid_schema(self): elements: [1,2,3] - type: LogForTesting ''' - with new_pipeline() as p: + with new_pipeline_expand_test() as p: with self.assertRaises(jsonschema.ValidationError): expand_pipeline(p, spec, validate_schema='generic') diff --git a/sdks/python/scripts/run_tox_cleanup.sh b/sdks/python/scripts/run_tox_cleanup.sh index be4409525b53..89f5a6c61810 100755 --- a/sdks/python/scripts/run_tox_cleanup.sh +++ b/sdks/python/scripts/run_tox_cleanup.sh @@ -35,7 +35,7 @@ set -e for dir in apache_beam target/build; do if [ -d "${dir}" ]; then for ext in pyc c so; do - find ${dir} -type f -name "*.${ext}" -delete + find ${dir} -type f -name "*.${ext}" -delete || true done fi done