diff --git a/sdks/python/apache_beam/options/pipeline_options.py b/sdks/python/apache_beam/options/pipeline_options.py index 38b36c3a2c45..0e1012b2de65 100644 --- a/sdks/python/apache_beam/options/pipeline_options.py +++ b/sdks/python/apache_beam/options/pipeline_options.py @@ -486,11 +486,12 @@ def get_all_options( drop_default=False, add_extra_args_fn: Optional[Callable[[_BeamArgumentParser], None]] = None, retain_unknown_options=False, - display_warnings=False) -> Dict[str, Any]: + display_warnings=False, + current_only=False, + ) -> Dict[str, Any]: """Returns a dictionary of all defined arguments. - Returns a dictionary of all defined arguments (arguments that are defined in - any subclass of PipelineOptions) into a dictionary. + Returns a dictionary of all defined arguments into a dictionary. Args: drop_default: If set to true, options that are equal to their default @@ -500,6 +501,9 @@ def get_all_options( retain_unknown_options: If set to true, options not recognized by any known pipeline options class will still be included in the result. If set to false, they will be discarded. + current_only: If set to true, only returns options defined in this class. + Otherwise, arguments that are defined in any subclass of PipelineOptions + are returned (default). Returns: Dictionary of all args and values. @@ -510,8 +514,11 @@ def get_all_options( # instance of each subclass to avoid conflicts. subset = {} parser = _BeamArgumentParser(allow_abbrev=False) - for cls in PipelineOptions.__subclasses__(): - subset.setdefault(str(cls), cls) + if current_only: + subset.setdefault(str(type(self)), type(self)) + else: + for cls in PipelineOptions.__subclasses__(): + subset.setdefault(str(cls), cls) for cls in subset.values(): cls._add_argparse_args(parser) # pylint: disable=protected-access if add_extra_args_fn: @@ -562,7 +569,7 @@ def add_new_arg(arg, **kwargs): continue parsed_args, _ = parser.parse_known_args(self._flags) else: - if unknown_args: + if unknown_args and not current_only: _LOGGER.warning("Discarding unparseable args: %s", unknown_args) parsed_args = known_args result = vars(parsed_args) @@ -580,7 +587,7 @@ def add_new_arg(arg, **kwargs): if overrides: if retain_unknown_options: result.update(overrides) - else: + elif not current_only: _LOGGER.warning("Discarding invalid overrides: %s", overrides) return result diff --git a/sdks/python/apache_beam/options/pipeline_options_test.py b/sdks/python/apache_beam/options/pipeline_options_test.py index 705e8e1e2c04..c683c9625272 100644 --- a/sdks/python/apache_beam/options/pipeline_options_test.py +++ b/sdks/python/apache_beam/options/pipeline_options_test.py @@ -238,6 +238,19 @@ def test_get_all_options(self, flags, expected, _): options.view_as(PipelineOptionsTest.MockOptions).mock_multi_option, expected['mock_multi_option']) + def test_get_superclass_options(self): + flags = ["--mock_option", "mock", "--fake_option", "fake"] + options = PipelineOptions(flags=flags).view_as( + PipelineOptionsTest.FakeOptions) + items = options.get_all_options(current_only=True).items() + print(items) + self.assertTrue(('fake_option', 'fake') in items) + self.assertFalse(('mock_option', 'mock') in items) + items = options.view_as(PipelineOptionsTest.MockOptions).get_all_options( + current_only=True).items() + self.assertFalse(('fake_option', 'fake') in items) + self.assertTrue(('mock_option', 'mock') in items) + @parameterized.expand(TEST_CASES) def test_subclasses_of_pipeline_options_can_be_instantiated( self, flags, expected, _): diff --git a/sdks/python/apache_beam/runners/dask/dask_runner.py b/sdks/python/apache_beam/runners/dask/dask_runner.py index 8975fcf1e138..bc915d300857 100644 --- a/sdks/python/apache_beam/runners/dask/dask_runner.py +++ b/sdks/python/apache_beam/runners/dask/dask_runner.py @@ -236,7 +236,7 @@ def run_pipeline(self, pipeline, options): 'DaskRunner is not available. Please install apache_beam[dask].') dask_options = options.view_as(DaskOptions).get_all_options( - drop_default=True) + drop_default=True, current_only=True) bag_kwargs = DaskOptions._extract_bag_kwargs(dask_options) client = ddist.Client(**dask_options)