From 42a6d6418ba3d85775ffdb10dc93db0102e8cf8b Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Tue, 30 Dec 2025 17:30:22 -0800 Subject: [PATCH 1/7] perf_string_to_int --- .../spark-expr/src/conversion_funcs/cast.rs | 27 ++++++++++++++----- 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index 5011917082..4be0d6e2a9 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -1965,22 +1965,37 @@ fn do_cast_string_to_int< type_name: &str, min_value: T, ) -> SparkResult> { - let trimmed_str = str.trim(); - if trimmed_str.is_empty() { + + let bytes = str.as_bytes(); + let mut start = 0; + let mut end = bytes.len(); + + while start < end && bytes[start].is_ascii_whitespace() { + start += 1; + } + while end > start && bytes[end - 1].is_ascii_whitespace() { + end -= 1; + } + + if start == end { return none_or_err(eval_mode, type_name, str); } + let trimmed_str = &str[start..end]; let len = trimmed_str.len(); + let trimmed_bytes = trimmed_str.as_bytes(); + let mut result: T = T::zero(); let mut negative = false; let radix = T::from(10); let stop_value = min_value / radix; let mut parse_sign_and_digits = true; - for (i, ch) in trimmed_str.char_indices() { + for i in 0..len { + let ch = trimmed_bytes[i]; if parse_sign_and_digits { if i == 0 { - negative = ch == '-'; - let positive = ch == '+'; + negative = ch == b'-'; + let positive = ch == b'+'; if negative || positive { if i + 1 == len { // input string is just "+" or "-" @@ -1991,7 +2006,7 @@ fn do_cast_string_to_int< } } - if ch == '.' { + if ch == b'.' { if eval_mode == EvalMode::Legacy { // truncate decimal in legacy mode parse_sign_and_digits = false; From 768a0175ebefc4ab120eec453bc2aacbf8baa5f1 Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Tue, 30 Dec 2025 22:08:13 -0800 Subject: [PATCH 2/7] perf_string_to_int --- .../spark-expr/src/conversion_funcs/cast.rs | 46 +++++++------------ 1 file changed, 17 insertions(+), 29 deletions(-) diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index 4be0d6e2a9..1dc3d6e4fb 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -1983,29 +1983,23 @@ fn do_cast_string_to_int< let trimmed_str = &str[start..end]; let len = trimmed_str.len(); let trimmed_bytes = trimmed_str.as_bytes(); - let mut result: T = T::zero(); - let mut negative = false; + let mut idx = 0; + let first_char = trimmed_bytes[0]; + let negative = first_char == b'-'; + if negative || first_char == b'+' { + idx = 1; + if len == 1{ + return none_or_err(eval_mode, type_name, str); + } + } + let radix = T::from(10); let stop_value = min_value / radix; let mut parse_sign_and_digits = true; - for i in 0..len { - let ch = trimmed_bytes[i]; + for &ch in &trimmed_bytes[idx..] { if parse_sign_and_digits { - if i == 0 { - negative = ch == b'-'; - let positive = ch == b'+'; - if negative || positive { - if i + 1 == len { - // input string is just "+" or "-" - return none_or_err(eval_mode, type_name, str); - } - // consume this char - continue; - } - } - if ch == b'.' { if eval_mode == EvalMode::Legacy { // truncate decimal in legacy mode @@ -2016,11 +2010,11 @@ fn do_cast_string_to_int< } } - let digit = if ch.is_ascii_digit() { - (ch as u32) - ('0' as u32) - } else { + if !ch.is_ascii_digit() { return none_or_err(eval_mode, type_name, str); - }; + } + let digit = T::from((ch - b'0') as i32); + result = result * radix - digit; // We are going to process the new digit and accumulate the result. However, before // doing this, if the result is already smaller than the @@ -2029,17 +2023,11 @@ fn do_cast_string_to_int< if result < stop_value { return none_or_err(eval_mode, type_name, str); } - // Since the previous result is greater than or equal to stopValue(Integer.MIN_VALUE / // radix), we can just use `result > 0` to check overflow. If result // overflows, we should stop - let v = result * radix; - let digit = (digit as i32).into(); - match v.checked_sub(&digit) { - Some(x) if x <= T::zero() => result = x, - _ => { - return none_or_err(eval_mode, type_name, str); - } + if result > T::zero() { + return none_or_err(eval_mode, type_name, str); } } else { // make sure fractional digits are valid digits but ignore them From ca9058753e28a2dd1e0dede67a4e6309b12d46d1 Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Wed, 31 Dec 2025 18:34:15 -0800 Subject: [PATCH 3/7] perf_string_to_int --- .../spark-expr/src/conversion_funcs/cast.rs | 62 ++++++++++++++++--- 1 file changed, 52 insertions(+), 10 deletions(-) diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index 1dc3d6e4fb..931b3118c6 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -389,19 +389,32 @@ macro_rules! cast_utf8_to_int { ($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident) => {{ let len = $array.len(); let mut cast_array = PrimitiveArray::<$array_type>::builder(len); - for i in 0..len { - if $array.is_null(i) { - cast_array.append_null() - } else if let Some(cast_value) = $cast_method($array.value(i), $eval_mode)? { - cast_array.append_value(cast_value); - } else { - cast_array.append_null() + + if $array.null_count() == 0 { + for i in 0..len { + if let Some(cast_value) = $cast_method($array.value(i), $eval_mode)? { + cast_array.append_value(cast_value); + } else { + cast_array.append_null() + } + } + } else { + for i in 0..len { + if $array.is_null(i) { + cast_array.append_null() + } else if let Some(cast_value) = $cast_method($array.value(i), $eval_mode)? { + cast_array.append_value(cast_value); + } else { + cast_array.append_null() + } } } + let result: SparkResult = Ok(Arc::new(cast_array.finish()) as ArrayRef); result }}; } + macro_rules! cast_utf8_to_timestamp { ($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident, $tz:expr) => {{ let len = $array.len(); @@ -1931,6 +1944,35 @@ fn cast_string_to_i16(str: &str, eval_mode: EvalMode) -> SparkResult /// Equivalent to org.apache.spark.unsafe.types.UTF8String.toInt(IntWrapper intWrapper) fn cast_string_to_i32(str: &str, eval_mode: EvalMode) -> SparkResult> { + // happy path + let bytes = str.as_bytes(); + let len = bytes.len(); + if len > 0 && len <= 10 { + // SAFETY: We checked len > 0 above + let first = unsafe { *bytes.get_unchecked(0) }; + // Must start with digit for happy path + if first >= b'0' && first <= b'9' { + let mut result: i64 = (first - b'0') as i64; + let mut i = 1; + + // Try to parse remaining digits + while i < len { + let b = bytes[i]; + if b >= b'0' && b <= b'9' { + result = result * 10 + (b - b'0') as i64; + i += 1; + } else { + // Hit non-digit (space, sign, decimal, etc.) - Bail to slow path + break; + } + } + if i == len && result <= i32::MAX as i64 { + return Ok(Some(result as i32)); + } + // Otherwise fall through to slow path + } + } + do_cast_string_to_int::(str, eval_mode, "INT", i32::MIN) } @@ -1965,7 +2007,6 @@ fn do_cast_string_to_int< type_name: &str, min_value: T, ) -> SparkResult> { - let bytes = str.as_bytes(); let mut start = 0; let mut end = bytes.len(); @@ -1989,7 +2030,7 @@ fn do_cast_string_to_int< let negative = first_char == b'-'; if negative || first_char == b'+' { idx = 1; - if len == 1{ + if len == 1 { return none_or_err(eval_mode, type_name, str); } } @@ -1998,7 +2039,7 @@ fn do_cast_string_to_int< let stop_value = min_value / radix; let mut parse_sign_and_digits = true; - for &ch in &trimmed_bytes[idx..] { + for &ch in &trimmed_bytes[idx..] { if parse_sign_and_digits { if ch == b'.' { if eval_mode == EvalMode::Legacy { @@ -2014,6 +2055,7 @@ fn do_cast_string_to_int< return none_or_err(eval_mode, type_name, str); } let digit = T::from((ch - b'0') as i32); + result = (result << 3) + (result << 1) - digit; result = result * radix - digit; // We are going to process the new digit and accumulate the result. However, before From 2b331f01c023f45e0577c59b9dbb56fbfdb34583 Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Wed, 31 Dec 2025 19:51:22 -0800 Subject: [PATCH 4/7] perf_string_to_int --- .../spark-expr/src/conversion_funcs/cast.rs | 50 ++++--------------- 1 file changed, 11 insertions(+), 39 deletions(-) diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index 931b3118c6..bd6eb9c29f 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -389,7 +389,6 @@ macro_rules! cast_utf8_to_int { ($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident) => {{ let len = $array.len(); let mut cast_array = PrimitiveArray::<$array_type>::builder(len); - if $array.null_count() == 0 { for i in 0..len { if let Some(cast_value) = $cast_method($array.value(i), $eval_mode)? { @@ -409,12 +408,10 @@ macro_rules! cast_utf8_to_int { } } } - let result: SparkResult = Ok(Arc::new(cast_array.finish()) as ArrayRef); result }}; } - macro_rules! cast_utf8_to_timestamp { ($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident, $tz:expr) => {{ let len = $array.len(); @@ -1944,35 +1941,6 @@ fn cast_string_to_i16(str: &str, eval_mode: EvalMode) -> SparkResult /// Equivalent to org.apache.spark.unsafe.types.UTF8String.toInt(IntWrapper intWrapper) fn cast_string_to_i32(str: &str, eval_mode: EvalMode) -> SparkResult> { - // happy path - let bytes = str.as_bytes(); - let len = bytes.len(); - if len > 0 && len <= 10 { - // SAFETY: We checked len > 0 above - let first = unsafe { *bytes.get_unchecked(0) }; - // Must start with digit for happy path - if first >= b'0' && first <= b'9' { - let mut result: i64 = (first - b'0') as i64; - let mut i = 1; - - // Try to parse remaining digits - while i < len { - let b = bytes[i]; - if b >= b'0' && b <= b'9' { - result = result * 10 + (b - b'0') as i64; - i += 1; - } else { - // Hit non-digit (space, sign, decimal, etc.) - Bail to slow path - break; - } - } - if i == len && result <= i32::MAX as i64 { - return Ok(Some(result as i32)); - } - // Otherwise fall through to slow path - } - } - do_cast_string_to_int::(str, eval_mode, "INT", i32::MIN) } @@ -2051,12 +2019,11 @@ fn do_cast_string_to_int< } } - if !ch.is_ascii_digit() { + let digit = if ch.is_ascii_digit() { + (ch as u32) - ('0' as u32) + } else { return none_or_err(eval_mode, type_name, str); - } - let digit = T::from((ch - b'0') as i32); - result = (result << 3) + (result << 1) - digit; - result = result * radix - digit; + }; // We are going to process the new digit and accumulate the result. However, before // doing this, if the result is already smaller than the @@ -2068,8 +2035,13 @@ fn do_cast_string_to_int< // Since the previous result is greater than or equal to stopValue(Integer.MIN_VALUE / // radix), we can just use `result > 0` to check overflow. If result // overflows, we should stop - if result > T::zero() { - return none_or_err(eval_mode, type_name, str); + let v = result * radix; + let digit = (digit as i32).into(); + match v.checked_sub(&digit) { + Some(x) if x <= T::zero() => result = x, + _ => { + return none_or_err(eval_mode, type_name, str); + } } } else { // make sure fractional digits are valid digits but ignore them From 5abce5f455f2d03d20a91d1a766ad107d92e93c4 Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Sun, 4 Jan 2026 15:53:54 -0800 Subject: [PATCH 5/7] improve_benchmark --- .../spark-expr/src/conversion_funcs/cast.rs | 25 ++++++++----------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index bd6eb9c29f..36458f84fb 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -54,8 +54,8 @@ use datafusion::common::{ use datafusion::physical_expr::PhysicalExpr; use datafusion::physical_plan::ColumnarValue; use num::{ - cast::AsPrimitive, integer::div_floor, traits::CheckedNeg, CheckedSub, Integer, Num, - ToPrimitive, Zero, + cast::AsPrimitive, integer::div_floor, traits::CheckedNeg, CheckedSub, Integer, ToPrimitive, + Zero, }; use regex::Regex; use std::str::FromStr; @@ -1967,9 +1967,7 @@ fn cast_string_to_int_with_range_check( /// Equivalent to /// - org.apache.spark.unsafe.types.UTF8String.toInt(IntWrapper intWrapper, boolean allowDecimal) /// - org.apache.spark.unsafe.types.UTF8String.toLong(LongWrapper longWrapper, boolean allowDecimal) -fn do_cast_string_to_int< - T: Num + PartialOrd + Integer + CheckedSub + CheckedNeg + From + Copy, ->( +fn do_cast_string_to_int + Copy>( str: &str, eval_mode: EvalMode, type_name: &str, @@ -1989,9 +1987,8 @@ fn do_cast_string_to_int< if start == end { return none_or_err(eval_mode, type_name, str); } - let trimmed_str = &str[start..end]; - let len = trimmed_str.len(); - let trimmed_bytes = trimmed_str.as_bytes(); + let trimmed_bytes = &bytes[start..end]; + let len = trimmed_bytes.len(); let mut result: T = T::zero(); let mut idx = 0; let first_char = trimmed_bytes[0]; @@ -2003,7 +2000,7 @@ fn do_cast_string_to_int< } } - let radix = T::from(10); + let radix = T::from(10_u8); let stop_value = min_value / radix; let mut parse_sign_and_digits = true; @@ -2019,11 +2016,12 @@ fn do_cast_string_to_int< } } - let digit = if ch.is_ascii_digit() { - (ch as u32) - ('0' as u32) - } else { + if !ch.is_ascii_digit() { return none_or_err(eval_mode, type_name, str); - }; + } + + // Direct conversion: u8 digit (0-9) → T + let digit: T = T::from(ch - b'0'); // We are going to process the new digit and accumulate the result. However, before // doing this, if the result is already smaller than the @@ -2036,7 +2034,6 @@ fn do_cast_string_to_int< // radix), we can just use `result > 0` to check overflow. If result // overflows, we should stop let v = result * radix; - let digit = (digit as i32).into(); match v.checked_sub(&digit) { Some(x) if x <= T::zero() => result = x, _ => { From 94d0f3206561e6b65fd7f6d27aadf7c9a226c19b Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Mon, 5 Jan 2026 15:05:14 -0800 Subject: [PATCH 6/7] improve_benchmark_remove_unwanted_branching --- .../spark-expr/src/conversion_funcs/cast.rs | 205 +++++++++++++++--- 1 file changed, 180 insertions(+), 25 deletions(-) diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index 36458f84fb..f51c27a0b6 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -1967,9 +1967,88 @@ fn cast_string_to_int_with_range_check( /// Equivalent to /// - org.apache.spark.unsafe.types.UTF8String.toInt(IntWrapper intWrapper, boolean allowDecimal) /// - org.apache.spark.unsafe.types.UTF8String.toLong(LongWrapper longWrapper, boolean allowDecimal) -fn do_cast_string_to_int + Copy>( +fn do_parse_string_to_int_legacy + Copy>( + str: &str, + min_value: T, +) -> SparkResult> { + let bytes = str.as_bytes(); + let mut start = 0; + let mut end = bytes.len(); + + while start < end && bytes[start].is_ascii_whitespace() { + start += 1; + } + while end > start && bytes[end - 1].is_ascii_whitespace() { + end -= 1; + } + + if start == end { + return Ok(None); + } + let trimmed_bytes = &bytes[start..end]; + let len = trimmed_bytes.len(); + let mut result: T = T::zero(); + let mut idx = 0; + let first_char = trimmed_bytes[0]; + let negative = first_char == b'-'; + if negative || first_char == b'+' { + idx = 1; + if len == 1 { + return Ok(None); + } + } + + let radix = T::from(10_u8); + let stop_value = min_value / radix; + let mut parse_sign_and_digits = true; + + for &ch in &trimmed_bytes[idx..] { + if parse_sign_and_digits { + if ch == b'.' { + // truncate decimal in legacy mode + parse_sign_and_digits = false; + continue; + } + + if !ch.is_ascii_digit() { + return Ok(None); + } + + let digit: T = T::from(ch - b'0'); + + if result < stop_value { + return Ok(None); + } + let v = result * radix; + match v.checked_sub(&digit) { + Some(x) if x <= T::zero() => result = x, + _ => { + return Ok(None); + } + } + } else { + if !ch.is_ascii_digit() { + return Ok(None); + } + } + } + + if !negative { + if let Some(neg) = result.checked_neg() { + if neg < T::zero() { + return Ok(None); + } + result = neg; + } else { + return Ok(None); + } + } + + Ok(Some(result)) +} + +fn do_parse_string_to_int_ansi + Copy>( str: &str, - eval_mode: EvalMode, type_name: &str, min_value: T, ) -> SparkResult> { @@ -1985,7 +2064,7 @@ fn do_cast_string_to_int + Copy> } if start == end { - return none_or_err(eval_mode, type_name, str); + return Err(invalid_value(str, "STRING", type_name)); } let trimmed_bytes = &bytes[start..end]; let len = trimmed_bytes.len(); @@ -1996,7 +2075,7 @@ fn do_cast_string_to_int + Copy> if negative || first_char == b'+' { idx = 1; if len == 1 { - return none_or_err(eval_mode, type_name, str); + return Err(invalid_value(str, "STRING", type_name)); } } @@ -2007,43 +2086,106 @@ fn do_cast_string_to_int + Copy> for &ch in &trimmed_bytes[idx..] { if parse_sign_and_digits { if ch == b'.' { - if eval_mode == EvalMode::Legacy { - // truncate decimal in legacy mode - parse_sign_and_digits = false; - continue; - } else { - return none_or_err(eval_mode, type_name, str); + return Err(invalid_value(str, "STRING", type_name)); + } + + if !ch.is_ascii_digit() { + return Err(invalid_value(str, "STRING", type_name)); + } + + let digit: T = T::from(ch - b'0'); + + if result < stop_value { + return Err(invalid_value(str, "STRING", type_name)); + } + let v = result * radix; + match v.checked_sub(&digit) { + Some(x) if x <= T::zero() => result = x, + _ => { + return Err(invalid_value(str, "STRING", type_name)); } } + } else { + if !ch.is_ascii_digit() { + return Err(invalid_value(str, "STRING", type_name)); + } + } + } + + if !negative { + if let Some(neg) = result.checked_neg() { + if neg < T::zero() { + return Err(invalid_value(str, "STRING", type_name)); + } + result = neg; + } else { + return Err(invalid_value(str, "STRING", type_name)); + } + } + + Ok(Some(result)) +} + +fn do_parse_string_to_int_try + Copy>( + str: &str, + min_value: T, +) -> SparkResult> { + let bytes = str.as_bytes(); + let mut start = 0; + let mut end = bytes.len(); + + while start < end && bytes[start].is_ascii_whitespace() { + start += 1; + } + while end > start && bytes[end - 1].is_ascii_whitespace() { + end -= 1; + } + + if start == end { + return Ok(None); + } + let trimmed_bytes = &bytes[start..end]; + let len = trimmed_bytes.len(); + let mut result: T = T::zero(); + let mut idx = 0; + let first_char = trimmed_bytes[0]; + let negative = first_char == b'-'; + if negative || first_char == b'+' { + idx = 1; + if len == 1 { + return Ok(None); + } + } + + let radix = T::from(10_u8); + let stop_value = min_value / radix; + let mut parse_sign_and_digits = true; + + for &ch in &trimmed_bytes[idx..] { + if parse_sign_and_digits { + if ch == b'.' { + return Ok(None); + } if !ch.is_ascii_digit() { - return none_or_err(eval_mode, type_name, str); + return Ok(None); } - // Direct conversion: u8 digit (0-9) → T let digit: T = T::from(ch - b'0'); - // We are going to process the new digit and accumulate the result. However, before - // doing this, if the result is already smaller than the - // stopValue(Integer.MIN_VALUE / radix), then result * 10 will definitely be - // smaller than minValue, and we can stop if result < stop_value { - return none_or_err(eval_mode, type_name, str); + return Ok(None); } - // Since the previous result is greater than or equal to stopValue(Integer.MIN_VALUE / - // radix), we can just use `result > 0` to check overflow. If result - // overflows, we should stop let v = result * radix; match v.checked_sub(&digit) { Some(x) if x <= T::zero() => result = x, _ => { - return none_or_err(eval_mode, type_name, str); + return Ok(None); } } } else { - // make sure fractional digits are valid digits but ignore them if !ch.is_ascii_digit() { - return none_or_err(eval_mode, type_name, str); + return Ok(None); } } } @@ -2051,17 +2193,30 @@ fn do_cast_string_to_int + Copy> if !negative { if let Some(neg) = result.checked_neg() { if neg < T::zero() { - return none_or_err(eval_mode, type_name, str); + return Ok(None); } result = neg; } else { - return none_or_err(eval_mode, type_name, str); + return Ok(None); } } Ok(Some(result)) } +fn do_cast_string_to_int + Copy>( + str: &str, + eval_mode: EvalMode, + type_name: &str, + min_value: T, +) -> SparkResult> { + match eval_mode { + EvalMode::Legacy => do_parse_string_to_int_legacy(str, min_value), + EvalMode::Ansi => do_parse_string_to_int_ansi(str, type_name, min_value), + EvalMode::Try => do_parse_string_to_int_try(str, min_value), + } +} + fn cast_string_to_decimal( array: &ArrayRef, to_type: &DataType, From f1299eb75972038a6add948e37adacff9bd1436f Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Mon, 5 Jan 2026 19:20:23 -0800 Subject: [PATCH 7/7] improve_benchmark_remove_unwanted_branching_per_eval_mode --- .../spark-expr/src/conversion_funcs/cast.rs | 189 ++++++++---------- 1 file changed, 84 insertions(+), 105 deletions(-) diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index f51c27a0b6..dc225d5267 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -1964,14 +1964,8 @@ fn cast_string_to_int_with_range_check( } } -/// Equivalent to -/// - org.apache.spark.unsafe.types.UTF8String.toInt(IntWrapper intWrapper, boolean allowDecimal) -/// - org.apache.spark.unsafe.types.UTF8String.toLong(LongWrapper longWrapper, boolean allowDecimal) -fn do_parse_string_to_int_legacy + Copy>( - str: &str, - min_value: T, -) -> SparkResult> { - let bytes = str.as_bytes(); +// Returns (start, end) indices after trimming whitespace +fn trim_whitespace(bytes: &[u8]) -> (usize, usize) { let mut start = 0; let mut end = bytes.len(); @@ -1982,21 +1976,51 @@ fn do_parse_string_to_int_legacy end -= 1; } - if start == end { - return Ok(None); - } - let trimmed_bytes = &bytes[start..end]; + (start, end) +} + +// Parses sign and returns (is_negative, start_idx after sign) +// Returns None if invalid (e.g., just "+" or "-") +fn parse_sign(trimmed_bytes: &[u8]) -> Option<(bool, usize)> { let len = trimmed_bytes.len(); - let mut result: T = T::zero(); - let mut idx = 0; + if len == 0 { + return None; + } + let first_char = trimmed_bytes[0]; let negative = first_char == b'-'; + if negative || first_char == b'+' { - idx = 1; if len == 1 { - return Ok(None); + return None; } + Some((negative, 1)) + } else { + Some((false, 0)) } +} + +/// Equivalent to +/// - org.apache.spark.unsafe.types.UTF8String.toInt(IntWrapper intWrapper, boolean allowDecimal) +/// - org.apache.spark.unsafe.types.UTF8String.toLong(LongWrapper longWrapper, boolean allowDecimal) +fn do_parse_string_to_int_legacy + Copy>( + str: &str, + min_value: T, +) -> SparkResult> { + let bytes = str.as_bytes(); + let (start, end) = trim_whitespace(bytes); + + if start == end { + return Ok(None); + } + let trimmed_bytes = &bytes[start..end]; + + let (negative, idx) = match parse_sign(trimmed_bytes) { + Some(result) => result, + None => return Ok(None), + }; + + let mut result: T = T::zero(); let radix = T::from(10_u8); let stop_value = min_value / radix; @@ -2027,6 +2051,7 @@ fn do_parse_string_to_int_legacy } } } else { + // in legacy mode we still process chars after the dot and make sure the chars are digits if !ch.is_ascii_digit() { return Ok(None); } @@ -2053,60 +2078,41 @@ fn do_parse_string_to_int_ansi + min_value: T, ) -> SparkResult> { let bytes = str.as_bytes(); - let mut start = 0; - let mut end = bytes.len(); - - while start < end && bytes[start].is_ascii_whitespace() { - start += 1; - } - while end > start && bytes[end - 1].is_ascii_whitespace() { - end -= 1; - } + let (start, end) = trim_whitespace(bytes); if start == end { return Err(invalid_value(str, "STRING", type_name)); } let trimmed_bytes = &bytes[start..end]; - let len = trimmed_bytes.len(); + + let (negative, idx) = match parse_sign(trimmed_bytes) { + Some(result) => result, + None => return Err(invalid_value(str, "STRING", type_name)), + }; + let mut result: T = T::zero(); - let mut idx = 0; - let first_char = trimmed_bytes[0]; - let negative = first_char == b'-'; - if negative || first_char == b'+' { - idx = 1; - if len == 1 { - return Err(invalid_value(str, "STRING", type_name)); - } - } let radix = T::from(10_u8); let stop_value = min_value / radix; - let mut parse_sign_and_digits = true; for &ch in &trimmed_bytes[idx..] { - if parse_sign_and_digits { - if ch == b'.' { - return Err(invalid_value(str, "STRING", type_name)); - } + if ch == b'.' { + return Err(invalid_value(str, "STRING", type_name)); + } - if !ch.is_ascii_digit() { - return Err(invalid_value(str, "STRING", type_name)); - } + if !ch.is_ascii_digit() { + return Err(invalid_value(str, "STRING", type_name)); + } - let digit: T = T::from(ch - b'0'); + let digit: T = T::from(ch - b'0'); - if result < stop_value { - return Err(invalid_value(str, "STRING", type_name)); - } - let v = result * radix; - match v.checked_sub(&digit) { - Some(x) if x <= T::zero() => result = x, - _ => { - return Err(invalid_value(str, "STRING", type_name)); - } - } - } else { - if !ch.is_ascii_digit() { + if result < stop_value { + return Err(invalid_value(str, "STRING", type_name)); + } + let v = result * radix; + match v.checked_sub(&digit) { + Some(x) if x <= T::zero() => result = x, + _ => { return Err(invalid_value(str, "STRING", type_name)); } } @@ -2131,60 +2137,42 @@ fn do_parse_string_to_int_try + min_value: T, ) -> SparkResult> { let bytes = str.as_bytes(); - let mut start = 0; - let mut end = bytes.len(); - - while start < end && bytes[start].is_ascii_whitespace() { - start += 1; - } - while end > start && bytes[end - 1].is_ascii_whitespace() { - end -= 1; - } + let (start, end) = trim_whitespace(bytes); if start == end { return Ok(None); } let trimmed_bytes = &bytes[start..end]; - let len = trimmed_bytes.len(); + + let (negative, idx) = match parse_sign(trimmed_bytes) { + Some(result) => result, + None => return Ok(None), + }; + let mut result: T = T::zero(); - let mut idx = 0; - let first_char = trimmed_bytes[0]; - let negative = first_char == b'-'; - if negative || first_char == b'+' { - idx = 1; - if len == 1 { - return Ok(None); - } - } let radix = T::from(10_u8); let stop_value = min_value / radix; - let mut parse_sign_and_digits = true; + // we don't have to go beyond decimal point in try eval mode - early return NULL for &ch in &trimmed_bytes[idx..] { - if parse_sign_and_digits { - if ch == b'.' { - return Ok(None); - } + if ch == b'.' { + return Ok(None); + } - if !ch.is_ascii_digit() { - return Ok(None); - } + if !ch.is_ascii_digit() { + return Ok(None); + } - let digit: T = T::from(ch - b'0'); + let digit: T = T::from(ch - b'0'); - if result < stop_value { - return Ok(None); - } - let v = result * radix; - match v.checked_sub(&digit) { - Some(x) if x <= T::zero() => result = x, - _ => { - return Ok(None); - } - } - } else { - if !ch.is_ascii_digit() { + if result < stop_value { + return Ok(None); + } + let v = result * radix; + match v.checked_sub(&digit) { + Some(x) if x <= T::zero() => result = x, + _ => { return Ok(None); } } @@ -2517,15 +2505,6 @@ fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> { Ok((final_mantissa, final_scale)) } -/// Either return Ok(None) or Err(SparkError::CastInvalidValue) depending on the evaluation mode -#[inline] -fn none_or_err(eval_mode: EvalMode, type_name: &str, str: &str) -> SparkResult> { - match eval_mode { - EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)), - _ => Ok(None), - } -} - #[inline] fn invalid_value(value: &str, from_type: &str, to_type: &str) -> SparkError { SparkError::CastInvalidValue {