Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 35 additions & 6 deletions sdks/python/apache_beam/coders/coder_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

import decimal
import enum
import functools
import itertools
import json
import logging
Expand Down Expand Up @@ -354,6 +355,7 @@ def decode(self, value):
NAMED_TUPLE_TYPE = 102
ENUM_TYPE = 103
NESTED_STATE_TYPE = 104
DATACLASS_KW_ONLY_TYPE = 105

# Types that can be encoded as iterables, but are not literally
# lists, etc. due to being lazy. The actual type is not preserved
Expand All @@ -374,6 +376,18 @@ def _verify_dill_compat():
raise RuntimeError(base_error + f". Found dill version '{dill.__version__}")


dataclass_uses_kw_only: Callable[[Any], bool]
if dataclasses:
# Cache the result to avoid multiple checks for the same dataclass type.
@functools.cache
def dataclass_uses_kw_only(cls) -> bool:
return any(
field.init and field.kw_only for field in dataclasses.fields(cls))

else:
dataclass_uses_kw_only = lambda cls: False


class FastPrimitivesCoderImpl(StreamCoderImpl):
"""For internal use only; no backwards-compatibility guarantees."""
def __init__(
Expand Down Expand Up @@ -497,18 +511,25 @@ def encode_special_deterministic(self, value, stream):
self.encode_type(type(value), stream)
stream.write(value.SerializePartialToString(deterministic=True), True)
elif dataclasses and dataclasses.is_dataclass(value):
stream.write_byte(DATACLASS_TYPE)
if not type(value).__dataclass_params__.frozen:
raise TypeError(
"Unable to deterministically encode non-frozen '%s' of type '%s' "
"for the input of '%s'" %
(value, type(value), self.requires_deterministic_step_label))
self.encode_type(type(value), stream)
values = [
getattr(value, field.name) for field in dataclasses.fields(value)
]
init_fields = [field for field in dataclasses.fields(value) if field.init]
try:
self.iterable_coder_impl.encode_to_stream(values, stream, True)
if dataclass_uses_kw_only(type(value)):
stream.write_byte(DATACLASS_KW_ONLY_TYPE)
self.encode_type(type(value), stream)
stream.write_var_int64(len(init_fields))
for field in init_fields:
stream.write(field.name.encode("utf-8"), True)
self.encode_to_stream(getattr(value, field.name), stream, True)
else: # Not using kw_only, we can pass parameters by position.
stream.write_byte(DATACLASS_TYPE)
self.encode_type(type(value), stream)
values = [getattr(value, field.name) for field in init_fields]
self.iterable_coder_impl.encode_to_stream(values, stream, True)
except Exception as e:
raise TypeError(self._deterministic_encoding_error_msg(value)) from e
elif isinstance(value, tuple) and hasattr(type(value), '_fields'):
Expand Down Expand Up @@ -616,6 +637,14 @@ def decode_from_stream(self, stream, nested):
msg = cls()
msg.ParseFromString(stream.read_all(True))
return msg
elif t == DATACLASS_KW_ONLY_TYPE:
cls = self.decode_type(stream)
vlen = stream.read_var_int64()
fields = {}
for _ in range(vlen):
field_name = stream.read_all(True).decode('utf-8')
fields[field_name] = self.decode_from_stream(stream, True)
return cls(**fields)
elif t == DATACLASS_TYPE or t == NAMED_TUPLE_TYPE:
cls = self.decode_type(stream)
return cls(*self.iterable_coder_impl.decode_from_stream(stream, True))
Expand Down
10 changes: 10 additions & 0 deletions sdks/python/apache_beam/coders/coders_test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ class FrozenDataClass:
a: Any
b: int

@dataclasses.dataclass(frozen=True, kw_only=True)
class FrozenKwOnlyDataClass:
c: int
d: int

@dataclasses.dataclass
class UnFrozenDataClass:
x: int
Expand Down Expand Up @@ -303,9 +308,11 @@ def test_deterministic_coder(self, compat_version):

if dataclasses is not None:
self.check_coder(deterministic_coder, FrozenDataClass(1, 2))
self.check_coder(deterministic_coder, FrozenKwOnlyDataClass(c=1, d=2))

with self.assertRaises(TypeError):
self.check_coder(deterministic_coder, UnFrozenDataClass(1, 2))

with self.assertRaises(TypeError):
self.check_coder(
deterministic_coder, FrozenDataClass(UnFrozenDataClass(1, 2), 3))
Expand Down Expand Up @@ -742,6 +749,7 @@ def test_cross_process_encoding_of_special_types_is_deterministic(
from apache_beam.coders.coders_test_common import DefinesGetState
from apache_beam.coders.coders_test_common import DefinesGetAndSetState
from apache_beam.coders.coders_test_common import FrozenDataClass
from apache_beam.coders.coders_test_common import FrozenKwOnlyDataClass


from apache_beam.coders import proto2_coder_test_messages_pb2 as test_message
Expand Down Expand Up @@ -777,6 +785,8 @@ def test_cross_process_encoding_of_special_types_is_deterministic(
test_cases.extend([
("frozen_dataclass", FrozenDataClass(1, 2)),
("frozen_dataclass_list", [FrozenDataClass(1, 2), FrozenDataClass(3, 4)]),
("frozen_kwonly_dataclass", FrozenKwOnlyDataClass(c=1, d=2)),
("frozen_kwonly_dataclass_list", [FrozenKwOnlyDataClass(c=1, d=2), FrozenKwOnlyDataClass(c=3, d=4)]),
])

compat_version = {'"'+ compat_version +'"' if compat_version else None}
Expand Down
Loading