@@ -19,11 +19,14 @@ use crate::utils::array_with_timezone;
1919use crate :: { timezone, BinaryOutputStyle } ;
2020use crate :: { EvalMode , SparkError , SparkResult } ;
2121use arrow:: array:: builder:: StringBuilder ;
22- use arrow:: array:: { ArrayAccessor , BooleanBuilder , Decimal128Builder , DictionaryArray , GenericByteArray , LargeStringArray , ListArray , PrimitiveBuilder , StringArray , StructArray } ;
22+ use arrow:: array:: {
23+ ArrayAccessor , BooleanBuilder , Decimal128Builder , DictionaryArray , GenericByteArray , ListArray ,
24+ PrimitiveBuilder , StringArray , StructArray ,
25+ } ;
2326use arrow:: compute:: can_cast_types;
2427use arrow:: datatypes:: {
25- i256, ArrowDictionaryKeyType , ArrowNativeType , DataType , Decimal256Type , DecimalType ,
26- GenericBinaryType , Schema ,
28+ i256, ArrowDictionaryKeyType , ArrowNativeType , DataType , Decimal256Type , GenericBinaryType ,
29+ Schema ,
2730} ;
2831use arrow:: {
2932 array:: {
@@ -55,7 +58,6 @@ use num::{
5558 ToPrimitive , Zero ,
5659} ;
5760use regex:: Regex ;
58- use std:: num:: ParseFloatError ;
5961use std:: str:: FromStr ;
6062use std:: {
6163 any:: Any ,
@@ -1082,7 +1084,7 @@ fn cast_string_to_decimal128_impl(
10821084) -> SparkResult < ArrayRef > {
10831085 let string_array = array
10841086 . as_any ( )
1085- . downcast_ref :: < LargeStringArray > ( )
1087+ . downcast_ref :: < StringArray > ( )
10861088 . ok_or_else ( || SparkError :: Internal ( "Expected string array" . to_string ( ) ) ) ?;
10871089
10881090 let mut decimal_builder = Decimal128Builder :: with_capacity ( string_array. len ( ) ) ;
@@ -1195,8 +1197,7 @@ fn parse_string_to_decimal(s: &str, precision: u8, scale: i8) -> SparkResult<Opt
11951197 }
11961198
11971199 // Parse the string as a decimal number
1198- // Note: We do NOT strip 'D' or 'F' suffixes - let parsing fail naturally
1199- // This matches Spark's behavior which uses JavaBigDecimal(string)
1200+ // Note: We do NOT strip 'D' or 'F' suffixes - let rust's parsing fail naturally for invalid input
12001201 match parse_decimal_str ( s) {
12011202 Ok ( ( mantissa, exponent) ) => {
12021203 // Convert to target scale
@@ -1246,30 +1247,48 @@ fn parse_string_to_decimal(s: &str, precision: u8, scale: i8) -> SparkResult<Opt
12461247 }
12471248}
12481249
1249- /// Parse a decimal string into ( mantissa, scale)
1250+ /// Parse a decimal string into mantissa and scale
12501251/// e.g., "123.45" -> (12345, 2), "-0.001" -> (-1, 3)
12511252fn parse_decimal_str ( s : & str ) -> Result < ( i128 , i32 ) , String > {
12521253 let s = s. trim ( ) ;
12531254 if s. is_empty ( ) {
12541255 return Err ( "Empty string" . to_string ( ) ) ;
12551256 }
12561257
1257- let negative = s. starts_with ( '-' ) ;
1258- let s = if negative || s. starts_with ( '+' ) {
1259- & s[ 1 ..]
1258+ // Check if input is scientific notation (e.g., "1.23E-5", "1e10")
1259+ let ( mantissa_str, exponent) = if let Some ( e_pos) = s. find ( |c| [ 'e' , 'E' ] . contains ( & c) ) {
1260+ let mantissa_part = & s[ ..e_pos] ;
1261+ let exponent_part = & s[ e_pos + 1 ..] ;
1262+
1263+ // Parse exponent part
1264+ let exp: i32 = exponent_part
1265+ . parse ( )
1266+ . map_err ( |_| "Invalid exponent" . to_string ( ) ) ?;
1267+
1268+ ( mantissa_part, exp)
12601269 } else {
1261- s
1270+ ( s , 0 )
12621271 } ;
12631272
1264- // Split by decimal point
1265- let parts: Vec < & str > = s. split ( '.' ) . collect ( ) ;
1273+ let negative = mantissa_str. starts_with ( '-' ) ;
1274+ let mantissa_str = if negative || mantissa_str. starts_with ( '+' ) {
1275+ & mantissa_str[ 1 ..]
1276+ } else {
1277+ mantissa_str
1278+ } ;
1279+
1280+ let split_by_dot: Vec < & str > = mantissa_str. split ( '.' ) . collect ( ) ;
12661281
1267- if parts . len ( ) > 2 {
1282+ if split_by_dot . len ( ) > 2 {
12681283 return Err ( "Multiple decimal points" . to_string ( ) ) ;
12691284 }
12701285
1271- let integral_part = parts[ 0 ] ;
1272- let fractional_part = if parts. len ( ) == 2 { parts[ 1 ] } else { "" } ;
1286+ let integral_part = split_by_dot[ 0 ] ;
1287+ let fractional_part = if split_by_dot. len ( ) == 2 {
1288+ split_by_dot[ 1 ]
1289+ } else {
1290+ ""
1291+ } ;
12731292
12741293 // Parse integral part
12751294 let integral_value: i128 = if integral_part. is_empty ( ) {
@@ -1281,7 +1300,7 @@ fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> {
12811300 } ;
12821301
12831302 // Parse fractional part
1284- let scale = fractional_part. len ( ) as i32 ;
1303+ let fractional_scale = fractional_part. len ( ) as i32 ;
12851304 let fractional_value: i128 = if fractional_part. is_empty ( ) {
12861305 0
12871306 } else {
@@ -1290,15 +1309,17 @@ fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> {
12901309 . map_err ( |_| "Invalid fractional part" . to_string ( ) ) ?
12911310 } ;
12921311
1293- // Combine: value = integral * 10^scale + fractional
1312+ // Combine: value = integral * 10^fractional_scale + fractional
12941313 let mantissa = integral_value
1295- . checked_mul ( 10_i128 . pow ( scale as u32 ) )
1314+ . checked_mul ( 10_i128 . pow ( fractional_scale as u32 ) )
12961315 . and_then ( |v| v. checked_add ( fractional_value) )
12971316 . ok_or ( "Overflow in mantissa calculation" ) ?;
12981317
12991318 let final_mantissa = if negative { -mantissa } else { mantissa } ;
1300-
1301- Ok ( ( final_mantissa, scale) )
1319+ // final scale = fractional_scale - exponent
1320+ // For example : "1.23E-5" has fractional_scale=2, exponent=-5, so scale = 2 - (-5) = 7
1321+ let final_scale = fractional_scale - exponent;
1322+ Ok ( ( final_mantissa, final_scale) )
13021323}
13031324
13041325fn cast_string_to_float (
@@ -1307,8 +1328,9 @@ fn cast_string_to_float(
13071328 eval_mode : EvalMode ,
13081329) -> SparkResult < ArrayRef > {
13091330 match to_type {
1310- DataType :: Float16 => cast_string_to_float_impl :: < Float32Type > ( array, eval_mode, "FLOAT" ) ,
1311- DataType :: Float32 => cast_string_to_float_impl :: < Float32Type > ( array, eval_mode, "FLOAT" ) ,
1331+ DataType :: Float16 | DataType :: Float32 => {
1332+ cast_string_to_float_impl :: < Float32Type > ( array, eval_mode, "FLOAT" )
1333+ }
13121334 DataType :: Float64 => cast_string_to_float_impl :: < Float64Type > ( array, eval_mode, "DOUBLE" ) ,
13131335 _ => Err ( SparkError :: Internal ( format ! (
13141336 "Unsupported cast to float type: {:?}" ,
@@ -1323,92 +1345,59 @@ fn cast_string_to_float_impl<T: ArrowPrimitiveType>(
13231345 type_name : & str ,
13241346) -> SparkResult < ArrayRef >
13251347where
1326- T :: Native : FloatParse ,
1348+ T :: Native : FromStr + num :: Float ,
13271349{
13281350 let arr = array
13291351 . as_any ( )
13301352 . downcast_ref :: < StringArray > ( )
1331- . ok_or_else ( || SparkError :: Internal ( "could not parse input as string type " . to_string ( ) ) ) ?;
1353+ . ok_or_else ( || SparkError :: Internal ( "Expected string array " . to_string ( ) ) ) ?;
13321354
1333- let mut cast_array = PrimitiveArray :: < T > :: builder ( arr. len ( ) ) ;
1355+ let mut builder = PrimitiveBuilder :: < T > :: with_capacity ( arr. len ( ) ) ;
13341356
13351357 for i in 0 ..arr. len ( ) {
13361358 if arr. is_null ( i) {
1337- cast_array . append_null ( ) ;
1359+ builder . append_null ( ) ;
13381360 } else {
13391361 let str_value = arr. value ( i) . trim ( ) ;
1340- match T :: Native :: parse_spark_float ( str_value) {
1341- Ok ( v) => {
1342- cast_array. append_value ( v) ;
1343- }
1344- Err ( _) => {
1362+ match parse_string_to_float ( str_value) {
1363+ Some ( v) => builder. append_value ( v) ,
1364+ None => {
13451365 if eval_mode == EvalMode :: Ansi {
13461366 return Err ( invalid_value ( arr. value ( i) , "STRING" , type_name) ) ;
1347- } else {
1348- cast_array. append_null ( ) ;
13491367 }
1368+ builder. append_null ( ) ;
13501369 }
13511370 }
13521371 }
13531372 }
1354- Ok ( Arc :: new ( cast_array. finish ( ) ) )
1355- }
13561373
1357- /// Trait for parsing float from str
1358- trait FloatParse : Sized {
1359- fn parse_spark_float ( s : & str ) -> Result < Self , ParseFloatError > ;
1374+ Ok ( Arc :: new ( builder. finish ( ) ) )
13601375}
13611376
1362- impl FloatParse for f32 {
1363- fn parse_spark_float ( s : & str ) -> Result < Self , ParseFloatError > {
1364- let s_lower = s. to_lowercase ( ) ;
1365-
1366- if s_lower == "inf" || s_lower == "+inf" || s_lower == "infinity" || s_lower == "+infinity"
1367- {
1368- return Ok ( f32:: INFINITY ) ;
1369- }
1370-
1371- if s_lower == "-inf" || s_lower == "-infinity" {
1372- return Ok ( f32:: NEG_INFINITY ) ;
1373- }
1374-
1375- if s_lower == "nan" {
1376- return Ok ( f32:: NAN ) ;
1377- }
1378-
1379- let pruned = if s_lower. ends_with ( 'd' ) || s_lower. ends_with ( 'f' ) {
1380- & s[ ..s. len ( ) - 1 ]
1381- } else {
1382- s
1383- } ;
1384- pruned. parse :: < f32 > ( )
1377+ /// helper to parse floats from string inputs
1378+ fn parse_string_to_float < F > ( s : & str ) -> Option < F >
1379+ where
1380+ F : FromStr + num:: Float ,
1381+ {
1382+ let s_lower = s. to_lowercase ( ) ;
1383+ // Handle +inf / -inf
1384+ if s_lower == "inf" || s_lower == "+inf" || s_lower == "infinity" || s_lower == "+infinity" {
1385+ return Some ( F :: infinity ( ) ) ;
13851386 }
1386- }
1387-
1388- impl FloatParse for f64 {
1389- fn parse_spark_float ( s : & str ) -> Result < Self , ParseFloatError > {
1390- let s_lower = s. to_lowercase ( ) ;
1391-
1392- if s_lower == "inf" || s_lower == "+inf" || s_lower == "infinity" || s_lower == "+infinity"
1393- {
1394- return Ok ( f64:: INFINITY ) ;
1395- }
1396-
1397- if s_lower == "-inf" || s_lower == "-infinity" {
1398- return Ok ( f64:: NEG_INFINITY ) ;
1399- }
1400-
1401- if s_lower == "nan" {
1402- return Ok ( f64:: NAN ) ;
1403- }
1404-
1405- let cleaned = if s_lower. ends_with ( 'd' ) || s_lower. ends_with ( 'f' ) {
1406- & s[ ..s. len ( ) - 1 ]
1407- } else {
1408- s
1409- } ;
1410- cleaned. parse :: < f64 > ( )
1387+ if s_lower == "-inf" || s_lower == "-infinity" {
1388+ return Some ( F :: neg_infinity ( ) ) ;
1389+ }
1390+ if s_lower == "nan" {
1391+ return Some ( F :: nan ( ) ) ;
14111392 }
1393+ // Remove D/F suffix if present
1394+ let pruned_float_str = if s_lower. ends_with ( 'd' ) || s_lower. ends_with ( 'f' ) {
1395+ & s[ ..s. len ( ) - 1 ]
1396+ } else {
1397+ s
1398+ } ;
1399+ // Rust's parse logic already handles scientific notations so we just rely on it
1400+ pruned_float_str. parse :: < F > ( ) . ok ( )
14121401}
14131402
14141403fn cast_binary_to_string < O : OffsetSizeTrait > (
0 commit comments