diff --git a/sdks/python/apache_beam/coders/coder_impl.py b/sdks/python/apache_beam/coders/coder_impl.py index 1e3bb2ece92a..3e0b5218b166 100644 --- a/sdks/python/apache_beam/coders/coder_impl.py +++ b/sdks/python/apache_beam/coders/coder_impl.py @@ -32,7 +32,6 @@ import decimal import enum -import functools import itertools import json import logging @@ -376,18 +375,6 @@ def _verify_dill_compat(): raise RuntimeError(base_error + f". Found dill version '{dill.__version__}") -dataclass_uses_kw_only: Callable[[Any], bool] -if dataclasses: - # Cache the result to avoid multiple checks for the same dataclass type. - @functools.cache - def dataclass_uses_kw_only(cls) -> bool: - return any( - field.init and field.kw_only for field in dataclasses.fields(cls)) - -else: - dataclass_uses_kw_only = lambda cls: False - - class FastPrimitivesCoderImpl(StreamCoderImpl): """For internal use only; no backwards-compatibility guarantees.""" def __init__( @@ -518,7 +505,7 @@ def encode_special_deterministic(self, value, stream): (value, type(value), self.requires_deterministic_step_label)) init_fields = [field for field in dataclasses.fields(value) if field.init] try: - if dataclass_uses_kw_only(type(value)): + if any(field.kw_only for field in init_fields): stream.write_byte(DATACLASS_KW_ONLY_TYPE) self.encode_type(type(value), stream) stream.write_var_int64(len(init_fields)) diff --git a/sdks/python/apache_beam/coders/coders_test_common.py b/sdks/python/apache_beam/coders/coders_test_common.py index 8f89ab9602c1..fcc5e6ac58bf 100644 --- a/sdks/python/apache_beam/coders/coders_test_common.py +++ b/sdks/python/apache_beam/coders/coders_test_common.py @@ -123,6 +123,15 @@ class UnFrozenDataClass: x: int y: int + @dataclasses.dataclass(frozen=True, kw_only=True) + class FrozenUnInitKwOnlyDataClass: + side: int + area: int = dataclasses.field(init=False) + + def __post_init__(self): + # Hack to update an attribute in a frozen dataclass. + object.__setattr__(self, 'area', self.side**2) + # These tests need to all be run in the same process due to the asserts # in tearDownClass. @@ -309,6 +318,8 @@ def test_deterministic_coder(self, compat_version): if dataclasses is not None: self.check_coder(deterministic_coder, FrozenDataClass(1, 2)) self.check_coder(deterministic_coder, FrozenKwOnlyDataClass(c=1, d=2)) + self.check_coder( + deterministic_coder, FrozenUnInitKwOnlyDataClass(side=11)) with self.assertRaises(TypeError): self.check_coder(deterministic_coder, UnFrozenDataClass(1, 2)) @@ -750,6 +761,7 @@ def test_cross_process_encoding_of_special_types_is_deterministic( from apache_beam.coders.coders_test_common import DefinesGetAndSetState from apache_beam.coders.coders_test_common import FrozenDataClass from apache_beam.coders.coders_test_common import FrozenKwOnlyDataClass + from apache_beam.coders.coders_test_common import FrozenUnInitKwOnlyDataClass from apache_beam.coders import proto2_coder_test_messages_pb2 as test_message @@ -786,7 +798,7 @@ def test_cross_process_encoding_of_special_types_is_deterministic( ("frozen_dataclass", FrozenDataClass(1, 2)), ("frozen_dataclass_list", [FrozenDataClass(1, 2), FrozenDataClass(3, 4)]), ("frozen_kwonly_dataclass", FrozenKwOnlyDataClass(c=1, d=2)), - ("frozen_kwonly_dataclass_list", [FrozenKwOnlyDataClass(c=1, d=2), FrozenKwOnlyDataClass(c=3, d=4)]), + ("frozen_kwonly_dataclass_list", [FrozenKwOnlyDataClass(c=1, d=2), FrozenUnInitKwOnlyDataClass(side=3)]), ]) compat_version = {'"'+ compat_version +'"' if compat_version else None}