@@ -20,12 +20,13 @@ use crate::{timezone, BinaryOutputStyle};
2020use crate :: { EvalMode , SparkError , SparkResult } ;
2121use arrow:: array:: builder:: StringBuilder ;
2222use arrow:: array:: {
23- BooleanBuilder , Decimal128Builder , DictionaryArray , GenericByteArray , ListArray , StringArray ,
24- StructArray ,
23+ BooleanBuilder , Decimal128Builder , DictionaryArray , GenericByteArray , ListArray ,
24+ PrimitiveBuilder , StringArray , StructArray ,
2525} ;
2626use arrow:: compute:: can_cast_types;
2727use arrow:: datatypes:: {
28- ArrowDictionaryKeyType , ArrowNativeType , DataType , GenericBinaryType , Schema ,
28+ i256, ArrowDictionaryKeyType , ArrowNativeType , DataType , Decimal256Type , GenericBinaryType ,
29+ Schema ,
2930} ;
3031use arrow:: {
3132 array:: {
@@ -224,9 +225,7 @@ fn can_cast_from_string(to_type: &DataType, options: &SparkCastOptions) -> bool
224225 }
225226 Decimal128 ( _, _) => {
226227 // https://github.com/apache/datafusion-comet/issues/325
227- // Does not support inputs ending with 'd' or 'f'. Does not support 'inf'.
228- // Does not support ANSI mode. Returns 0.0 instead of null if input contains no digits
229-
228+ // Does not support fullwidth digits and null byte handling.
230229 options. allow_incompat
231230 }
232231 Date32 | Date64 => {
@@ -976,6 +975,12 @@ fn cast_array(
976975 cast_string_to_timestamp ( & array, to_type, eval_mode, & cast_options. timezone )
977976 }
978977 ( Utf8 , Date32 ) => cast_string_to_date ( & array, to_type, eval_mode) ,
978+ ( Utf8 | LargeUtf8 , Decimal128 ( precision, scale) ) => {
979+ cast_string_to_decimal ( & array, to_type, precision, scale, eval_mode)
980+ }
981+ ( Utf8 | LargeUtf8 , Decimal256 ( precision, scale) ) => {
982+ cast_string_to_decimal ( & array, to_type, precision, scale, eval_mode)
983+ }
979984 ( Int64 , Int32 )
980985 | ( Int64 , Int16 )
981986 | ( Int64 , Int8 )
@@ -1187,7 +1192,7 @@ fn is_datafusion_spark_compatible(
11871192 ) ,
11881193 DataType :: Utf8 if allow_incompat => matches ! (
11891194 to_type,
1190- DataType :: Binary | DataType :: Float32 | DataType :: Float64 | DataType :: Decimal128 ( _ , _ )
1195+ DataType :: Binary | DataType :: Float32 | DataType :: Float64
11911196 ) ,
11921197 DataType :: Utf8 => matches ! ( to_type, DataType :: Binary ) ,
11931198 DataType :: Date32 => matches ! ( to_type, DataType :: Utf8 ) ,
@@ -1976,6 +1981,306 @@ fn do_cast_string_to_int<
19761981 Ok ( Some ( result) )
19771982}
19781983
1984+ fn cast_string_to_decimal (
1985+ array : & ArrayRef ,
1986+ to_type : & DataType ,
1987+ precision : & u8 ,
1988+ scale : & i8 ,
1989+ eval_mode : EvalMode ,
1990+ ) -> SparkResult < ArrayRef > {
1991+ match to_type {
1992+ DataType :: Decimal128 ( _, _) => {
1993+ cast_string_to_decimal128_impl ( array, eval_mode, * precision, * scale)
1994+ }
1995+ DataType :: Decimal256 ( _, _) => {
1996+ cast_string_to_decimal256_impl ( array, eval_mode, * precision, * scale)
1997+ }
1998+ _ => Err ( SparkError :: Internal ( format ! (
1999+ "Unexpected type in cast_string_to_decimal: {:?}" ,
2000+ to_type
2001+ ) ) ) ,
2002+ }
2003+ }
2004+
2005+ fn cast_string_to_decimal128_impl (
2006+ array : & ArrayRef ,
2007+ eval_mode : EvalMode ,
2008+ precision : u8 ,
2009+ scale : i8 ,
2010+ ) -> SparkResult < ArrayRef > {
2011+ let string_array = array
2012+ . as_any ( )
2013+ . downcast_ref :: < StringArray > ( )
2014+ . ok_or_else ( || SparkError :: Internal ( "Expected string array" . to_string ( ) ) ) ?;
2015+
2016+ let mut decimal_builder = Decimal128Builder :: with_capacity ( string_array. len ( ) ) ;
2017+
2018+ for i in 0 ..string_array. len ( ) {
2019+ if string_array. is_null ( i) {
2020+ decimal_builder. append_null ( ) ;
2021+ } else {
2022+ let str_value = string_array. value ( i) ;
2023+ match parse_string_to_decimal ( str_value, precision, scale) {
2024+ Ok ( Some ( decimal_value) ) => {
2025+ decimal_builder. append_value ( decimal_value) ;
2026+ }
2027+ Ok ( None ) => {
2028+ if eval_mode == EvalMode :: Ansi {
2029+ return Err ( invalid_value (
2030+ string_array. value ( i) ,
2031+ "STRING" ,
2032+ & format ! ( "DECIMAL({},{})" , precision, scale) ,
2033+ ) ) ;
2034+ }
2035+ decimal_builder. append_null ( ) ;
2036+ }
2037+ Err ( e) => {
2038+ if eval_mode == EvalMode :: Ansi {
2039+ return Err ( e) ;
2040+ }
2041+ decimal_builder. append_null ( ) ;
2042+ }
2043+ }
2044+ }
2045+ }
2046+
2047+ Ok ( Arc :: new (
2048+ decimal_builder
2049+ . with_precision_and_scale ( precision, scale) ?
2050+ . finish ( ) ,
2051+ ) )
2052+ }
2053+
2054+ fn cast_string_to_decimal256_impl (
2055+ array : & ArrayRef ,
2056+ eval_mode : EvalMode ,
2057+ precision : u8 ,
2058+ scale : i8 ,
2059+ ) -> SparkResult < ArrayRef > {
2060+ let string_array = array
2061+ . as_any ( )
2062+ . downcast_ref :: < StringArray > ( )
2063+ . ok_or_else ( || SparkError :: Internal ( "Expected string array" . to_string ( ) ) ) ?;
2064+
2065+ let mut decimal_builder = PrimitiveBuilder :: < Decimal256Type > :: with_capacity ( string_array. len ( ) ) ;
2066+
2067+ for i in 0 ..string_array. len ( ) {
2068+ if string_array. is_null ( i) {
2069+ decimal_builder. append_null ( ) ;
2070+ } else {
2071+ let str_value = string_array. value ( i) ;
2072+ match parse_string_to_decimal ( str_value, precision, scale) {
2073+ Ok ( Some ( decimal_value) ) => {
2074+ // Convert i128 to i256
2075+ let i256_value = i256:: from_i128 ( decimal_value) ;
2076+ decimal_builder. append_value ( i256_value) ;
2077+ }
2078+ Ok ( None ) => {
2079+ if eval_mode == EvalMode :: Ansi {
2080+ return Err ( invalid_value (
2081+ str_value,
2082+ "STRING" ,
2083+ & format ! ( "DECIMAL({},{})" , precision, scale) ,
2084+ ) ) ;
2085+ }
2086+ decimal_builder. append_null ( ) ;
2087+ }
2088+ Err ( e) => {
2089+ if eval_mode == EvalMode :: Ansi {
2090+ return Err ( e) ;
2091+ }
2092+ decimal_builder. append_null ( ) ;
2093+ }
2094+ }
2095+ }
2096+ }
2097+
2098+ Ok ( Arc :: new (
2099+ decimal_builder
2100+ . with_precision_and_scale ( precision, scale) ?
2101+ . finish ( ) ,
2102+ ) )
2103+ }
2104+
2105+ /// Parse a string to decimal following Spark's behavior
2106+ fn parse_string_to_decimal ( s : & str , precision : u8 , scale : i8 ) -> SparkResult < Option < i128 > > {
2107+ let string_bytes = s. as_bytes ( ) ;
2108+ let mut start = 0 ;
2109+ let mut end = string_bytes. len ( ) ;
2110+
2111+ // trim whitespaces
2112+ while start < end && string_bytes[ start] . is_ascii_whitespace ( ) {
2113+ start += 1 ;
2114+ }
2115+ while end > start && string_bytes[ end - 1 ] . is_ascii_whitespace ( ) {
2116+ end -= 1 ;
2117+ }
2118+
2119+ let trimmed = & s[ start..end] ;
2120+
2121+ if trimmed. is_empty ( ) {
2122+ return Ok ( None ) ;
2123+ }
2124+ // Handle special values (inf, nan, etc.)
2125+ if trimmed. eq_ignore_ascii_case ( "inf" )
2126+ || trimmed. eq_ignore_ascii_case ( "+inf" )
2127+ || trimmed. eq_ignore_ascii_case ( "infinity" )
2128+ || trimmed. eq_ignore_ascii_case ( "+infinity" )
2129+ || trimmed. eq_ignore_ascii_case ( "-inf" )
2130+ || trimmed. eq_ignore_ascii_case ( "-infinity" )
2131+ || trimmed. eq_ignore_ascii_case ( "nan" )
2132+ {
2133+ return Ok ( None ) ;
2134+ }
2135+
2136+ // validate and parse mantissa and exponent
2137+ match parse_decimal_str ( trimmed) {
2138+ Ok ( ( mantissa, exponent) ) => {
2139+ // Convert to target scale
2140+ let target_scale = scale as i32 ;
2141+ let scale_adjustment = target_scale - exponent;
2142+
2143+ let scaled_value = if scale_adjustment >= 0 {
2144+ // Need to multiply (increase scale) but return None if scale is too high to fit i128
2145+ if scale_adjustment > 38 {
2146+ return Ok ( None ) ;
2147+ }
2148+ mantissa. checked_mul ( 10_i128 . pow ( scale_adjustment as u32 ) )
2149+ } else {
2150+ // Need to multiply (increase scale) but return None if scale is too high to fit i128
2151+ let abs_scale_adjustment = ( -scale_adjustment) as u32 ;
2152+ if abs_scale_adjustment > 38 {
2153+ return Ok ( Some ( 0 ) ) ;
2154+ }
2155+
2156+ let divisor = 10_i128 . pow ( abs_scale_adjustment) ;
2157+ let quotient_opt = mantissa. checked_div ( divisor) ;
2158+ // Check if divisor is 0
2159+ if quotient_opt. is_none ( ) {
2160+ return Ok ( None ) ;
2161+ }
2162+ let quotient = quotient_opt. unwrap ( ) ;
2163+ let remainder = mantissa % divisor;
2164+
2165+ // Round half up: if abs(remainder) >= divisor/2, round away from zero
2166+ let half_divisor = divisor / 2 ;
2167+ let rounded = if remainder. abs ( ) >= half_divisor {
2168+ if mantissa >= 0 {
2169+ quotient + 1
2170+ } else {
2171+ quotient - 1
2172+ }
2173+ } else {
2174+ quotient
2175+ } ;
2176+ Some ( rounded)
2177+ } ;
2178+
2179+ match scaled_value {
2180+ Some ( value) => {
2181+ // Check if it fits target precision
2182+ if is_validate_decimal_precision ( value, precision) {
2183+ Ok ( Some ( value) )
2184+ } else {
2185+ Ok ( None )
2186+ }
2187+ }
2188+ None => {
2189+ // Overflow while scaling
2190+ Ok ( None )
2191+ }
2192+ }
2193+ }
2194+ Err ( _) => Ok ( None ) ,
2195+ }
2196+ }
2197+
2198+ /// Parse a decimal string into mantissa and scale
2199+ /// e.g., "123.45" -> (12345, 2), "-0.001" -> (-1, 3)
2200+ fn parse_decimal_str ( s : & str ) -> Result < ( i128 , i32 ) , String > {
2201+ if s. is_empty ( ) {
2202+ return Err ( "Empty string" . to_string ( ) ) ;
2203+ }
2204+
2205+ let ( mantissa_str, exponent) = if let Some ( e_pos) = s. find ( |c| [ 'e' , 'E' ] . contains ( & c) ) {
2206+ let mantissa_part = & s[ ..e_pos] ;
2207+ let exponent_part = & s[ e_pos + 1 ..] ;
2208+ // Parse exponent
2209+ let exp: i32 = exponent_part
2210+ . parse ( )
2211+ . map_err ( |e| format ! ( "Invalid exponent: {}" , e) ) ?;
2212+
2213+ ( mantissa_part, exp)
2214+ } else {
2215+ ( s, 0 )
2216+ } ;
2217+
2218+ let negative = mantissa_str. starts_with ( '-' ) ;
2219+ let mantissa_str = if negative || mantissa_str. starts_with ( '+' ) {
2220+ & mantissa_str[ 1 ..]
2221+ } else {
2222+ mantissa_str
2223+ } ;
2224+
2225+ if mantissa_str. starts_with ( '+' ) || mantissa_str. starts_with ( '-' ) {
2226+ return Err ( "Invalid sign format" . to_string ( ) ) ;
2227+ }
2228+
2229+ let ( integral_part, fractional_part) = match mantissa_str. find ( '.' ) {
2230+ Some ( dot_pos) => {
2231+ if mantissa_str[ dot_pos + 1 ..] . contains ( '.' ) {
2232+ return Err ( "Multiple decimal points" . to_string ( ) ) ;
2233+ }
2234+ ( & mantissa_str[ ..dot_pos] , & mantissa_str[ dot_pos + 1 ..] )
2235+ }
2236+ None => ( mantissa_str, "" ) ,
2237+ } ;
2238+
2239+ if integral_part. is_empty ( ) && fractional_part. is_empty ( ) {
2240+ return Err ( "No digits found" . to_string ( ) ) ;
2241+ }
2242+
2243+ if !integral_part. is_empty ( ) && !integral_part. bytes ( ) . all ( |b| b. is_ascii_digit ( ) ) {
2244+ return Err ( "Invalid integral part" . to_string ( ) ) ;
2245+ }
2246+
2247+ if !fractional_part. is_empty ( ) && !fractional_part. bytes ( ) . all ( |b| b. is_ascii_digit ( ) ) {
2248+ return Err ( "Invalid fractional part" . to_string ( ) ) ;
2249+ }
2250+
2251+ // Parse integral part
2252+ let integral_value: i128 = if integral_part. is_empty ( ) {
2253+ // Empty integral part is valid (e.g., ".5" or "-.7e9")
2254+ 0
2255+ } else {
2256+ integral_part
2257+ . parse ( )
2258+ . map_err ( |_| "Invalid integral part" . to_string ( ) ) ?
2259+ } ;
2260+
2261+ // Parse fractional part
2262+ let fractional_scale = fractional_part. len ( ) as i32 ;
2263+ let fractional_value: i128 = if fractional_part. is_empty ( ) {
2264+ 0
2265+ } else {
2266+ fractional_part
2267+ . parse ( )
2268+ . map_err ( |_| "Invalid fractional part" . to_string ( ) ) ?
2269+ } ;
2270+
2271+ // Combine: value = integral * 10^fractional_scale + fractional
2272+ let mantissa = integral_value
2273+ . checked_mul ( 10_i128 . pow ( fractional_scale as u32 ) )
2274+ . and_then ( |v| v. checked_add ( fractional_value) )
2275+ . ok_or ( "Overflow in mantissa calculation" ) ?;
2276+
2277+ let final_mantissa = if negative { -mantissa } else { mantissa } ;
2278+ // final scale = fractional_scale - exponent
2279+ // For example : "1.23E-5" has fractional_scale=2, exponent=-5, so scale = 2 - (-5) = 7
2280+ let final_scale = fractional_scale - exponent;
2281+ Ok ( ( final_mantissa, final_scale) )
2282+ }
2283+
19792284/// Either return Ok(None) or Err(SparkError::CastInvalidValue) depending on the evaluation mode
19802285#[ inline]
19812286fn none_or_err < T > ( eval_mode : EvalMode , type_name : & str , str : & str ) -> SparkResult < Option < T > > {
0 commit comments