diff --git a/.github/trigger_files/beam_PostCommit_Python.json b/.github/trigger_files/beam_PostCommit_Python.json index c6ec17f48412..d6818d275f1f 100644 --- a/.github/trigger_files/beam_PostCommit_Python.json +++ b/.github/trigger_files/beam_PostCommit_Python.json @@ -1,5 +1,6 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run.", + "pr": "36271", "modification": 35 } diff --git a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Dataflow.json b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Dataflow.json index 99a8fc8ff6d5..b60f5c4cc3c8 100644 --- a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Dataflow.json +++ b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Dataflow.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "modification": 14 + "modification": 0 } diff --git a/.github/trigger_files/beam_PostCommit_Python_Xlang_IO_Dataflow.json b/.github/trigger_files/beam_PostCommit_Python_Xlang_IO_Dataflow.json index f1ba03a243ee..b60f5c4cc3c8 100644 --- a/.github/trigger_files/beam_PostCommit_Python_Xlang_IO_Dataflow.json +++ b/.github/trigger_files/beam_PostCommit_Python_Xlang_IO_Dataflow.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "modification": 5 + "modification": 0 } diff --git a/CHANGES.md b/CHANGES.md index bab7182539e6..dfad320a694d 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -85,6 +85,7 @@ ## Bugfixes * Fixed FirestoreV1 Beam connectors allow configuring inconsistent project/database IDs between RPC requests and routing headers #36895 (Java) ([#36895](https://github.com/apache/beam/issues/36895)). + Logical type and coder registry are saved for pipelines in the case of default pickler. This fixes a side effect of switching to cloudpickle as default pickler in Beam 2.65.0 (Python) ([#35738](https://github.com/apache/beam/issues/35738)). ## Known Issues diff --git a/sdks/python/apache_beam/coders/typecoders.py b/sdks/python/apache_beam/coders/typecoders.py index 779c65dc772c..ef75a21ce9ef 100644 --- a/sdks/python/apache_beam/coders/typecoders.py +++ b/sdks/python/apache_beam/coders/typecoders.py @@ -114,6 +114,14 @@ def _register_coder_internal( typehint_coder_class: Type[coders.Coder]) -> None: self._coders[typehint_type] = typehint_coder_class + @staticmethod + def _normalize_typehint_type(typehint_type): + if typehint_type.__module__ == '__main__': + # See https://github.com/apache/beam/issues/21541 + # TODO(robertwb): Remove once all runners are portable. + return getattr(typehint_type, '__name__', str(typehint_type)) + return typehint_type + def register_coder( self, typehint_type: Any, typehint_coder_class: Type[coders.Coder]) -> None: @@ -123,11 +131,8 @@ def register_coder( 'Received %r instead.' % typehint_coder_class) if typehint_type not in self.custom_types: self.custom_types.append(typehint_type) - if typehint_type.__module__ == '__main__': - # See https://github.com/apache/beam/issues/21541 - # TODO(robertwb): Remove once all runners are portable. - typehint_type = getattr(typehint_type, '__name__', str(typehint_type)) - self._register_coder_internal(typehint_type, typehint_coder_class) + self._register_coder_internal( + self._normalize_typehint_type(typehint_type), typehint_coder_class) def get_coder(self, typehint: Any) -> coders.Coder: if typehint and typehint.__module__ == '__main__': @@ -170,9 +175,15 @@ def get_coder(self, typehint: Any) -> coders.Coder: coder = self._fallback_coder return coder.from_type_hint(typehint, self) - def get_custom_type_coder_tuples(self, types): + def get_custom_type_coder_tuples(self, types=None): """Returns type/coder tuples for all custom types passed in.""" - return [(t, self._coders[t]) for t in types if t in self.custom_types] + return [(t, self._coders[self._normalize_typehint_type(t)]) + for t in self.custom_types if (types is None or t in types)] + + def load_custom_type_coder_tuples(self, type_coder): + """Load type/coder tuples into coder registry.""" + for t, c in type_coder: + self.register_coder(t, c) def verify_deterministic(self, key_coder, op_name, silent=True): if not key_coder.is_deterministic(): diff --git a/sdks/python/apache_beam/internal/cloudpickle_pickler.py b/sdks/python/apache_beam/internal/cloudpickle_pickler.py index 199294f1731d..acdcc46cd40d 100644 --- a/sdks/python/apache_beam/internal/cloudpickle_pickler.py +++ b/sdks/python/apache_beam/internal/cloudpickle_pickler.py @@ -252,12 +252,35 @@ def _lock_reducer(obj): def dump_session(file_path): - # It is possible to dump session with cloudpickle. However, since references - # are saved it should not be necessary. See https://s.apache.org/beam-picklers - pass + # Since References are saved (https://s.apache.org/beam-picklers), we only + # dump supported Beam Registries (currently only logical type registry) + from apache_beam.coders import typecoders + from apache_beam.typehints import schemas + + with _pickle_lock, open(file_path, 'wb') as file: + coder_reg = typecoders.registry.get_custom_type_coder_tuples() + logical_type_reg = schemas.LogicalType._known_logical_types.copy_custom() + + pickler = cloudpickle.CloudPickler(file) + # TODO(https://github.com/apache/beam/issues/18500) add file system registry + # once implemented + pickler.dump({"coder": coder_reg, "logical_type": logical_type_reg}) def load_session(file_path): - # It is possible to load_session with cloudpickle. However, since references - # are saved it should not be necessary. See https://s.apache.org/beam-picklers - pass + from apache_beam.coders import typecoders + from apache_beam.typehints import schemas + + with _pickle_lock, open(file_path, 'rb') as file: + registries = cloudpickle.load(file) + if type(registries) != dict: + raise ValueError( + "Faled loading session: expected dict, got {}", type(registries)) + if "coder" in registries: + typecoders.registry.load_custom_type_coder_tuples(registries["coder"]) + else: + _LOGGER.warning('No coder registry found in saved session') + if "logical_type" in registries: + schemas.LogicalType._known_logical_types.load(registries["logical_type"]) + else: + _LOGGER.warning('No logical type registry found in saved session') diff --git a/sdks/python/apache_beam/internal/cloudpickle_pickler_test.py b/sdks/python/apache_beam/internal/cloudpickle_pickler_test.py index 4a51c56c24be..99fbb03ac2e4 100644 --- a/sdks/python/apache_beam/internal/cloudpickle_pickler_test.py +++ b/sdks/python/apache_beam/internal/cloudpickle_pickler_test.py @@ -20,6 +20,7 @@ # pytype: skip-file import os +import tempfile import threading import types import unittest @@ -31,6 +32,7 @@ from apache_beam.internal import module_test from apache_beam.internal.cloudpickle_pickler import dumps from apache_beam.internal.cloudpickle_pickler import loads +from apache_beam.typehints.schemas import LogicalTypeRegistry from apache_beam.utils import shared GLOBAL_DICT_REF = module_test.GLOBAL_DICT @@ -244,6 +246,24 @@ def sample_func(): unpickled_filename = os.path.abspath(unpickled_code.co_filename) self.assertEqual(unpickled_filename, original_filename) + @mock.patch( + "apache_beam.coders.typecoders.registry.load_custom_type_coder_tuples") + @mock.patch( + "apache_beam.typehints.schemas.LogicalType._known_logical_types.load") + def test_dump_load_session(self, logicaltype_mock, coder_mock): + session_file = 'pickled' + + with tempfile.TemporaryDirectory() as tmp_dirname: + pickled_session_file = os.path.join(tmp_dirname, session_file) + beam_cloudpickle.dump_session(pickled_session_file) + beam_cloudpickle.load_session(pickled_session_file) + load_logical_types = logicaltype_mock.call_args.args + load_coders = coder_mock.call_args.args + self.assertEqual(len(load_logical_types), 1) + self.assertEqual(len(load_coders), 1) + self.assertTrue(isinstance(load_logical_types[0], LogicalTypeRegistry)) + self.assertTrue(isinstance(load_coders[0], list)) + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/internal/pickler.py b/sdks/python/apache_beam/internal/pickler.py index 0af3b16ec053..3626b599a5c4 100644 --- a/sdks/python/apache_beam/internal/pickler.py +++ b/sdks/python/apache_beam/internal/pickler.py @@ -91,6 +91,14 @@ def load_session(file_path): return desired_pickle_lib.load_session(file_path) +def is_currently_dill(): + return desired_pickle_lib == dill_pickler + + +def is_currently_cloudpickle(): + return desired_pickle_lib == cloudpickle_pickler + + def set_library(selected_library=DEFAULT_PICKLE_LIB): """ Sets pickle library that will be used. """ global desired_pickle_lib @@ -108,12 +116,11 @@ def set_library(selected_library=DEFAULT_PICKLE_LIB): "Pipeline option pickle_library=dill_unsafe is set, but dill is not " "installed. Install dill in job submission and runtime environments.") - is_currently_dill = (desired_pickle_lib == dill_pickler) dill_is_requested = ( selected_library == USE_DILL or selected_library == USE_DILL_UNSAFE) # If switching to or from dill, update the pickler hook overrides. - if is_currently_dill != dill_is_requested: + if is_currently_dill() != dill_is_requested: dill_pickler.override_pickler_hooks(selected_library == USE_DILL) if dill_is_requested: diff --git a/sdks/python/apache_beam/options/pipeline_options.py b/sdks/python/apache_beam/options/pipeline_options.py index f2addf6f9d53..170ade224c10 100644 --- a/sdks/python/apache_beam/options/pipeline_options.py +++ b/sdks/python/apache_beam/options/pipeline_options.py @@ -64,7 +64,10 @@ # Map defined with option names to flag names for boolean options # that have a destination(dest) in parser.add_argument() different # from the flag name and whose default value is `None`. -_FLAG_THAT_SETS_FALSE_VALUE = {'use_public_ips': 'no_use_public_ips'} +_FLAG_THAT_SETS_FALSE_VALUE = { + 'use_public_ips': 'no_use_public_ips', + 'save_main_session': 'no_save_main_session' +} # Set of options which should not be overriden when applying options from a # different language. This is relevant when using x-lang transforms where the # expansion service is started up with some pipeline options, and will @@ -1672,7 +1675,7 @@ def _add_argparse_args(cls, parser): choices=['cloudpickle', 'default', 'dill', 'dill_unsafe']) parser.add_argument( '--save_main_session', - default=False, + default=None, action='store_true', help=( 'Save the main session state so that pickled functions and classes ' @@ -1680,6 +1683,15 @@ def _add_argparse_args(cls, parser): 'Some workflows do not need the session state if for instance all ' 'their functions/classes are defined in proper modules ' '(not __main__) and the modules are importable in the worker. ')) + parser.add_argument( + '--no_save_main_session', + default=None, + action='store_false', + dest='save_main_session', + help=( + 'Disable saving the main session state. It is enabled/disabled by' + 'default for cloudpickle/dill pickler. See "save_main_session".')) + parser.add_argument( '--sdk_location', default='default', @@ -1780,10 +1792,23 @@ def _add_argparse_args(cls, parser): 'If not specified, the default Maven Central repository will be ' 'used.')) + def _handle_load_main_session(self, validator): + save_main_session = getattr(self, 'save_main_session') + if save_main_session is None: + # save_main_session default to False for dill, while default to true + # for cloudpickle + pickle_library = getattr(self, 'pickle_library') + if pickle_library in ['default', 'cloudpickle']: + setattr(self, 'save_main_session', True) + else: + setattr(self, 'save_main_session', False) + return [] + def validate(self, validator): errors = [] errors.extend(validator.validate_container_prebuilding_options(self)) errors.extend(validator.validate_pickle_library(self)) + errors.extend(self._handle_load_main_session(validator)) return errors diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py b/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py index 178a75ec41d9..d5d8ba662f06 100644 --- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py +++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py @@ -42,6 +42,7 @@ from apache_beam.runners.dataflow.dataflow_runner import _check_and_add_missing_options from apache_beam.runners.dataflow.dataflow_runner import _check_and_add_missing_streaming_options from apache_beam.runners.dataflow.internal.clients import dataflow as dataflow_api +from apache_beam.runners.internal import names from apache_beam.runners.runner import PipelineState from apache_beam.testing.extra_assertions import ExtraAssertionsMixin from apache_beam.testing.test_pipeline import TestPipeline @@ -243,6 +244,18 @@ def test_create_runner(self): self.assertTrue( isinstance(create_runner('TestDataflowRunner'), TestDataflowRunner)) + @staticmethod + def dependency_proto_from_main_session_file(serialized_path): + return [ + beam_runner_api_pb2.ArtifactInformation( + type_urn=common_urns.artifact_types.FILE.urn, + type_payload=serialized_path, + role_urn=common_urns.artifact_roles.STAGING_TO.urn, + role_payload=beam_runner_api_pb2.ArtifactStagingToRolePayload( + staged_name=names.PICKLED_MAIN_SESSION_FILE).SerializeToString( + )) + ] + def test_environment_override_translation_legacy_worker_harness_image(self): self.default_properties.append('--experiments=beam_fn_api') self.default_properties.append('--worker_harness_container_image=LEGACY') @@ -256,17 +269,22 @@ def test_environment_override_translation_legacy_worker_harness_image(self): | 'Do' >> ptransform.FlatMap(lambda x: [(x, x)]) | ptransform.GroupByKey()) + actual = list(remote_runner.proto_pipeline.components.environments.values()) + self.assertEqual(len(actual), 1) + actual = actual[0] + file_path = actual.dependencies[0].type_payload + # Dependency payload contains main_session from a transient temp directory + # Use actual for expected value. + main_session_dep = self.dependency_proto_from_main_session_file(file_path) self.assertEqual( - list(remote_runner.proto_pipeline.components.environments.values()), - [ - beam_runner_api_pb2.Environment( - urn=common_urns.environments.DOCKER.urn, - payload=beam_runner_api_pb2.DockerPayload( - container_image='LEGACY').SerializeToString(), - capabilities=environments.python_sdk_docker_capabilities(), - dependencies=environments.python_sdk_dependencies( - options=options)) - ]) + actual, + beam_runner_api_pb2.Environment( + urn=common_urns.environments.DOCKER.urn, + payload=beam_runner_api_pb2.DockerPayload( + container_image='LEGACY').SerializeToString(), + capabilities=environments.python_sdk_docker_capabilities(), + dependencies=environments.python_sdk_dependencies(options=options) + + main_session_dep)) def test_environment_override_translation_sdk_container_image(self): self.default_properties.append('--experiments=beam_fn_api') @@ -281,17 +299,22 @@ def test_environment_override_translation_sdk_container_image(self): | 'Do' >> ptransform.FlatMap(lambda x: [(x, x)]) | ptransform.GroupByKey()) + actual = list(remote_runner.proto_pipeline.components.environments.values()) + self.assertEqual(len(actual), 1) + actual = actual[0] + file_path = actual.dependencies[0].type_payload + # Dependency payload contains main_session from a transient temp directory + # Use actual for expected value. + main_session_dep = self.dependency_proto_from_main_session_file(file_path) self.assertEqual( - list(remote_runner.proto_pipeline.components.environments.values()), - [ - beam_runner_api_pb2.Environment( - urn=common_urns.environments.DOCKER.urn, - payload=beam_runner_api_pb2.DockerPayload( - container_image='FOO').SerializeToString(), - capabilities=environments.python_sdk_docker_capabilities(), - dependencies=environments.python_sdk_dependencies( - options=options)) - ]) + actual, + beam_runner_api_pb2.Environment( + urn=common_urns.environments.DOCKER.urn, + payload=beam_runner_api_pb2.DockerPayload( + container_image='FOO').SerializeToString(), + capabilities=environments.python_sdk_docker_capabilities(), + dependencies=environments.python_sdk_dependencies(options=options) + + main_session_dep)) def test_remote_runner_translation(self): remote_runner = DataflowRunner() diff --git a/sdks/python/apache_beam/runners/portability/stager.py b/sdks/python/apache_beam/runners/portability/stager.py index 9147410c2463..aa03082f0d57 100644 --- a/sdks/python/apache_beam/runners/portability/stager.py +++ b/sdks/python/apache_beam/runners/portability/stager.py @@ -376,7 +376,6 @@ def create_job_resources( pickled_session_file = os.path.join( temp_dir, names.PICKLED_MAIN_SESSION_FILE) pickler.dump_session(pickled_session_file) - # for pickle_library: cloudpickle, dump_session is no op if os.path.exists(pickled_session_file): resources.append( Stager._create_file_stage_to_artifact( diff --git a/sdks/python/apache_beam/runners/portability/stager_test.py b/sdks/python/apache_beam/runners/portability/stager_test.py index 22a41e592c2b..233e0c3dcea1 100644 --- a/sdks/python/apache_beam/runners/portability/stager_test.py +++ b/sdks/python/apache_beam/runners/portability/stager_test.py @@ -200,7 +200,7 @@ def test_with_main_session(self): # (https://github.com/apache/beam/issues/21457): Remove the decorator once # cloudpickle is default pickle library @pytest.mark.no_xdist - def test_main_session_not_staged_when_using_cloudpickle(self): + def test_main_session_staged_when_using_cloudpickle(self): staging_dir = self.make_temp_dir() options = PipelineOptions() @@ -209,7 +209,10 @@ def test_main_session_not_staged_when_using_cloudpickle(self): # session is saved when pickle_library==cloudpickle. options.view_as(SetupOptions).pickle_library = pickler.USE_CLOUDPICKLE self.update_options(options) - self.assertEqual([stager.SUBMISSION_ENV_DEPENDENCIES_FILE], + self.assertEqual([ + names.PICKLED_MAIN_SESSION_FILE, + stager.SUBMISSION_ENV_DEPENDENCIES_FILE + ], self.stager.create_and_stage_job_resources( options, staging_location=staging_dir)[1]) diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py index cdb807e8dbc5..e4dd6cc2121f 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py @@ -113,26 +113,25 @@ def create_harness(environment, dry_run=False): _LOGGER.info('semi_persistent_directory: %s', semi_persistent_directory) _worker_id = environment.get('WORKER_ID', None) - if pickle_library != pickler.USE_CLOUDPICKLE: - try: - _load_main_session(semi_persistent_directory) - except LoadMainSessionException: - exception_details = traceback.format_exc() - _LOGGER.error( - 'Could not load main session: %s', exception_details, exc_info=True) - raise - except Exception: # pylint: disable=broad-except - summary = ( - "Could not load main session. Inspect which external dependencies " - "are used in the main module of your pipeline. Verify that " - "corresponding packages are installed in the pipeline runtime " - "environment and their installed versions match the versions used in " - "pipeline submission environment. For more information, see: https://" - "beam.apache.org/documentation/sdks/python-pipeline-dependencies/") - _LOGGER.error(summary, exc_info=True) - exception_details = traceback.format_exc() - deferred_exception = LoadMainSessionException( - f"{summary} {exception_details}") + try: + _load_main_session(semi_persistent_directory) + except LoadMainSessionException: + exception_details = traceback.format_exc() + _LOGGER.error( + 'Could not load main session: %s', exception_details, exc_info=True) + raise + except Exception: # pylint: disable=broad-except + summary = ( + "Could not load main session. Inspect which external dependencies " + "are used in the main module of your pipeline. Verify that " + "corresponding packages are installed in the pipeline runtime " + "environment and their installed versions match the versions used in " + "pipeline submission environment. For more information, see: https://" + "beam.apache.org/documentation/sdks/python-pipeline-dependencies/") + _LOGGER.error(summary, exc_info=True) + exception_details = traceback.format_exc() + deferred_exception = LoadMainSessionException( + f"{summary} {exception_details}") _LOGGER.info( 'Pipeline_options: %s', @@ -356,6 +355,14 @@ class LoadMainSessionException(Exception): def _load_main_session(semi_persistent_directory): """Loads a pickled main session from the path specified.""" + if pickler.is_currently_dill(): + warn_msg = ' Functions defined in __main__ (interactive session) may fail.' + err_msg = ' Functions defined in __main__ (interactive session) will ' \ + 'almost certainly fail.' + elif pickler.is_currently_cloudpickle(): + warn_msg = ' User registered objects (e.g. schema, logical type) through' \ + 'registeries may not be effective' + err_msg = '' if semi_persistent_directory: session_file = os.path.join( semi_persistent_directory, 'staged', names.PICKLED_MAIN_SESSION_FILE) @@ -365,21 +372,18 @@ def _load_main_session(semi_persistent_directory): # This can happen if the worker fails to download the main session. # Raise a fatal error and crash this worker, forcing a restart. if os.path.getsize(session_file) == 0: - # Potenitally transient error, unclear if still happening. - raise LoadMainSessionException( - 'Session file found, but empty: %s. Functions defined in __main__ ' - '(interactive session) will almost certainly fail.' % - (session_file, )) - pickler.load_session(session_file) + if pickler.is_currently_dill(): + # Potenitally transient error, unclear if still happening. + raise LoadMainSessionException( + 'Session file found, but empty: %s.%s' % (session_file, err_msg)) + else: + _LOGGER.warning('Empty session file: %s.%s', warn_msg, session_file) + else: + pickler.load_session(session_file) else: - _LOGGER.warning( - 'No session file found: %s. Functions defined in __main__ ' - '(interactive session) may fail.', - session_file) + _LOGGER.warning('No session file found: %s.%s', warn_msg, session_file) else: - _LOGGER.warning( - 'No semi_persistent_directory found: Functions defined in __main__ ' - '(interactive session) may fail.') + _LOGGER.warning('No semi_persistent_directory found: %s', warn_msg) if __name__ == '__main__': diff --git a/sdks/python/apache_beam/typehints/schemas.py b/sdks/python/apache_beam/typehints/schemas.py index c21dde426fc7..e9674fa5bc20 100644 --- a/sdks/python/apache_beam/typehints/schemas.py +++ b/sdks/python/apache_beam/typehints/schemas.py @@ -684,12 +684,17 @@ def __init__(self): self.by_urn = {} self.by_logical_type = {} self.by_language_type = {} + self._custom_urns = set() - def add(self, urn, logical_type): + def _add_internal(self, urn, logical_type): self.by_urn[urn] = logical_type self.by_logical_type[logical_type] = urn self.by_language_type[logical_type.language_type()] = logical_type + def add(self, urn, logical_type): + self._add_internal(urn, logical_type) + self._custom_urns.add(urn) + def get_logical_type_by_urn(self, urn): return self.by_urn.get(urn, None) @@ -704,8 +709,25 @@ def copy(self): copy.by_urn.update(self.by_urn) copy.by_logical_type.update(self.by_logical_type) copy.by_language_type.update(self.by_language_type) + copy._custom_urns.update(self._custom_urns) return copy + def copy_custom(self): + copy = LogicalTypeRegistry() + for urn in self._custom_urns: + logical_type = self.by_urn[urn] + copy.by_urn[urn] = logical_type + copy.by_logical_type[logical_type] = urn + copy.by_language_type[logical_type.language_type()] = logical_type + copy._custom_urns.add(urn) + return copy + + def load(self, another): + self.by_urn.update(another.by_urn) + self.by_logical_type.update(another.by_logical_type) + self.by_language_type.update(another.by_language_type) + self._custom_urns.update(another._custom_urns) + LanguageT = TypeVar('LanguageT') RepresentationT = TypeVar('RepresentationT') @@ -768,6 +790,19 @@ def to_language_type(self, value): """Convert an instance of RepresentationT to LanguageT.""" raise NotImplementedError() + @classmethod + def _register_internal(cls, logical_type_cls): + """ + Register an implementation of LogicalType. + + The types registered using this decorator are not pickled on pipeline + submission, as it relies module import to be registered on worker + initialization. Should be used within schemas module and static context. + """ + cls._known_logical_types._add_internal( + logical_type_cls.urn(), logical_type_cls) + return logical_type_cls + @classmethod def register_logical_type(cls, logical_type_cls): """Register an implementation of LogicalType.""" @@ -884,7 +919,7 @@ def _from_typing(cls, typ): ('micros', np.int64)]) -@LogicalType.register_logical_type +@LogicalType._register_internal class MillisInstant(NoArgumentLogicalType[Timestamp, np.int64]): """Millisecond-precision instant logical type handles values consistent with that encoded by ``InstantCoder`` in the Java SDK. @@ -928,7 +963,7 @@ def to_language_type(self, value): # Make sure MicrosInstant is registered after MillisInstant so that it # overwrites the mapping of Timestamp language type representation choice and # thus does not lose microsecond precision inside python sdk. -@LogicalType.register_logical_type +@LogicalType._register_internal class MicrosInstant(NoArgumentLogicalType[Timestamp, MicrosInstantRepresentation]): """Microsecond-precision instant logical type that handles ``Timestamp``.""" @@ -955,7 +990,7 @@ def to_language_type(self, value): return Timestamp(seconds=int(value.seconds), micros=int(value.micros)) -@LogicalType.register_logical_type +@LogicalType._register_internal class PythonCallable(NoArgumentLogicalType[PythonCallableWithSource, str]): """A logical type for PythonCallableSource objects.""" @classmethod @@ -1011,7 +1046,7 @@ def to_language_type(self, value): return decimal.Decimal(value.decode()) -@LogicalType.register_logical_type +@LogicalType._register_internal class FixedPrecisionDecimalLogicalType( LogicalType[decimal.Decimal, DecimalLogicalType, @@ -1063,10 +1098,10 @@ def _from_typing(cls, typ): # TODO(yathu,BEAM-10722): Investigate and resolve conflicts in logical type # registration when more than one logical types sharing the same language type -LogicalType.register_logical_type(DecimalLogicalType) +LogicalType._register_internal(DecimalLogicalType) -@LogicalType.register_logical_type +@LogicalType._register_internal class FixedBytes(PassThroughLogicalType[bytes, np.int32]): """A logical type for fixed-length bytes.""" @classmethod @@ -1099,7 +1134,7 @@ def argument(self): return self.length -@LogicalType.register_logical_type +@LogicalType._register_internal class VariableBytes(PassThroughLogicalType[bytes, np.int32]): """A logical type for variable-length bytes with specified maximum length.""" @classmethod @@ -1129,7 +1164,7 @@ def argument(self): return self.max_length -@LogicalType.register_logical_type +@LogicalType._register_internal class FixedString(PassThroughLogicalType[str, np.int32]): """A logical type for fixed-length string.""" @classmethod @@ -1162,7 +1197,7 @@ def argument(self): return self.length -@LogicalType.register_logical_type +@LogicalType._register_internal class VariableString(PassThroughLogicalType[str, np.int32]): """A logical type for variable-length string with specified maximum length.""" @classmethod @@ -1195,7 +1230,7 @@ def argument(self): # TODO: A temporary fix for missing jdbc logical types. # See the discussion in https://github.com/apache/beam/issues/35738 for # more detail. -@LogicalType.register_logical_type +@LogicalType._register_internal class JdbcDateType(LogicalType[datetime.date, MillisInstant, str]): """ For internal use only; no backwards-compatibility guarantees. @@ -1238,7 +1273,7 @@ def _from_typing(cls, typ): return cls() -@LogicalType.register_logical_type +@LogicalType._register_internal class JdbcTimeType(LogicalType[datetime.time, MillisInstant, str]): """ For internal use only; no backwards-compatibility guarantees.