diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index 5011917082..dc225d5267 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -54,8 +54,8 @@ use datafusion::common::{ use datafusion::physical_expr::PhysicalExpr; use datafusion::physical_plan::ColumnarValue; use num::{ - cast::AsPrimitive, integer::div_floor, traits::CheckedNeg, CheckedSub, Integer, Num, - ToPrimitive, Zero, + cast::AsPrimitive, integer::div_floor, traits::CheckedNeg, CheckedSub, Integer, ToPrimitive, + Zero, }; use regex::Regex; use std::str::FromStr; @@ -389,13 +389,23 @@ macro_rules! cast_utf8_to_int { ($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident) => {{ let len = $array.len(); let mut cast_array = PrimitiveArray::<$array_type>::builder(len); - for i in 0..len { - if $array.is_null(i) { - cast_array.append_null() - } else if let Some(cast_value) = $cast_method($array.value(i), $eval_mode)? { - cast_array.append_value(cast_value); - } else { - cast_array.append_null() + if $array.null_count() == 0 { + for i in 0..len { + if let Some(cast_value) = $cast_method($array.value(i), $eval_mode)? { + cast_array.append_value(cast_value); + } else { + cast_array.append_null() + } + } + } else { + for i in 0..len { + if $array.is_null(i) { + cast_array.append_null() + } else if let Some(cast_value) = $cast_method($array.value(i), $eval_mode)? { + cast_array.append_value(cast_value); + } else { + cast_array.append_null() + } } } let result: SparkResult = Ok(Arc::new(cast_array.finish()) as ArrayRef); @@ -1954,82 +1964,216 @@ fn cast_string_to_int_with_range_check( } } +// Returns (start, end) indices after trimming whitespace +fn trim_whitespace(bytes: &[u8]) -> (usize, usize) { + let mut start = 0; + let mut end = bytes.len(); + + while start < end && bytes[start].is_ascii_whitespace() { + start += 1; + } + while end > start && bytes[end - 1].is_ascii_whitespace() { + end -= 1; + } + + (start, end) +} + +// Parses sign and returns (is_negative, start_idx after sign) +// Returns None if invalid (e.g., just "+" or "-") +fn parse_sign(trimmed_bytes: &[u8]) -> Option<(bool, usize)> { + let len = trimmed_bytes.len(); + if len == 0 { + return None; + } + + let first_char = trimmed_bytes[0]; + let negative = first_char == b'-'; + + if negative || first_char == b'+' { + if len == 1 { + return None; + } + Some((negative, 1)) + } else { + Some((false, 0)) + } +} + /// Equivalent to /// - org.apache.spark.unsafe.types.UTF8String.toInt(IntWrapper intWrapper, boolean allowDecimal) /// - org.apache.spark.unsafe.types.UTF8String.toLong(LongWrapper longWrapper, boolean allowDecimal) -fn do_cast_string_to_int< - T: Num + PartialOrd + Integer + CheckedSub + CheckedNeg + From + Copy, ->( +fn do_parse_string_to_int_legacy + Copy>( str: &str, - eval_mode: EvalMode, - type_name: &str, min_value: T, ) -> SparkResult> { - let trimmed_str = str.trim(); - if trimmed_str.is_empty() { - return none_or_err(eval_mode, type_name, str); + let bytes = str.as_bytes(); + let (start, end) = trim_whitespace(bytes); + + if start == end { + return Ok(None); } - let len = trimmed_str.len(); + let trimmed_bytes = &bytes[start..end]; + + let (negative, idx) = match parse_sign(trimmed_bytes) { + Some(result) => result, + None => return Ok(None), + }; + let mut result: T = T::zero(); - let mut negative = false; - let radix = T::from(10); + + let radix = T::from(10_u8); let stop_value = min_value / radix; let mut parse_sign_and_digits = true; - for (i, ch) in trimmed_str.char_indices() { + for &ch in &trimmed_bytes[idx..] { if parse_sign_and_digits { - if i == 0 { - negative = ch == '-'; - let positive = ch == '+'; - if negative || positive { - if i + 1 == len { - // input string is just "+" or "-" - return none_or_err(eval_mode, type_name, str); - } - // consume this char - continue; - } + if ch == b'.' { + // truncate decimal in legacy mode + parse_sign_and_digits = false; + continue; } - if ch == '.' { - if eval_mode == EvalMode::Legacy { - // truncate decimal in legacy mode - parse_sign_and_digits = false; - continue; - } else { - return none_or_err(eval_mode, type_name, str); - } + if !ch.is_ascii_digit() { + return Ok(None); } - let digit = if ch.is_ascii_digit() { - (ch as u32) - ('0' as u32) - } else { - return none_or_err(eval_mode, type_name, str); - }; + let digit: T = T::from(ch - b'0'); - // We are going to process the new digit and accumulate the result. However, before - // doing this, if the result is already smaller than the - // stopValue(Integer.MIN_VALUE / radix), then result * 10 will definitely be - // smaller than minValue, and we can stop if result < stop_value { - return none_or_err(eval_mode, type_name, str); + return Ok(None); } - - // Since the previous result is greater than or equal to stopValue(Integer.MIN_VALUE / - // radix), we can just use `result > 0` to check overflow. If result - // overflows, we should stop let v = result * radix; - let digit = (digit as i32).into(); match v.checked_sub(&digit) { Some(x) if x <= T::zero() => result = x, _ => { - return none_or_err(eval_mode, type_name, str); + return Ok(None); } } } else { - // make sure fractional digits are valid digits but ignore them + // in legacy mode we still process chars after the dot and make sure the chars are digits if !ch.is_ascii_digit() { - return none_or_err(eval_mode, type_name, str); + return Ok(None); + } + } + } + + if !negative { + if let Some(neg) = result.checked_neg() { + if neg < T::zero() { + return Ok(None); + } + result = neg; + } else { + return Ok(None); + } + } + + Ok(Some(result)) +} + +fn do_parse_string_to_int_ansi + Copy>( + str: &str, + type_name: &str, + min_value: T, +) -> SparkResult> { + let bytes = str.as_bytes(); + let (start, end) = trim_whitespace(bytes); + + if start == end { + return Err(invalid_value(str, "STRING", type_name)); + } + let trimmed_bytes = &bytes[start..end]; + + let (negative, idx) = match parse_sign(trimmed_bytes) { + Some(result) => result, + None => return Err(invalid_value(str, "STRING", type_name)), + }; + + let mut result: T = T::zero(); + + let radix = T::from(10_u8); + let stop_value = min_value / radix; + + for &ch in &trimmed_bytes[idx..] { + if ch == b'.' { + return Err(invalid_value(str, "STRING", type_name)); + } + + if !ch.is_ascii_digit() { + return Err(invalid_value(str, "STRING", type_name)); + } + + let digit: T = T::from(ch - b'0'); + + if result < stop_value { + return Err(invalid_value(str, "STRING", type_name)); + } + let v = result * radix; + match v.checked_sub(&digit) { + Some(x) if x <= T::zero() => result = x, + _ => { + return Err(invalid_value(str, "STRING", type_name)); + } + } + } + + if !negative { + if let Some(neg) = result.checked_neg() { + if neg < T::zero() { + return Err(invalid_value(str, "STRING", type_name)); + } + result = neg; + } else { + return Err(invalid_value(str, "STRING", type_name)); + } + } + + Ok(Some(result)) +} + +fn do_parse_string_to_int_try + Copy>( + str: &str, + min_value: T, +) -> SparkResult> { + let bytes = str.as_bytes(); + let (start, end) = trim_whitespace(bytes); + + if start == end { + return Ok(None); + } + let trimmed_bytes = &bytes[start..end]; + + let (negative, idx) = match parse_sign(trimmed_bytes) { + Some(result) => result, + None => return Ok(None), + }; + + let mut result: T = T::zero(); + + let radix = T::from(10_u8); + let stop_value = min_value / radix; + + // we don't have to go beyond decimal point in try eval mode - early return NULL + for &ch in &trimmed_bytes[idx..] { + if ch == b'.' { + return Ok(None); + } + + if !ch.is_ascii_digit() { + return Ok(None); + } + + let digit: T = T::from(ch - b'0'); + + if result < stop_value { + return Ok(None); + } + let v = result * radix; + match v.checked_sub(&digit) { + Some(x) if x <= T::zero() => result = x, + _ => { + return Ok(None); } } } @@ -2037,17 +2181,30 @@ fn do_cast_string_to_int< if !negative { if let Some(neg) = result.checked_neg() { if neg < T::zero() { - return none_or_err(eval_mode, type_name, str); + return Ok(None); } result = neg; } else { - return none_or_err(eval_mode, type_name, str); + return Ok(None); } } Ok(Some(result)) } +fn do_cast_string_to_int + Copy>( + str: &str, + eval_mode: EvalMode, + type_name: &str, + min_value: T, +) -> SparkResult> { + match eval_mode { + EvalMode::Legacy => do_parse_string_to_int_legacy(str, min_value), + EvalMode::Ansi => do_parse_string_to_int_ansi(str, type_name, min_value), + EvalMode::Try => do_parse_string_to_int_try(str, min_value), + } +} + fn cast_string_to_decimal( array: &ArrayRef, to_type: &DataType, @@ -2348,15 +2505,6 @@ fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> { Ok((final_mantissa, final_scale)) } -/// Either return Ok(None) or Err(SparkError::CastInvalidValue) depending on the evaluation mode -#[inline] -fn none_or_err(eval_mode: EvalMode, type_name: &str, str: &str) -> SparkResult> { - match eval_mode { - EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)), - _ => Ok(None), - } -} - #[inline] fn invalid_value(value: &str, from_type: &str, to_type: &str) -> SparkError { SparkError::CastInvalidValue {