diff --git a/paimon-python/pypaimon/table/row/generic_row.py b/paimon-python/pypaimon/table/row/generic_row.py index 4aa740de7219..1460f649d1c2 100644 --- a/paimon-python/pypaimon/table/row/generic_row.py +++ b/paimon-python/pypaimon/table/row/generic_row.py @@ -16,6 +16,8 @@ # limitations under the License. ################################################################################ +import calendar +import decimal import struct from datetime import date, datetime, time, timedelta from decimal import Decimal @@ -29,6 +31,63 @@ from pypaimon.table.row.blob import BlobData +_DECIMAL_CTX = decimal.Context(prec=100, rounding=decimal.ROUND_HALF_UP) + + +def _decimal_to_unscaled_with_check(d: Decimal, precision: int, scale: int): + """Round decimal with HALF_UP, check precision overflow, and return unscaled value. + Returns (unscaled_int, True) on overflow, (unscaled_int, False) on success.""" + rounded = d.quantize(Decimal(10) ** -scale, context=_DECIMAL_CTX) + sign, digits, exponent = rounded.as_tuple() + # Precision overflow check + if rounded != 0 and len(digits) > precision: + return 0, True + int_digits = int(''.join(str(x) for x in digits)) if digits != (0,) else 0 + shift = exponent + scale + if shift >= 0: + unscaled = int_digits * (10 ** shift) + else: + unscaled = int_digits // (10 ** (-shift)) + return (-unscaled if sign else unscaled), False + + +def _parse_type_precision_scale(data_type): + """Parse precision and scale from type string like DECIMAL(38, 10).""" + type_str = str(data_type) + if '(' in type_str and ')' in type_str: + try: + params_str = type_str.split('(')[1].split(')')[0] + parts = [p.strip() for p in params_str.split(',')] + precision = int(parts[0]) + scale = int(parts[1]) if len(parts) > 1 else 0 + return precision, scale + except (ValueError, IndexError): + return 0, 0 + return 0, 0 + + +_EPOCH = datetime(1970, 1, 1) + + +def _datetime_to_millis_and_nanos(value: datetime): + """Convert datetime to (epoch_millis, nano_of_millisecond) without float arithmetic.""" + epoch_seconds = calendar.timegm(value.timetuple()) + millis = epoch_seconds * 1000 + value.microsecond // 1000 + nano_of_millisecond = (value.microsecond % 1000) * 1000 + return millis, nano_of_millisecond + + +def _millis_nanos_to_datetime(millis: int, nano_of_millisecond: int = 0) -> datetime: + """Convert (epoch_millis, nano_of_millisecond) to datetime. Nanos truncated to micros.""" + total_micros = millis * 1000 + nano_of_millisecond // 1000 + seconds = total_micros // 1_000_000 + micros = total_micros % 1_000_000 + if micros < 0: + seconds -= 1 + micros += 1_000_000 + return _EPOCH + timedelta(seconds=seconds, microseconds=micros) + + @dataclass class GenericRow(InternalRow): @@ -233,26 +292,49 @@ def _parse_blob(cls, bytes_data: bytes, base_offset: int, field_offset: int) -> return BlobData.from_bytes(binary_data) @classmethod - def _parse_decimal(cls, bytes_data: bytes, base_offset: int, field_offset: int, data_type: DataType) -> Decimal: - unscaled_long = struct.unpack(' Decimal: + sign = 0 if unscaled_value >= 0 else 1 + digits = tuple(int(d) for d in str(abs(unscaled_value))) if unscaled_value != 0 else (0,) + return Decimal((sign, digits, -scale)) + + @classmethod + def _parse_decimal(cls, bytes_data: bytes, base_offset: int, field_offset: int, data_type: DataType): + precision, scale = _parse_type_precision_scale(data_type) + if precision <= 0: + raise ValueError(f"Decimal requires precision > 0, got {precision}") + if precision <= 18: + # Compact: unscaled long in fixed part + unscaled_long = struct.unpack('> 32) & 0xFFFFFFFF + byte_length = offset_and_len & 0xFFFFFFFF + var_offset = base_offset + cursor + unscaled_bytes = bytes_data[var_offset:var_offset + byte_length] + unscaled_value = int.from_bytes(unscaled_bytes, byteorder='big', signed=True) + # Precision overflow returns null + result = cls._unscaled_to_decimal(unscaled_value, scale) + _, digits, _ = result.as_tuple() + if result != 0 and len(digits) > precision: + return None + return result @classmethod def _parse_timestamp(cls, bytes_data: bytes, base_offset: int, field_offset: int, data_type: DataType) -> datetime: - millis = struct.unpack('> 32) & 0xFFFFFFFF + millis = struct.unpack(' date: @@ -301,9 +383,45 @@ def to_bytes(cls, row: Union[GenericRow, BinaryRow]) -> bytes: raise ValueError(f"BinaryRow only support AtomicType yet, meet {field.type.__class__}") type_name = field.type.type.upper() - if any(type_name.startswith(p) for p in ['CHAR', 'VARCHAR', 'STRING', - 'BINARY', 'VARBINARY', 'BYTES', 'BLOB']): - if any(type_name.startswith(p) for p in ['CHAR', 'VARCHAR', 'STRING']): + is_var_len_type = any(type_name.startswith(p) for p in [ + 'CHAR', 'VARCHAR', 'STRING', 'BINARY', 'VARBINARY', 'BYTES', 'BLOB']) + is_decimal_type = type_name.startswith('DECIMAL') or type_name.startswith('NUMERIC') + is_timestamp_type = type_name.startswith('TIMESTAMP') + decimal_precision, decimal_scale = _parse_type_precision_scale(field.type) if is_decimal_type else (0, 0) + is_high_precision_decimal = is_decimal_type and decimal_precision > 18 + timestamp_precision = _parse_type_precision_scale(field.type)[0] if is_timestamp_type else 0 + is_non_compact_timestamp = is_timestamp_type and timestamp_precision > 3 + + # Precision overflow -> null + if is_decimal_type and value is not None: + d = value if isinstance(value, Decimal) else Decimal(str(value)) + unscaled_value, overflow = _decimal_to_unscaled_with_check(d, decimal_precision, decimal_scale) + if overflow: + cls._set_null_bit(fixed_part, 0, i) + struct.pack_into(' bytes: value_bytes = bytes(value) length = len(value_bytes) - if length <= cls.MAX_FIX_PART_DATA_SIZE: + if length <= cls.MAX_FIX_PART_DATA_SIZE and not is_high_precision_decimal: fixed_part[field_fixed_offset: field_fixed_offset + length] = value_bytes for j in range(length, 7): fixed_part[field_fixed_offset + j] = 0 header_byte = 0x80 | length fixed_part[field_fixed_offset + 7] = header_byte else: - var_length = cls._round_number_of_bytes_to_nearest_word(len(value_bytes)) + # Non-compact decimal: fixed 16 bytes; others: 8-byte aligned + if is_high_precision_decimal: + var_length = 16 + else: + var_length = cls._round_number_of_bytes_to_nearest_word(len(value_bytes)) var_value_bytes = value_bytes + b'\x00' * (var_length - length) offset_in_variable_part = current_variable_offset variable_part_data.append(var_value_bytes) @@ -365,8 +487,18 @@ def _serialize_field_value(cls, value: Any, data_type: AtomicType) -> bytes: elif type_name in ['DOUBLE']: return cls._serialize_double(value) elif type_name.startswith('DECIMAL') or type_name.startswith('NUMERIC'): + precision, _ = _parse_type_precision_scale(data_type) + if precision > 18: + raise ValueError( + f"Non-compact decimal (precision={precision}) must be serialized " + f"via the variable-length path in to_bytes(), not _serialize_field_value()") return cls._serialize_decimal(value, data_type) elif type_name.startswith('TIMESTAMP'): + precision = _parse_type_precision_scale(data_type)[0] + if precision > 3: + raise ValueError( + f"Non-compact timestamp (precision={precision}) must be serialized " + f"via the variable-length path in to_bytes(), not _serialize_field_value()") return cls._serialize_timestamp(value) elif type_name in ['DATE']: return cls._serialize_date(value) + b'\x00' * 4 @@ -405,27 +537,17 @@ def _serialize_double(cls, value: float) -> bytes: @classmethod def _serialize_decimal(cls, value: Decimal, data_type: DataType) -> bytes: - type_str = str(data_type) - if '(' in type_str and ')' in type_str: - try: - precision_scale = type_str.split('(')[1].split(')')[0] - if ',' in precision_scale: - scale = int(precision_scale.split(',')[1]) - else: - scale = 0 - except: - scale = 0 - else: - scale = 0 - - unscaled_value = int(value * (10 ** scale)) + """Compact decimal: unscaled long in fixed part.""" + precision, scale = _parse_type_precision_scale(data_type) + d = value if isinstance(value, Decimal) else Decimal(str(value)) + unscaled_value, _ = _decimal_to_unscaled_with_check(d, precision, scale) return struct.pack(' bytes: if value.tzinfo is not None: raise RuntimeError("datetime tzinfo not supported yet") - millis = int(value.timestamp() * 1000) + millis, _ = _datetime_to_millis_and_nanos(value) return struct.pack(' 0.05 + fields = [ + DataField(0, "d", AtomicType("DECIMAL(4, 2)")), + DataField(1, "d2", AtomicType("DECIMAL(4, 2)")), + ] + row = GenericRow([Decimal("0.05"), None], fields, RowKind.INSERT) + serialized = GenericRowSerializer.to_bytes(row) + result = GenericRowDeserializer.from_bytes(serialized, fields) + + self.assertEqual(str(result.values[0]), "0.05") + self.assertIsNone(result.values[1]) + + # Another compact value: 0.06 + row2 = GenericRow([Decimal("0.06"), None], fields, RowKind.INSERT) + serialized2 = GenericRowSerializer.to_bytes(row2) + result2 = GenericRowDeserializer.from_bytes(serialized2, fields) + self.assertEqual(str(result2.values[0]), "0.06") + + def test_decimal_not_compact(self): + """Test non-compact decimal (precision > 18) round-trip.""" + # precision=25, scale=5 + fields = [ + DataField(0, "d", AtomicType("DECIMAL(25, 5)")), + DataField(1, "d2", AtomicType("DECIMAL(25, 5)")), + ] + row = GenericRow([Decimal("5.55000"), None], fields, RowKind.INSERT) + serialized = GenericRowSerializer.to_bytes(row) + result = GenericRowDeserializer.from_bytes(serialized, fields) + + self.assertEqual(str(result.values[0]), "5.55000") + self.assertIsNone(result.values[1]) + + # Another value: 6.55 + row2 = GenericRow([Decimal("6.55000"), None], fields, RowKind.INSERT) + serialized2 = GenericRowSerializer.to_bytes(row2) + result2 = GenericRowDeserializer.from_bytes(serialized2, fields) + self.assertEqual(str(result2.values[0]), "6.55000") + + # Negative value + row3 = GenericRow([Decimal("-123.45000"), None], fields, RowKind.INSERT) + serialized3 = GenericRowSerializer.to_bytes(row3) + result3 = GenericRowDeserializer.from_bytes(serialized3, fields) + self.assertEqual(str(result3.values[0]), "-123.45000") + + def test_decimal_high_precision_large_value(self): + """Test high-precision decimal with large values that exceed long range.""" + fields = [DataField(0, "d", AtomicType("DECIMAL(38, 10)"))] + + test_values = [ + Decimal("12345678901234567890.1234567890"), + Decimal("-99999999999999999999.9999999999"), + Decimal("0E-10"), + ] + + for val in test_values: + with self.subTest(value=val): + row = GenericRow([val], fields, RowKind.INSERT) + serialized = GenericRowSerializer.to_bytes(row) + result = GenericRowDeserializer.from_bytes(serialized, fields) + self.assertEqual(result.values[0], val) + + def test_decimal_mixed_with_other_types(self): + """Test decimal fields mixed with other types in a single row.""" + fields = [ + DataField(0, "id", AtomicType("INT")), + DataField(1, "name", AtomicType("STRING")), + DataField(2, "compact_dec", AtomicType("DECIMAL(10, 2)")), + DataField(3, "high_dec", AtomicType("DECIMAL(38, 2)")), + DataField(4, "score", AtomicType("DOUBLE")), + ] + + row = GenericRow( + [42, "test_row", Decimal("12345.67"), Decimal("12312455.22"), 3.14], + fields, RowKind.INSERT + ) + serialized = GenericRowSerializer.to_bytes(row) + result = GenericRowDeserializer.from_bytes(serialized, fields) + + self.assertEqual(result.values[0], 42) + self.assertEqual(result.values[1], "test_row") + self.assertEqual(result.values[2], Decimal("12345.67")) + self.assertEqual(result.values[3], Decimal("12312455.22")) + self.assertAlmostEqual(result.values[4], 3.14) + + def test_decimal_compact_binary_format(self): + """Verify compact decimal binary layout: unscaled long in fixed part.""" + fields = [DataField(0, "d", AtomicType("DECIMAL(4, 2)"))] + row = GenericRow([Decimal("0.05")], fields, RowKind.INSERT) + serialized = GenericRowSerializer.to_bytes(row) + + # Skip 4-byte arity prefix + data = serialized[4:] + null_bits_size = 8 # ((1 + 63 + 8) // 64) * 8 + field_offset = null_bits_size + unscaled_long = struct.unpack(' unscaled = 5 + self.assertEqual(unscaled_long, 5) + + def test_decimal_not_compact_binary_format(self): + """Verify non-compact decimal binary layout: (offset << 32 | length) in fixed part, + 16-byte big-endian unscaled bytes in variable part. + """ + fields = [DataField(0, "d", AtomicType("DECIMAL(25, 5)"))] + row = GenericRow([Decimal("5.55000")], fields, RowKind.INSERT) + serialized = GenericRowSerializer.to_bytes(row) + + # Skip 4-byte arity prefix + data = serialized[4:] + null_bits_size = 8 + field_offset = null_bits_size + fixed_part_size = null_bits_size + 1 * 8 + + offset_and_len = struct.unpack('> 32) & 0xFFFFFFFF + byte_length = offset_and_len & 0xFFFFFFFF + + # cursor should point to the variable area (== fixed_part_size) + self.assertEqual(cursor, fixed_part_size) + # variable area should be exactly 16 bytes + var_area = data[cursor:] + self.assertEqual(len(var_area), 16) + # unscaled bytes are big-endian signed + unscaled_bytes = data[cursor:cursor + byte_length] + unscaled_value = int.from_bytes(unscaled_bytes, byteorder='big', signed=True) + # Decimal("5.55000") with scale=5 => unscaled = 555000 + self.assertEqual(unscaled_value, 555000) + + def test_decimal_boundary_precision(self): + """Test boundary: DECIMAL(18, ...) is compact, DECIMAL(19, ...) is non-compact.""" + # precision=18: last compact + fields_18 = [DataField(0, "d", AtomicType("DECIMAL(18, 4)"))] + row_18 = GenericRow([Decimal("12345678901234.5678")], fields_18, RowKind.INSERT) + s_18 = GenericRowSerializer.to_bytes(row_18) + r_18 = GenericRowDeserializer.from_bytes(s_18, fields_18) + self.assertEqual(r_18.values[0], Decimal("12345678901234.5678")) + # verify compact: no variable area beyond fixed part + data_18 = s_18[4:] + null_bits_size = 8 + fixed_part_size = null_bits_size + 1 * 8 + self.assertEqual(len(data_18), fixed_part_size) + + # precision=19: first non-compact + fields_19 = [DataField(0, "d", AtomicType("DECIMAL(19, 4)"))] + row_19 = GenericRow([Decimal("12345678901234.5678")], fields_19, RowKind.INSERT) + s_19 = GenericRowSerializer.to_bytes(row_19) + r_19 = GenericRowDeserializer.from_bytes(s_19, fields_19) + self.assertEqual(r_19.values[0], Decimal("12345678901234.5678")) + # verify non-compact: has 16-byte variable area + data_19 = s_19[4:] + self.assertEqual(len(data_19), fixed_part_size + 16) + + def test_decimal_zero_different_scales(self): + """Test zero value with different precisions and scales.""" + test_cases = [ + ("DECIMAL(38, 0)", Decimal("0")), + ("DECIMAL(38, 10)", Decimal("0E-10")), + ("DECIMAL(10, 2)", Decimal("0.00")), + ] + for type_str, val in test_cases: + with self.subTest(type=type_str): + fields = [DataField(0, "d", AtomicType(type_str))] + row = GenericRow([val], fields, RowKind.INSERT) + serialized = GenericRowSerializer.to_bytes(row) + result = GenericRowDeserializer.from_bytes(serialized, fields) + self.assertEqual(result.values[0], val) + + def test_decimal_half_up_rounding(self): + """Excess fractional digits should be rounded with HALF_UP.""" + fields = [DataField(0, "d", AtomicType("DECIMAL(10, 2)"))] + + test_cases = [ + (Decimal("1.999"), Decimal("2.00")), # .999 rounds up + (Decimal("1.235"), Decimal("1.24")), # .235 rounds up (HALF_UP) + (Decimal("1.234"), Decimal("1.23")), # .234 rounds down + (Decimal("1.225"), Decimal("1.23")), # .225 rounds up (HALF_UP) + (Decimal("-1.235"), Decimal("-1.24")), # negative HALF_UP + ] + for val, expected in test_cases: + with self.subTest(value=val): + row = GenericRow([val], fields, RowKind.INSERT) + serialized = GenericRowSerializer.to_bytes(row) + result = GenericRowDeserializer.from_bytes(serialized, fields) + self.assertEqual(result.values[0], expected) + + def test_decimal_precision_overflow_returns_null(self): + """Values exceeding declared precision should be stored as null.""" + # DECIMAL(4, 2) can hold at most 2 integer + 2 fractional digits => max 99.99 + fields = [DataField(0, "d", AtomicType("DECIMAL(4, 2)"))] + + # 999.99 needs 5 digits total, exceeds precision=4 + row = GenericRow([Decimal("999.99")], fields, RowKind.INSERT) + serialized = GenericRowSerializer.to_bytes(row) + result = GenericRowDeserializer.from_bytes(serialized, fields) + self.assertIsNone(result.values[0]) + + # 99.999 rounds to 100.00 (5 digits), also overflows + row2 = GenericRow([Decimal("99.999")], fields, RowKind.INSERT) + serialized2 = GenericRowSerializer.to_bytes(row2) + result2 = GenericRowDeserializer.from_bytes(serialized2, fields) + self.assertIsNone(result2.values[0]) + + # 99.99 fits exactly in DECIMAL(4, 2) + row3 = GenericRow([Decimal("99.99")], fields, RowKind.INSERT) + serialized3 = GenericRowSerializer.to_bytes(row3) + result3 = GenericRowDeserializer.from_bytes(serialized3, fields) + self.assertEqual(result3.values[0], Decimal("99.99")) + + def test_decimal_precision_overflow_high_precision(self): + """Precision overflow check also works for non-compact decimals.""" + # DECIMAL(20, 5) can hold 15 integer + 5 fractional digits + fields = [DataField(0, "d", AtomicType("DECIMAL(20, 5)"))] + + # This value fits: 15 integer digits + 5 fractional + row = GenericRow([Decimal("123456789012345.12345")], fields, RowKind.INSERT) + serialized = GenericRowSerializer.to_bytes(row) + result = GenericRowDeserializer.from_bytes(serialized, fields) + self.assertEqual(result.values[0], Decimal("123456789012345.12345")) + + # This value overflows: 16 integer digits + 5 fractional = 21 > 20 + row2 = GenericRow([Decimal("1234567890123456.12345")], fields, RowKind.INSERT) + serialized2 = GenericRowSerializer.to_bytes(row2) + result2 = GenericRowDeserializer.from_bytes(serialized2, fields) + self.assertIsNone(result2.values[0]) + + def test_decimal_deserialization_precision_overflow_non_compact(self): + """Non-compact decimal deserialization returns None if precision overflows.""" + # Serialize with DECIMAL(38, 5) which fits, then deserialize as DECIMAL(20, 5) + fields_wide = [DataField(0, "d", AtomicType("DECIMAL(38, 5)"))] + fields_narrow = [DataField(0, "d", AtomicType("DECIMAL(20, 5)"))] + + # 21 digits total exceeds precision=20 + row = GenericRow([Decimal("1234567890123456.12345")], fields_wide, RowKind.INSERT) + serialized = GenericRowSerializer.to_bytes(row) + result = GenericRowDeserializer.from_bytes(serialized, fields_narrow) + self.assertIsNone(result.values[0]) + + def test_decimal_deserialization_invalid_precision(self): + """Deserialization with precision <= 0 raises ValueError.""" + fields_valid = [DataField(0, "d", AtomicType("DECIMAL(10, 2)"))] + row = GenericRow([Decimal("1.23")], fields_valid, RowKind.INSERT) + serialized = GenericRowSerializer.to_bytes(row) + + fields_bad = [DataField(0, "d", AtomicType("DECIMAL(0, 2)"))] + with self.assertRaises(ValueError): + GenericRowDeserializer.from_bytes(serialized, fields_bad) + + +if __name__ == '__main__': + unittest.main() diff --git a/paimon-python/pypaimon/tests/timestamp_test.py b/paimon-python/pypaimon/tests/timestamp_test.py new file mode 100644 index 000000000000..360f4fa099a5 --- /dev/null +++ b/paimon-python/pypaimon/tests/timestamp_test.py @@ -0,0 +1,200 @@ +""" +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import struct +import unittest +from datetime import datetime +from decimal import Decimal + +from pypaimon.schema.data_types import AtomicType, DataField +from pypaimon.table.row.generic_row import GenericRow, GenericRowSerializer, GenericRowDeserializer +from pypaimon.table.row.row_kind import RowKind + + +class TimestampTest(unittest.TestCase): + """Tests for timestamp serialization/deserialization in GenericRow, + aligned with Java BinaryRow's compact/non-compact format.""" + + def test_timestamp_compact(self): + """Compact timestamp (precision <= 3): epoch millis stored directly in fixed part.""" + for type_str in ["TIMESTAMP(0)", "TIMESTAMP(3)"]: + with self.subTest(type=type_str): + fields = [DataField(0, "ts", AtomicType(type_str))] + ts = datetime(2025, 4, 8, 10, 30, 0, 123000) # .123 seconds + row = GenericRow([ts], fields, RowKind.INSERT) + serialized = GenericRowSerializer.to_bytes(row) + result = GenericRowDeserializer.from_bytes(serialized, fields) + # compact stores millis, so microsecond part is truncated to millis + expected_millis = int(ts.timestamp() * 1000) + actual_millis = int(result.values[0].timestamp() * 1000) + self.assertEqual(actual_millis, expected_millis) + + def test_timestamp_compact_binary_format(self): + """Verify compact timestamp binary layout: epoch millis in fixed slot, no variable area.""" + fields = [DataField(0, "ts", AtomicType("TIMESTAMP(3)"))] + ts = datetime(2025, 4, 8, 10, 30, 0) + row = GenericRow([ts], fields, RowKind.INSERT) + serialized = GenericRowSerializer.to_bytes(row) + + data = serialized[4:] # skip arity prefix + null_bits_size = 8 + fixed_part_size = null_bits_size + 1 * 8 + # No variable area for compact timestamp + self.assertEqual(len(data), fixed_part_size) + # Fixed slot contains epoch millis + field_offset = null_bits_size + millis = struct.unpack('> 32) & 0xFFFFFFFF + nano_of_millisecond = offset_and_nano & 0xFFFFFFFF + + # cursor should point to variable area + self.assertEqual(cursor, fixed_part_size) + # 123456 us = 123 ms + 456 us = 123 ms + 456000 ns + self.assertEqual(nano_of_millisecond, 456000) + + # Variable area contains epoch millis + var_millis = struct.unpack('