diff --git a/sdks/python/apache_beam/options/pipeline_options_context.py b/sdks/python/apache_beam/options/pipeline_options_context.py new file mode 100644 index 000000000000..5f2475e334af --- /dev/null +++ b/sdks/python/apache_beam/options/pipeline_options_context.py @@ -0,0 +1,65 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Context-scoped access to pipeline options during graph construction and +translation. + +This module provides thread-safe and async-safe access to globally-available +instances of PipelineOptions using contextvars, scoped by the current pipeline. +It allows components like transforms and coders to access the pipeline's +configuration without requiring explicit parameter passing through every +level of the call stack. + +For internal use only; no backwards-compatibility guarantees. +""" + +from contextlib import contextmanager +from contextvars import ContextVar +from typing import TYPE_CHECKING +from typing import Optional + +if TYPE_CHECKING: + from apache_beam.options.pipeline_options import PipelineOptions + +# The contextvar holding the current pipeline's options. +# Each thread and each asyncio task gets its own isolated copy. +_pipeline_options: ContextVar[Optional['PipelineOptions']] = ContextVar( + 'pipeline_options', default=None) + + +def get_pipeline_options() -> Optional['PipelineOptions']: + """Get the current pipeline's options from the context. + + Returns: + The PipelineOptions for the currently executing pipeline operation, + or None if called outside of a pipeline context. + """ + return _pipeline_options.get() + + +@contextmanager +def scoped_pipeline_options(options: Optional['PipelineOptions']): + """Context manager that sets pipeline options for the duration of a block. + + Args: + options: The PipelineOptions to make available during this scope. + """ + token = _pipeline_options.set(options) + try: + yield + finally: + _pipeline_options.reset(token) diff --git a/sdks/python/apache_beam/options/pipeline_options_context_test.py b/sdks/python/apache_beam/options/pipeline_options_context_test.py new file mode 100644 index 000000000000..da9c8fce4aab --- /dev/null +++ b/sdks/python/apache_beam/options/pipeline_options_context_test.py @@ -0,0 +1,335 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Tests for pipeline_options_context module. + +These tests verify that the contextvar-based approach properly isolates +pipeline options across threads and async tasks, preventing race conditions. +""" + +import asyncio +import threading +import unittest + +import apache_beam as beam +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options_context import get_pipeline_options +from apache_beam.options.pipeline_options_context import scoped_pipeline_options + + +class PipelineConstructionOptionsTest(unittest.TestCase): + def test_nested_scoping(self): + """Test that nested scopes properly restore outer options.""" + outer_options = PipelineOptions(['--runner=DirectRunner']) + inner_options = PipelineOptions(['--runner=DataflowRunner']) + + with scoped_pipeline_options(outer_options): + self.assertIs(get_pipeline_options(), outer_options) + + with scoped_pipeline_options(inner_options): + self.assertIs(get_pipeline_options(), inner_options) + + self.assertIs(get_pipeline_options(), outer_options) + + self.assertIsNone(get_pipeline_options()) + + def test_exception_in_scope_restores_options(self): + """Test that options are restored even when an exception is raised.""" + outer_options = PipelineOptions(['--runner=DirectRunner']) + inner_options = PipelineOptions(['--runner=DataflowRunner']) + + with scoped_pipeline_options(outer_options): + try: + with scoped_pipeline_options(inner_options): + self.assertIs(get_pipeline_options(), inner_options) + raise ValueError("Test exception") + except ValueError: + pass + + self.assertIs(get_pipeline_options(), outer_options) + + def test_different_threads_see_their_own_isolated_options(self): + """Test that different threads see their own isolated options.""" + results = {} + errors = [] + barrier = threading.Barrier(2) + + def thread_worker(thread_id, runner_name): + try: + options = PipelineOptions([f'--runner={runner_name}']) + with scoped_pipeline_options(options): + barrier.wait(timeout=5) + + current = get_pipeline_options() + results[thread_id] = current.get_all_options()['runner'] + import time + time.sleep(0.01) + + current_after = get_pipeline_options() + if current_after is not current: + errors.append( + f"Thread {thread_id}: options changed during execution") + except Exception as e: + errors.append(f"Thread {thread_id}: {e}") + + thread1 = threading.Thread(target=thread_worker, args=(1, 'DirectRunner')) + thread2 = threading.Thread(target=thread_worker, args=(2, 'DataflowRunner')) + + thread1.start() + thread2.start() + + thread1.join(timeout=5) + thread2.join(timeout=5) + + self.assertEqual(errors, []) + self.assertEqual(results[1], 'DirectRunner') + self.assertEqual(results[2], 'DataflowRunner') + + def test_asyncio_task_isolation(self): + """Test that different asyncio tasks see their own isolated options.""" + async def async_worker( + task_id, runner_name, results, ready_event, go_event): + options = PipelineOptions([f'--runner={runner_name}']) + with scoped_pipeline_options(options): + ready_event.set() + await go_event.wait() + current = get_pipeline_options() + results[task_id] = current.get_all_options()['runner'] + await asyncio.sleep(0.01) + current_after = get_pipeline_options() + assert current_after is current, \ + f"Task {task_id}: options changed during execution" + + async def run_test(): + results = {} + ready_events = [asyncio.Event() for _ in range(2)] + go_event = asyncio.Event() + + task1 = asyncio.create_task( + async_worker(1, 'DirectRunner', results, ready_events[0], go_event)) + task2 = asyncio.create_task( + async_worker(2, 'DataflowRunner', results, ready_events[1], go_event)) + + # Wait for both tasks to be ready + await asyncio.gather(*[e.wait() for e in ready_events]) + # Signal all tasks to proceed + go_event.set() + + await asyncio.gather(task1, task2) + return results + + results = asyncio.run(run_test()) + self.assertEqual(results[1], 'DirectRunner') + self.assertEqual(results[2], 'DataflowRunner') + + def test_transform_sees_pipeline_options(self): + """Test that a transform can access pipeline options during expand().""" + class OptionsCapturingTransform(beam.PTransform): + """Transform that captures pipeline options during expand().""" + def __init__(self, expected_job_name): + self.expected_job_name = expected_job_name + self.captured_options = None + + def expand(self, pcoll): + # This runs during pipeline construction + self.captured_options = get_pipeline_options() + return pcoll | beam.Map(lambda x: x) + + options = PipelineOptions(['--job_name=test_job_123']) + transform = OptionsCapturingTransform('test_job_123') + + with beam.Pipeline(options=options) as p: + _ = p | beam.Create([1, 2, 3]) | transform + + # Verify the transform saw the correct options + self.assertIsNotNone(transform.captured_options) + self.assertEqual( + transform.captured_options.get_all_options()['job_name'], + 'test_job_123') + + def test_coder_sees_correct_options_during_run(self): + """Test that coders see correct pipeline options during proto conversion. + + This tests the run path where as_deterministic_coder() is called during + to_runner_api() proto conversion. + """ + from apache_beam.coders import coders + from apache_beam.utils import shared + + errors = [] + + class WeakRefDict(dict): + pass + + class TestKey: + def __init__(self, value): + self.value = value + + def __eq__(self, other): + return isinstance(other, TestKey) and self.value == other.value + + def __hash__(self): + return hash(self.value) + + class OptionsCapturingKeyCoder(coders.Coder): + """Coder that captures pipeline options in as_deterministic_coder.""" + shared_handle = shared.Shared() + + def encode(self, value): + return str(value.value).encode('utf-8') + + def decode(self, encoded): + return TestKey(encoded.decode('utf-8')) + + def is_deterministic(self): + return False + + def as_deterministic_coder(self, step_label, error_message=None): + opts = get_pipeline_options() + if opts is not None: + results = OptionsCapturingKeyCoder.shared_handle.acquire(WeakRefDict) + job_name = opts.get_all_options().get('job_name') + results['Worker1'] = job_name + return self + + beam.coders.registry.register_coder(TestKey, OptionsCapturingKeyCoder) + + results = OptionsCapturingKeyCoder.shared_handle.acquire(WeakRefDict) + + job_name = 'gbk_job' + options = PipelineOptions([f'--job_name={job_name}']) + + with beam.Pipeline(options=options) as p: + _ = ( + p + | beam.Create([(TestKey(1), 'a'), (TestKey(2), 'b')]) + | beam.GroupByKey()) + + self.assertEqual(errors, [], f"Errors occurred: {errors}") + self.assertEqual( + results.get('Worker1'), + job_name, + f"Worker1 saw wrong options: {results}") + self.assertFalse(get_pipeline_options() == options) + + def test_barrier_inside_default_type_hints(self): + """Test race condition detection with barrier inside default_type_hints. + + This test reliably detects race conditions because: + 1. Both threads start pipeline construction simultaneously + 2. Inside default_type_hints (which is called during Pipeline.apply()), + both threads hit a barrier and wait for each other + 3. At this point, BOTH threads are inside scoped_pipeline_options + 4. When they continue, they read options - with a global var, they'd see + the wrong values because the last thread to set options would win + """ + + results = {} + errors = [] + inner_barrier = threading.Barrier(2) + + class BarrierTransform(beam.PTransform): + """Transform that synchronizes threads INSIDE default_type_hints.""" + def __init__(self, worker_id, results_dict, barrier): + self.worker_id = worker_id + self.results_dict = results_dict + self.barrier = barrier + + def expand(self, pcoll): + return pcoll | beam.Map(lambda x: x) + + def default_type_hints(self): + self.barrier.wait(timeout=5) + opts = get_pipeline_options() + if opts is not None: + job_name = opts.get_all_options().get('job_name') + if self.worker_id not in self.results_dict: + self.results_dict[self.worker_id] = job_name + + return super().default_type_hints() + + def construct_pipeline(worker_id): + try: + job_name = f'barrier_job_{worker_id}' + options = PipelineOptions([f'--job_name={job_name}']) + transform = BarrierTransform(worker_id, results, inner_barrier) + + with beam.Pipeline(options=options) as p: + _ = p | beam.Create([1, 2, 3]) | transform + + except Exception as e: + import traceback + errors.append(f"Worker {worker_id}: {e}\n{traceback.format_exc()}") + + thread1 = threading.Thread( + target=construct_pipeline, args=(1, )) + thread2 = threading.Thread( + target=construct_pipeline, args=(2, )) + + thread1.start() + thread2.start() + + thread1.join(timeout=10) + thread2.join(timeout=10) + + self.assertEqual(errors, [], f"Errors occurred: {errors}") + self.assertEqual( + results.get(1), + 'barrier_job_1', + f"Worker 1 saw wrong options: {results}") + self.assertEqual( + results.get(2), + 'barrier_job_2', + f"Worker 2 saw wrong options: {results}") + + +class PipelineSubclassApplyTest(unittest.TestCase): + def test_subclass_apply_called_on_recursive_paths(self): + """Test that Pipeline subclass overrides of apply() are respected. + + _apply_internal's recursive calls must go through self.apply(), not + self._apply_internal(), so that subclass interceptions are not skipped. + """ + apply_calls = [] + + class TrackingPipeline(beam.Pipeline): + def apply(self, transform, pvalueish=None, label=None): + apply_calls.append(label or transform.label) + return super().apply(transform, pvalueish, label) + + options = PipelineOptions(['--job_name=subclass_test']) + with TrackingPipeline(options=options) as p: + # "my_label" >> transform creates a _NamedPTransform, which triggers + # two recursive apply() calls: one to unwrap _NamedPTransform, and + # one to handle the label argument. + _ = p | beam.Create([1, 2, 3]) | "my_label" >> beam.Map(lambda x: x) + + # beam.Create goes through apply() once (no recursion). + # "my_label" >> Map triggers: apply(_NamedPTransform) -> apply(Map, + # label="my_label") -> apply(Map). That's 3 calls through apply(). + # Total: 1 (Create) + 3 (Map) = 4 calls minimum. + map_calls = [c for c in apply_calls if c == 'my_label' or c == 'Map'] + self.assertGreaterEqual( + len(map_calls), + 3, + f"Expected at least 3 apply() calls for the Map transform " + f"(NamedPTransform unwrap + label handling + final), " + f"got {len(map_calls)}. All calls: {apply_calls}") + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/apache_beam/pipeline.py b/sdks/python/apache_beam/pipeline.py index 6ef06abb7436..a6080f2f3e7f 100644 --- a/sdks/python/apache_beam/pipeline.py +++ b/sdks/python/apache_beam/pipeline.py @@ -81,6 +81,7 @@ from apache_beam.options.pipeline_options import StandardOptions from apache_beam.options.pipeline_options import StreamingOptions from apache_beam.options.pipeline_options import TypeOptions +from apache_beam.options.pipeline_options_context import scoped_pipeline_options from apache_beam.options.pipeline_options_validator import PipelineOptionsValidator from apache_beam.portability import common_urns from apache_beam.portability.api import beam_runner_api_pb2 @@ -559,6 +560,12 @@ def replace_all(self, replacements: Iterable['PTransformOverride']) -> None: def run(self, test_runner_api: Union[bool, str] = 'AUTO') -> 'PipelineResult': """Runs the pipeline. Returns whatever our runner returns after running.""" + with scoped_pipeline_options(self._options): + return self._run_internal(test_runner_api) + + def _run_internal( + self, test_runner_api: Union[bool, str] = 'AUTO') -> 'PipelineResult': + """Internal implementation of run(), called within scoped options.""" # All pipeline options are finalized at this point. # Call get_all_options to print warnings on invalid options. self.options.get_all_options( @@ -698,6 +705,15 @@ def apply( RuntimeError: if the transform object was already applied to this pipeline and needs to be cloned in order to apply again. """ + with scoped_pipeline_options(self._options): + return self._apply_internal(transform, pvalueish, label) + + def _apply_internal( + self, + transform: ptransform.PTransform, + pvalueish: Optional[pvalue.PValue] = None, + label: Optional[str] = None) -> pvalue.PValue: + """Internal implementation of apply(), called within scoped options.""" if isinstance(transform, ptransform._NamedPTransform): return self.apply( transform.transform, pvalueish, label or transform.label)