Skip to content

Commit a6cfadb

Browse files
authored
feat: Improve compatibility of string to decimal cast (#2925)
1 parent 1076f35 commit a6cfadb

File tree

4 files changed

+395
-24
lines changed

4 files changed

+395
-24
lines changed

docs/source/user-guide/latest/compatibility.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,8 @@ The following cast operations are not compatible with Spark for all inputs and a
183183
| double | decimal | There can be rounding differences |
184184
| string | float | Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. Does not support ANSI mode. |
185185
| string | double | Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. Does not support ANSI mode. |
186-
| string | decimal | Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. Does not support ANSI mode. Returns 0.0 instead of null if input contains no digits |
186+
| string | decimal | Does not support fullwidth unicode digits (e.g \\uFF10)
187+
or strings containing null bytes (e.g \\u0000) |
187188
| string | timestamp | Not all valid formats are supported |
188189
<!-- prettier-ignore-end -->
189190
<!--END:INCOMPAT_CAST_TABLE-->

native/spark-expr/src/conversion_funcs/cast.rs

Lines changed: 312 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,13 @@ use crate::{timezone, BinaryOutputStyle};
2020
use crate::{EvalMode, SparkError, SparkResult};
2121
use arrow::array::builder::StringBuilder;
2222
use arrow::array::{
23-
BooleanBuilder, Decimal128Builder, DictionaryArray, GenericByteArray, ListArray, StringArray,
24-
StructArray,
23+
BooleanBuilder, Decimal128Builder, DictionaryArray, GenericByteArray, ListArray,
24+
PrimitiveBuilder, StringArray, StructArray,
2525
};
2626
use arrow::compute::can_cast_types;
2727
use arrow::datatypes::{
28-
ArrowDictionaryKeyType, ArrowNativeType, DataType, GenericBinaryType, Schema,
28+
i256, ArrowDictionaryKeyType, ArrowNativeType, DataType, Decimal256Type, GenericBinaryType,
29+
Schema,
2930
};
3031
use arrow::{
3132
array::{
@@ -224,9 +225,7 @@ fn can_cast_from_string(to_type: &DataType, options: &SparkCastOptions) -> bool
224225
}
225226
Decimal128(_, _) => {
226227
// https://github.com/apache/datafusion-comet/issues/325
227-
// Does not support inputs ending with 'd' or 'f'. Does not support 'inf'.
228-
// Does not support ANSI mode. Returns 0.0 instead of null if input contains no digits
229-
228+
// Does not support fullwidth digits and null byte handling.
230229
options.allow_incompat
231230
}
232231
Date32 | Date64 => {
@@ -976,6 +975,12 @@ fn cast_array(
976975
cast_string_to_timestamp(&array, to_type, eval_mode, &cast_options.timezone)
977976
}
978977
(Utf8, Date32) => cast_string_to_date(&array, to_type, eval_mode),
978+
(Utf8 | LargeUtf8, Decimal128(precision, scale)) => {
979+
cast_string_to_decimal(&array, to_type, precision, scale, eval_mode)
980+
}
981+
(Utf8 | LargeUtf8, Decimal256(precision, scale)) => {
982+
cast_string_to_decimal(&array, to_type, precision, scale, eval_mode)
983+
}
979984
(Int64, Int32)
980985
| (Int64, Int16)
981986
| (Int64, Int8)
@@ -1187,7 +1192,7 @@ fn is_datafusion_spark_compatible(
11871192
),
11881193
DataType::Utf8 if allow_incompat => matches!(
11891194
to_type,
1190-
DataType::Binary | DataType::Float32 | DataType::Float64 | DataType::Decimal128(_, _)
1195+
DataType::Binary | DataType::Float32 | DataType::Float64
11911196
),
11921197
DataType::Utf8 => matches!(to_type, DataType::Binary),
11931198
DataType::Date32 => matches!(to_type, DataType::Utf8),
@@ -1976,6 +1981,306 @@ fn do_cast_string_to_int<
19761981
Ok(Some(result))
19771982
}
19781983

1984+
fn cast_string_to_decimal(
1985+
array: &ArrayRef,
1986+
to_type: &DataType,
1987+
precision: &u8,
1988+
scale: &i8,
1989+
eval_mode: EvalMode,
1990+
) -> SparkResult<ArrayRef> {
1991+
match to_type {
1992+
DataType::Decimal128(_, _) => {
1993+
cast_string_to_decimal128_impl(array, eval_mode, *precision, *scale)
1994+
}
1995+
DataType::Decimal256(_, _) => {
1996+
cast_string_to_decimal256_impl(array, eval_mode, *precision, *scale)
1997+
}
1998+
_ => Err(SparkError::Internal(format!(
1999+
"Unexpected type in cast_string_to_decimal: {:?}",
2000+
to_type
2001+
))),
2002+
}
2003+
}
2004+
2005+
fn cast_string_to_decimal128_impl(
2006+
array: &ArrayRef,
2007+
eval_mode: EvalMode,
2008+
precision: u8,
2009+
scale: i8,
2010+
) -> SparkResult<ArrayRef> {
2011+
let string_array = array
2012+
.as_any()
2013+
.downcast_ref::<StringArray>()
2014+
.ok_or_else(|| SparkError::Internal("Expected string array".to_string()))?;
2015+
2016+
let mut decimal_builder = Decimal128Builder::with_capacity(string_array.len());
2017+
2018+
for i in 0..string_array.len() {
2019+
if string_array.is_null(i) {
2020+
decimal_builder.append_null();
2021+
} else {
2022+
let str_value = string_array.value(i);
2023+
match parse_string_to_decimal(str_value, precision, scale) {
2024+
Ok(Some(decimal_value)) => {
2025+
decimal_builder.append_value(decimal_value);
2026+
}
2027+
Ok(None) => {
2028+
if eval_mode == EvalMode::Ansi {
2029+
return Err(invalid_value(
2030+
string_array.value(i),
2031+
"STRING",
2032+
&format!("DECIMAL({},{})", precision, scale),
2033+
));
2034+
}
2035+
decimal_builder.append_null();
2036+
}
2037+
Err(e) => {
2038+
if eval_mode == EvalMode::Ansi {
2039+
return Err(e);
2040+
}
2041+
decimal_builder.append_null();
2042+
}
2043+
}
2044+
}
2045+
}
2046+
2047+
Ok(Arc::new(
2048+
decimal_builder
2049+
.with_precision_and_scale(precision, scale)?
2050+
.finish(),
2051+
))
2052+
}
2053+
2054+
fn cast_string_to_decimal256_impl(
2055+
array: &ArrayRef,
2056+
eval_mode: EvalMode,
2057+
precision: u8,
2058+
scale: i8,
2059+
) -> SparkResult<ArrayRef> {
2060+
let string_array = array
2061+
.as_any()
2062+
.downcast_ref::<StringArray>()
2063+
.ok_or_else(|| SparkError::Internal("Expected string array".to_string()))?;
2064+
2065+
let mut decimal_builder = PrimitiveBuilder::<Decimal256Type>::with_capacity(string_array.len());
2066+
2067+
for i in 0..string_array.len() {
2068+
if string_array.is_null(i) {
2069+
decimal_builder.append_null();
2070+
} else {
2071+
let str_value = string_array.value(i);
2072+
match parse_string_to_decimal(str_value, precision, scale) {
2073+
Ok(Some(decimal_value)) => {
2074+
// Convert i128 to i256
2075+
let i256_value = i256::from_i128(decimal_value);
2076+
decimal_builder.append_value(i256_value);
2077+
}
2078+
Ok(None) => {
2079+
if eval_mode == EvalMode::Ansi {
2080+
return Err(invalid_value(
2081+
str_value,
2082+
"STRING",
2083+
&format!("DECIMAL({},{})", precision, scale),
2084+
));
2085+
}
2086+
decimal_builder.append_null();
2087+
}
2088+
Err(e) => {
2089+
if eval_mode == EvalMode::Ansi {
2090+
return Err(e);
2091+
}
2092+
decimal_builder.append_null();
2093+
}
2094+
}
2095+
}
2096+
}
2097+
2098+
Ok(Arc::new(
2099+
decimal_builder
2100+
.with_precision_and_scale(precision, scale)?
2101+
.finish(),
2102+
))
2103+
}
2104+
2105+
/// Parse a string to decimal following Spark's behavior
2106+
fn parse_string_to_decimal(s: &str, precision: u8, scale: i8) -> SparkResult<Option<i128>> {
2107+
let string_bytes = s.as_bytes();
2108+
let mut start = 0;
2109+
let mut end = string_bytes.len();
2110+
2111+
// trim whitespaces
2112+
while start < end && string_bytes[start].is_ascii_whitespace() {
2113+
start += 1;
2114+
}
2115+
while end > start && string_bytes[end - 1].is_ascii_whitespace() {
2116+
end -= 1;
2117+
}
2118+
2119+
let trimmed = &s[start..end];
2120+
2121+
if trimmed.is_empty() {
2122+
return Ok(None);
2123+
}
2124+
// Handle special values (inf, nan, etc.)
2125+
if trimmed.eq_ignore_ascii_case("inf")
2126+
|| trimmed.eq_ignore_ascii_case("+inf")
2127+
|| trimmed.eq_ignore_ascii_case("infinity")
2128+
|| trimmed.eq_ignore_ascii_case("+infinity")
2129+
|| trimmed.eq_ignore_ascii_case("-inf")
2130+
|| trimmed.eq_ignore_ascii_case("-infinity")
2131+
|| trimmed.eq_ignore_ascii_case("nan")
2132+
{
2133+
return Ok(None);
2134+
}
2135+
2136+
// validate and parse mantissa and exponent
2137+
match parse_decimal_str(trimmed) {
2138+
Ok((mantissa, exponent)) => {
2139+
// Convert to target scale
2140+
let target_scale = scale as i32;
2141+
let scale_adjustment = target_scale - exponent;
2142+
2143+
let scaled_value = if scale_adjustment >= 0 {
2144+
// Need to multiply (increase scale) but return None if scale is too high to fit i128
2145+
if scale_adjustment > 38 {
2146+
return Ok(None);
2147+
}
2148+
mantissa.checked_mul(10_i128.pow(scale_adjustment as u32))
2149+
} else {
2150+
// Need to multiply (increase scale) but return None if scale is too high to fit i128
2151+
let abs_scale_adjustment = (-scale_adjustment) as u32;
2152+
if abs_scale_adjustment > 38 {
2153+
return Ok(Some(0));
2154+
}
2155+
2156+
let divisor = 10_i128.pow(abs_scale_adjustment);
2157+
let quotient_opt = mantissa.checked_div(divisor);
2158+
// Check if divisor is 0
2159+
if quotient_opt.is_none() {
2160+
return Ok(None);
2161+
}
2162+
let quotient = quotient_opt.unwrap();
2163+
let remainder = mantissa % divisor;
2164+
2165+
// Round half up: if abs(remainder) >= divisor/2, round away from zero
2166+
let half_divisor = divisor / 2;
2167+
let rounded = if remainder.abs() >= half_divisor {
2168+
if mantissa >= 0 {
2169+
quotient + 1
2170+
} else {
2171+
quotient - 1
2172+
}
2173+
} else {
2174+
quotient
2175+
};
2176+
Some(rounded)
2177+
};
2178+
2179+
match scaled_value {
2180+
Some(value) => {
2181+
// Check if it fits target precision
2182+
if is_validate_decimal_precision(value, precision) {
2183+
Ok(Some(value))
2184+
} else {
2185+
Ok(None)
2186+
}
2187+
}
2188+
None => {
2189+
// Overflow while scaling
2190+
Ok(None)
2191+
}
2192+
}
2193+
}
2194+
Err(_) => Ok(None),
2195+
}
2196+
}
2197+
2198+
/// Parse a decimal string into mantissa and scale
2199+
/// e.g., "123.45" -> (12345, 2), "-0.001" -> (-1, 3)
2200+
fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> {
2201+
if s.is_empty() {
2202+
return Err("Empty string".to_string());
2203+
}
2204+
2205+
let (mantissa_str, exponent) = if let Some(e_pos) = s.find(|c| ['e', 'E'].contains(&c)) {
2206+
let mantissa_part = &s[..e_pos];
2207+
let exponent_part = &s[e_pos + 1..];
2208+
// Parse exponent
2209+
let exp: i32 = exponent_part
2210+
.parse()
2211+
.map_err(|e| format!("Invalid exponent: {}", e))?;
2212+
2213+
(mantissa_part, exp)
2214+
} else {
2215+
(s, 0)
2216+
};
2217+
2218+
let negative = mantissa_str.starts_with('-');
2219+
let mantissa_str = if negative || mantissa_str.starts_with('+') {
2220+
&mantissa_str[1..]
2221+
} else {
2222+
mantissa_str
2223+
};
2224+
2225+
if mantissa_str.starts_with('+') || mantissa_str.starts_with('-') {
2226+
return Err("Invalid sign format".to_string());
2227+
}
2228+
2229+
let (integral_part, fractional_part) = match mantissa_str.find('.') {
2230+
Some(dot_pos) => {
2231+
if mantissa_str[dot_pos + 1..].contains('.') {
2232+
return Err("Multiple decimal points".to_string());
2233+
}
2234+
(&mantissa_str[..dot_pos], &mantissa_str[dot_pos + 1..])
2235+
}
2236+
None => (mantissa_str, ""),
2237+
};
2238+
2239+
if integral_part.is_empty() && fractional_part.is_empty() {
2240+
return Err("No digits found".to_string());
2241+
}
2242+
2243+
if !integral_part.is_empty() && !integral_part.bytes().all(|b| b.is_ascii_digit()) {
2244+
return Err("Invalid integral part".to_string());
2245+
}
2246+
2247+
if !fractional_part.is_empty() && !fractional_part.bytes().all(|b| b.is_ascii_digit()) {
2248+
return Err("Invalid fractional part".to_string());
2249+
}
2250+
2251+
// Parse integral part
2252+
let integral_value: i128 = if integral_part.is_empty() {
2253+
// Empty integral part is valid (e.g., ".5" or "-.7e9")
2254+
0
2255+
} else {
2256+
integral_part
2257+
.parse()
2258+
.map_err(|_| "Invalid integral part".to_string())?
2259+
};
2260+
2261+
// Parse fractional part
2262+
let fractional_scale = fractional_part.len() as i32;
2263+
let fractional_value: i128 = if fractional_part.is_empty() {
2264+
0
2265+
} else {
2266+
fractional_part
2267+
.parse()
2268+
.map_err(|_| "Invalid fractional part".to_string())?
2269+
};
2270+
2271+
// Combine: value = integral * 10^fractional_scale + fractional
2272+
let mantissa = integral_value
2273+
.checked_mul(10_i128.pow(fractional_scale as u32))
2274+
.and_then(|v| v.checked_add(fractional_value))
2275+
.ok_or("Overflow in mantissa calculation")?;
2276+
2277+
let final_mantissa = if negative { -mantissa } else { mantissa };
2278+
// final scale = fractional_scale - exponent
2279+
// For example : "1.23E-5" has fractional_scale=2, exponent=-5, so scale = 2 - (-5) = 7
2280+
let final_scale = fractional_scale - exponent;
2281+
Ok((final_mantissa, final_scale))
2282+
}
2283+
19792284
/// Either return Ok(None) or Err(SparkError::CastInvalidValue) depending on the evaluation mode
19802285
#[inline]
19812286
fn none_or_err<T>(eval_mode: EvalMode, type_name: &str, str: &str) -> SparkResult<Option<T>> {

spark/src/main/scala/org/apache/comet/expressions/CometCast.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,9 +192,8 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
192192
"Does not support ANSI mode."))
193193
case _: DecimalType =>
194194
// https://github.com/apache/datafusion-comet/issues/325
195-
Incompatible(
196-
Some("Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. " +
197-
"Does not support ANSI mode. Returns 0.0 instead of null if input contains no digits"))
195+
Incompatible(Some("""Does not support fullwidth unicode digits (e.g \\uFF10)
196+
|or strings containing null bytes (e.g \\u0000)""".stripMargin))
198197
case DataTypes.DateType =>
199198
// https://github.com/apache/datafusion-comet/issues/327
200199
Compatible(Some("Only supports years between 262143 BC and 262142 AD"))

0 commit comments

Comments
 (0)