Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
584 changes: 569 additions & 15 deletions c/driver/postgresql/copy/postgres_copy_writer_test.cc
Comment thread
Mandukhai-Alimaa marked this conversation as resolved.

Large diffs are not rendered by default.

282 changes: 226 additions & 56 deletions c/driver/postgresql/copy/writer.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#pragma once

#include <algorithm>
#include <charconv>
#include <cinttypes>
#include <limits>
Expand Down Expand Up @@ -224,82 +225,141 @@ class PostgresCopyNumericFieldWriter : public PostgresCopyFieldWriter {
PostgresCopyNumericFieldWriter(int32_t precision, int32_t scale)
: precision_{precision}, scale_{scale} {}

// PostgreSQL NUMERIC Binary Format:
// ===================================
// PostgreSQL stores NUMERIC values in a variable-length binary format:
// - ndigits (int16): Number of base-10000 digits stored
// - weight (int16): Position of the first digit group relative to decimal point
// (weight can be negative for small fractional numbers)
// - sign (int16): kNumericPos (0x0000) or kNumericNeg (0x4000)
// - dscale (int16): Number of decimal digits after the decimal point (display scale)
// - digits[]: Array of int16 values, each 0-9999 (base-10000 representation)
//
// Value calculation: sum(digits[i] * 10000^(weight - i)) * 10^(-dscale)
//
// Example 1: 12300 (from Arrow Decimal value=123, scale=-2)
// - Logical representation: "12300"
// - Grouped in base-10000: [1][2300]
// - ndigits=2, weight=1, sign=0x0000, dscale=0, digits=[1, 2300]
// - Calculation: 1*10000^1 + 2300*10000^0 = 10000 + 2300 = 12300
//
// Example 2: 123.45 (from Arrow Decimal value=12345, scale=2)
// - Logical representation: "123.45"
// - Integer part "123", fractional part "45"
// - Grouped in base-10000: [123][4500] (fractional part right-padded)
// - ndigits=2, weight=0, sign=0x0000, dscale=2, digits=[123, 4500]
// - Calculation: 123*10000^0 + 4500*10000^(-1) = 123 + 0.45 = 123.45
//
// Example 3: 0.00123 (from Arrow Decimal value=123, scale=5)
// - Logical representation: "0.00123"
// - Integer part "0", fractional part "00123"
// - Grouped in base-10000: [123] (leading zeros skipped via negative weight)
// - ndigits=1, weight=-1, sign=0x0000, dscale=5, digits=[123]
// - Calculation: 123*10000^(-1) * 10^0 = 0.0123, but dscale=5 means display as
// 0.00123

ArrowErrorCode Write(ArrowBuffer* buffer, int64_t index, ArrowError* error) override {
struct ArrowDecimal decimal;
ArrowDecimalInit(&decimal, bitwidth_, precision_, scale_);
ArrowArrayViewGetDecimalUnsafe(array_view_, index, &decimal);

const int16_t sign = ArrowDecimalSign(&decimal) > 0 ? kNumericPos : kNumericNeg;

// Number of decimal digits per Postgres digit
constexpr int kDecDigits = 4;
std::vector<int16_t> pg_digits;
int16_t weight = -(scale_ / kDecDigits);
int16_t dscale = scale_;
bool seen_decimal = scale_ == 0;
bool truncating_trailing_zeros = true;

char decimal_string[max_decimal_digits_ + 1];
int digits_remaining = DecimalToString<bitwidth_>(&decimal, decimal_string);
do {
const int start_pos =
digits_remaining < kDecDigits ? 0 : digits_remaining - kDecDigits;
const size_t len = digits_remaining < 4 ? digits_remaining : kDecDigits;
const std::string_view substr{decimal_string + start_pos, len};
int16_t val{};
std::from_chars(substr.data(), substr.data() + substr.size(), val);

if (val == 0) {
if (!seen_decimal && truncating_trailing_zeros) {
dscale -= kDecDigits;
}
} else {
pg_digits.insert(pg_digits.begin(), val);
if (!seen_decimal && truncating_trailing_zeros) {
if (val % 1000 == 0) {
dscale -= 3;
} else if (val % 100 == 0) {
dscale -= 2;
} else if (val % 10 == 0) {
dscale -= 1;
}
}
truncating_trailing_zeros = false;
}
digits_remaining -= kDecDigits;
if (digits_remaining <= 0) {
break;
}
weight++;

if (start_pos <= static_cast<int>(std::strlen(decimal_string)) - scale_) {
seen_decimal = true;
}
} while (true);

int16_t ndigits = pg_digits.size();
int32_t field_size_bytes = sizeof(ndigits) + sizeof(weight) + sizeof(sign) +
// Convert decimal to string and split into integer/fractional parts
// Example transformation for Arrow Decimal(value=12345, scale=2) representing 123.45:
// Input: decimal.value = 12345, scale_ = 2
// After DecimalToString: raw_decimal_string = "12345", original_digits = 5
// After SplitDecimalParts: parts.integer_part = "123"
// parts.fractional_part = "45"
// parts.effective_scale = 2
char raw_decimal_string[max_decimal_digits_ + 1];
int original_digits = DecimalToString<bitwidth_>(&decimal, raw_decimal_string);
DecimalParts parts = SplitDecimalParts(raw_decimal_string, original_digits, scale_);

// Group into PostgreSQL base-10000 representation
// After GroupIntegerDigits: int_digits = [123], weight = 0
// (groups "123" right-to-left: "123" → 123, only 1 group so weight = 0)
auto [int_digits, weight] = GroupIntegerDigits(parts.integer_part);

// After GroupFractionalDigits: frac_digits = [4500], final_weight = 0
// (groups "45" left-to-right with right-padding: "45" → "4500" → 4500)
auto [frac_digits, final_weight] =
GroupFractionalDigits(parts.fractional_part, weight, !parts.integer_part.empty());

// Combine digit arrays
// After combining: all_digits = [123, 4500]
std::vector<int16_t> all_digits = int_digits;
all_digits.insert(all_digits.end(), frac_digits.begin(), frac_digits.end());

// Calculate display scale by counting trailing zeros in the DECIMAL STRING
// For our example: frac_part="45" has 0 trailing zeros, effective_scale=2
// So dscale = 2 - 0 = 2 (2 fractional digits to display)
int trailing_zeros = 0;
for (int j = parts.fractional_part.length() - 1;
j >= 0 && parts.fractional_part[j] == '0'; j--) {
trailing_zeros++;
}
int16_t dscale =
static_cast<int16_t>((std::max)(0, parts.effective_scale - trailing_zeros));

// Optimize: remove trailing zero digit groups from fractional part
int n_int_digit_groups = int_digits.size();
while (static_cast<int>(all_digits.size()) > n_int_digit_groups &&
all_digits.back() == 0) {
all_digits.pop_back();
}

// Handle zero special case
if (all_digits.empty()) {
final_weight = 0;
dscale = 0;
} else if (static_cast<int>(all_digits.size()) <= n_int_digit_groups) {
// All fractional digits were removed
dscale = 0;
}

if (dscale < 0) dscale = 0;

// Write PostgreSQL NUMERIC binary format to buffer
// Final values for our example: ndigits = 2
// final_weight = 0
// sign = 0x0000
// dscale = 2
// digits = [123, 4500]
// Binary output represents: 123 * 10000^0 + 4500 * 10000^(-1) = 123 + 0.45 = 123.45
int16_t ndigits = all_digits.size();
int32_t field_size_bytes = sizeof(ndigits) + sizeof(final_weight) + sizeof(sign) +
sizeof(dscale) + ndigits * sizeof(int16_t);

NANOARROW_RETURN_NOT_OK(WriteChecked<int32_t>(buffer, field_size_bytes, error));
NANOARROW_RETURN_NOT_OK(WriteChecked<int16_t>(buffer, ndigits, error));
NANOARROW_RETURN_NOT_OK(WriteChecked<int16_t>(buffer, weight, error));
NANOARROW_RETURN_NOT_OK(WriteChecked<int16_t>(buffer, final_weight, error));
NANOARROW_RETURN_NOT_OK(WriteChecked<int16_t>(buffer, sign, error));
NANOARROW_RETURN_NOT_OK(WriteChecked<int16_t>(buffer, dscale, error));

const size_t pg_digit_bytes = sizeof(int16_t) * pg_digits.size();
const size_t pg_digit_bytes = sizeof(int16_t) * all_digits.size();
NANOARROW_RETURN_NOT_OK(ArrowBufferReserve(buffer, pg_digit_bytes));
for (auto pg_digit : pg_digits) {
for (auto pg_digit : all_digits) {
WriteUnsafe<int16_t>(buffer, pg_digit);
}

return ADBC_STATUS_OK;
}

private:
// returns the length of the string
// Helper struct for organizing data flow between functions
struct DecimalParts {
std::string integer_part; // e.g., "12300" or "123"
std::string fractional_part; // e.g., "45" or "00123"
int effective_scale; // Scale after handling negative values
};

// Helper function implementations for decimal-to-PostgreSQL NUMERIC conversion

// Convert decimal to string (absolute value, no sign)
// Returns the length of the string
template <int32_t DEC_WIDTH>
int DecimalToString(struct ArrowDecimal* decimal, char* out) {
int DecimalToString(struct ArrowDecimal* decimal, char* out) const {
constexpr size_t nwords = (DEC_WIDTH == 128) ? 2 : 4;
uint8_t tmp[DEC_WIDTH / 8];
ArrowDecimalGetBytes(decimal, tmp);
Expand All @@ -322,10 +382,9 @@ class PostgresCopyNumericFieldWriter : public PostgresCopyFieldWriter {
for (size_t i = 0; i < DEC_WIDTH; i++) {
int carry;

carry = (buf[nwords - 1] >= 0x7FFFFFFFFFFFFFFF);
carry = (buf[nwords - 1] > 0x7FFFFFFFFFFFFFFF);
for (size_t j = nwords - 1; j > 0; j--) {
buf[j] =
((buf[j] << 1) & 0xFFFFFFFFFFFFFFFF) + (buf[j - 1] >= 0x7FFFFFFFFFFFFFFF);
buf[j] = ((buf[j] << 1) & 0xFFFFFFFFFFFFFFFF) + (buf[j - 1] > 0x7FFFFFFFFFFFFFFF);
}
buf[0] = ((buf[0] << 1) & 0xFFFFFFFFFFFFFFFF);

Expand All @@ -350,6 +409,117 @@ class PostgresCopyNumericFieldWriter : public PostgresCopyFieldWriter {
return ndigits;
}

DecimalParts SplitDecimalParts(const char* decimal_digits, int digit_count,
int scale) const {
// Virtual zeros represent the logical zeros appended for negative scale
// Example: value=123, scale=-2 → "123" with 2 virtual zeros = "12300"
const int virtual_zeros = (scale < 0) ? -scale : 0;
const int effective_scale = (scale < 0) ? 0 : scale;
const int total_logical_digits = digit_count + virtual_zeros;

// Calculate split point
const int n_int_digits = total_logical_digits > effective_scale
? total_logical_digits - effective_scale
: 0;
const int n_frac_digits = total_logical_digits - n_int_digits;

DecimalParts parts;
parts.effective_scale = effective_scale;

// Extract integer part
if (n_int_digits > 0) {
if (n_int_digits <= digit_count) {
// Integer part is within the original digits
parts.integer_part.assign(decimal_digits, n_int_digits);
} else {
// Integer part includes all original digits + virtual zeros
parts.integer_part.assign(decimal_digits, digit_count);
parts.integer_part.append(virtual_zeros, '0');
}
}

// Extract fractional part (only exists if scale > 0)
if (n_int_digits == 0 && total_logical_digits < effective_scale) {
// Small fractional: 0.00123 needs leading zeros
parts.fractional_part.assign(effective_scale - total_logical_digits, '0');
parts.fractional_part.append(decimal_digits, digit_count);
} else if (n_frac_digits > 0 && n_int_digits < digit_count) {
// Fractional part from remaining digits (virtual zeros don't appear in fractional
// part)
parts.fractional_part.assign(decimal_digits + n_int_digits,
digit_count - n_int_digits);
}

return parts;
}

std::pair<std::vector<int16_t>, int16_t> GroupIntegerDigits(
const std::string& int_part) const {
constexpr int kDecDigits = 4;
std::vector<int16_t> digits;

if (int_part.empty()) {
return {digits, -1}; // weight = -1 for pure fractional numbers
}

// Calculate weight: ceil(length / 4) - 1
int16_t weight = (int_part.length() + kDecDigits - 1) / kDecDigits - 1;

// Group right-to-left in chunks of 4
int i = int_part.length();
while (i > 0) {
int chunk_size = (std::min)(i, kDecDigits);
std::string_view chunk =
std::string_view(int_part).substr(i - chunk_size, chunk_size);

int16_t val{};
std::from_chars(chunk.data(), chunk.data() + chunk.size(), val);

// Skip trailing zeros
if (val != 0 || !digits.empty()) {
digits.insert(digits.begin(), val);
}
i -= chunk_size;
}

return {digits, weight};
}

std::pair<std::vector<int16_t>, int16_t> GroupFractionalDigits(
const std::string& frac_part, int16_t initial_weight, bool has_integer_part) const {
constexpr int kDecDigits = 4;
std::vector<int16_t> digits;
int16_t weight = initial_weight;

if (frac_part.empty()) {
return {digits, weight};
}

bool skip_leading_zeros = !has_integer_part;

// Group left-to-right in chunks of 4, right-padding last chunk
for (size_t i = 0; i < frac_part.length(); i += kDecDigits) {
int chunk_size = (std::min)(kDecDigits, static_cast<int>(frac_part.length() - i));
std::string chunk_str = frac_part.substr(i, chunk_size);

// Right-pad to 4 digits (e.g., "45" → "4500")
chunk_str.resize(kDecDigits, '0');

int16_t val{};
std::from_chars(chunk_str.data(), chunk_str.data() + chunk_str.size(), val);

if (skip_leading_zeros && val == 0) {
// Skip leading zero groups in fractional part (e.g., 0.0012 → skip "0012")
weight--;
} else {
digits.push_back(val);
skip_leading_zeros = false;
}
}

return {digits, weight};
}

static constexpr uint16_t kNumericPos = 0x0000;
static constexpr uint16_t kNumericNeg = 0x4000;
static constexpr int32_t bitwidth_ = (T == NANOARROW_TYPE_DECIMAL128) ? 128 : 256;
Expand Down
19 changes: 0 additions & 19 deletions c/driver/postgresql/validation/queries/ingest/decimal.toml

This file was deleted.

45 changes: 45 additions & 0 deletions c/driver/postgresql/validation/queries/ingest/decimal.txtcase
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

// part: expected_schema
{
"format": "+s",
"children": [
{
"name": "idx",
"format": "l",
"flags": ["nullable"]
},
{
"name": "value",
"format": "u",
"flags": ["nullable"],
"metadata": {
"ARROW:extension:name": "arrow.opaque",
"ARROW:extension:metadata": "{\"type_name\": \"numeric\", \"vendor_name\": \"PostgreSQL\"}"
}
}
]
}

// part: expected

{"idx": 0, "value": "0"}
{"idx": 1, "value": "123.45"}
{"idx": 2, "value": "-123.45"}
{"idx": 3, "value": "9999999.99"}
{"idx": 4, "value": "-9999999.99"}
Loading
Loading