Skip to content

Commit 79d0ea9

Browse files
committed
support_string_to_non_int_casts
1 parent c807081 commit 79d0ea9

File tree

2 files changed

+120
-96
lines changed

2 files changed

+120
-96
lines changed

native/spark-expr/src/conversion_funcs/cast.rs

Lines changed: 77 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,14 @@ use crate::utils::array_with_timezone;
1919
use crate::{timezone, BinaryOutputStyle};
2020
use crate::{EvalMode, SparkError, SparkResult};
2121
use arrow::array::builder::StringBuilder;
22-
use arrow::array::{ArrayAccessor, BooleanBuilder, Decimal128Builder, DictionaryArray, GenericByteArray, LargeStringArray, ListArray, PrimitiveBuilder, StringArray, StructArray};
22+
use arrow::array::{
23+
ArrayAccessor, BooleanBuilder, Decimal128Builder, DictionaryArray, GenericByteArray, ListArray,
24+
PrimitiveBuilder, StringArray, StructArray,
25+
};
2326
use arrow::compute::can_cast_types;
2427
use arrow::datatypes::{
25-
i256, ArrowDictionaryKeyType, ArrowNativeType, DataType, Decimal256Type, DecimalType,
26-
GenericBinaryType, Schema,
28+
i256, ArrowDictionaryKeyType, ArrowNativeType, DataType, Decimal256Type, GenericBinaryType,
29+
Schema,
2730
};
2831
use arrow::{
2932
array::{
@@ -55,7 +58,6 @@ use num::{
5558
ToPrimitive, Zero,
5659
};
5760
use regex::Regex;
58-
use std::num::ParseFloatError;
5961
use std::str::FromStr;
6062
use std::{
6163
any::Any,
@@ -1082,7 +1084,7 @@ fn cast_string_to_decimal128_impl(
10821084
) -> SparkResult<ArrayRef> {
10831085
let string_array = array
10841086
.as_any()
1085-
.downcast_ref::<LargeStringArray>()
1087+
.downcast_ref::<StringArray>()
10861088
.ok_or_else(|| SparkError::Internal("Expected string array".to_string()))?;
10871089

10881090
let mut decimal_builder = Decimal128Builder::with_capacity(string_array.len());
@@ -1195,8 +1197,7 @@ fn parse_string_to_decimal(s: &str, precision: u8, scale: i8) -> SparkResult<Opt
11951197
}
11961198

11971199
// Parse the string as a decimal number
1198-
// Note: We do NOT strip 'D' or 'F' suffixes - let parsing fail naturally
1199-
// This matches Spark's behavior which uses JavaBigDecimal(string)
1200+
// Note: We do NOT strip 'D' or 'F' suffixes - let rust's parsing fail naturally for invalid input
12001201
match parse_decimal_str(s) {
12011202
Ok((mantissa, exponent)) => {
12021203
// Convert to target scale
@@ -1246,30 +1247,48 @@ fn parse_string_to_decimal(s: &str, precision: u8, scale: i8) -> SparkResult<Opt
12461247
}
12471248
}
12481249

1249-
/// Parse a decimal string into (mantissa, scale)
1250+
/// Parse a decimal string into mantissa and scale
12501251
/// e.g., "123.45" -> (12345, 2), "-0.001" -> (-1, 3)
12511252
fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> {
12521253
let s = s.trim();
12531254
if s.is_empty() {
12541255
return Err("Empty string".to_string());
12551256
}
12561257

1257-
let negative = s.starts_with('-');
1258-
let s = if negative || s.starts_with('+') {
1259-
&s[1..]
1258+
// Check if input is scientific notation (e.g., "1.23E-5", "1e10")
1259+
let (mantissa_str, exponent) = if let Some(e_pos) = s.find(|c| ['e', 'E'].contains(&c)) {
1260+
let mantissa_part = &s[..e_pos];
1261+
let exponent_part = &s[e_pos + 1..];
1262+
1263+
// Parse exponent part
1264+
let exp: i32 = exponent_part
1265+
.parse()
1266+
.map_err(|_| "Invalid exponent".to_string())?;
1267+
1268+
(mantissa_part, exp)
12601269
} else {
1261-
s
1270+
(s, 0)
12621271
};
12631272

1264-
// Split by decimal point
1265-
let parts: Vec<&str> = s.split('.').collect();
1273+
let negative = mantissa_str.starts_with('-');
1274+
let mantissa_str = if negative || mantissa_str.starts_with('+') {
1275+
&mantissa_str[1..]
1276+
} else {
1277+
mantissa_str
1278+
};
1279+
1280+
let split_by_dot: Vec<&str> = mantissa_str.split('.').collect();
12661281

1267-
if parts.len() > 2 {
1282+
if split_by_dot.len() > 2 {
12681283
return Err("Multiple decimal points".to_string());
12691284
}
12701285

1271-
let integral_part = parts[0];
1272-
let fractional_part = if parts.len() == 2 { parts[1] } else { "" };
1286+
let integral_part = split_by_dot[0];
1287+
let fractional_part = if split_by_dot.len() == 2 {
1288+
split_by_dot[1]
1289+
} else {
1290+
""
1291+
};
12731292

12741293
// Parse integral part
12751294
let integral_value: i128 = if integral_part.is_empty() {
@@ -1281,7 +1300,7 @@ fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> {
12811300
};
12821301

12831302
// Parse fractional part
1284-
let scale = fractional_part.len() as i32;
1303+
let fractional_scale = fractional_part.len() as i32;
12851304
let fractional_value: i128 = if fractional_part.is_empty() {
12861305
0
12871306
} else {
@@ -1290,15 +1309,17 @@ fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> {
12901309
.map_err(|_| "Invalid fractional part".to_string())?
12911310
};
12921311

1293-
// Combine: value = integral * 10^scale + fractional
1312+
// Combine: value = integral * 10^fractional_scale + fractional
12941313
let mantissa = integral_value
1295-
.checked_mul(10_i128.pow(scale as u32))
1314+
.checked_mul(10_i128.pow(fractional_scale as u32))
12961315
.and_then(|v| v.checked_add(fractional_value))
12971316
.ok_or("Overflow in mantissa calculation")?;
12981317

12991318
let final_mantissa = if negative { -mantissa } else { mantissa };
1300-
1301-
Ok((final_mantissa, scale))
1319+
// final scale = fractional_scale - exponent
1320+
// For example : "1.23E-5" has fractional_scale=2, exponent=-5, so scale = 2 - (-5) = 7
1321+
let final_scale = fractional_scale - exponent;
1322+
Ok((final_mantissa, final_scale))
13021323
}
13031324

13041325
fn cast_string_to_float(
@@ -1307,8 +1328,9 @@ fn cast_string_to_float(
13071328
eval_mode: EvalMode,
13081329
) -> SparkResult<ArrayRef> {
13091330
match to_type {
1310-
DataType::Float16 => cast_string_to_float_impl::<Float32Type>(array, eval_mode, "FLOAT"),
1311-
DataType::Float32 => cast_string_to_float_impl::<Float32Type>(array, eval_mode, "FLOAT"),
1331+
DataType::Float16 | DataType::Float32 => {
1332+
cast_string_to_float_impl::<Float32Type>(array, eval_mode, "FLOAT")
1333+
}
13121334
DataType::Float64 => cast_string_to_float_impl::<Float64Type>(array, eval_mode, "DOUBLE"),
13131335
_ => Err(SparkError::Internal(format!(
13141336
"Unsupported cast to float type: {:?}",
@@ -1323,92 +1345,59 @@ fn cast_string_to_float_impl<T: ArrowPrimitiveType>(
13231345
type_name: &str,
13241346
) -> SparkResult<ArrayRef>
13251347
where
1326-
T::Native: FloatParse,
1348+
T::Native: FromStr + num::Float,
13271349
{
13281350
let arr = array
13291351
.as_any()
13301352
.downcast_ref::<StringArray>()
1331-
.ok_or_else(|| SparkError::Internal("could not parse input as string type".to_string()))?;
1353+
.ok_or_else(|| SparkError::Internal("Expected string array".to_string()))?;
13321354

1333-
let mut cast_array = PrimitiveArray::<T>::builder(arr.len());
1355+
let mut builder = PrimitiveBuilder::<T>::with_capacity(arr.len());
13341356

13351357
for i in 0..arr.len() {
13361358
if arr.is_null(i) {
1337-
cast_array.append_null();
1359+
builder.append_null();
13381360
} else {
13391361
let str_value = arr.value(i).trim();
1340-
match T::Native::parse_spark_float(str_value) {
1341-
Ok(v) => {
1342-
cast_array.append_value(v);
1343-
}
1344-
Err(_) => {
1362+
match parse_string_to_float(str_value) {
1363+
Some(v) => builder.append_value(v),
1364+
None => {
13451365
if eval_mode == EvalMode::Ansi {
13461366
return Err(invalid_value(arr.value(i), "STRING", type_name));
1347-
} else {
1348-
cast_array.append_null();
13491367
}
1368+
builder.append_null();
13501369
}
13511370
}
13521371
}
13531372
}
1354-
Ok(Arc::new(cast_array.finish()))
1355-
}
13561373

1357-
/// Trait for parsing float from str
1358-
trait FloatParse: Sized {
1359-
fn parse_spark_float(s: &str) -> Result<Self, ParseFloatError>;
1374+
Ok(Arc::new(builder.finish()))
13601375
}
13611376

1362-
impl FloatParse for f32 {
1363-
fn parse_spark_float(s: &str) -> Result<Self, ParseFloatError> {
1364-
let s_lower = s.to_lowercase();
1365-
1366-
if s_lower == "inf" || s_lower == "+inf" || s_lower == "infinity" || s_lower == "+infinity"
1367-
{
1368-
return Ok(f32::INFINITY);
1369-
}
1370-
1371-
if s_lower == "-inf" || s_lower == "-infinity" {
1372-
return Ok(f32::NEG_INFINITY);
1373-
}
1374-
1375-
if s_lower == "nan" {
1376-
return Ok(f32::NAN);
1377-
}
1378-
1379-
let pruned = if s_lower.ends_with('d') || s_lower.ends_with('f') {
1380-
&s[..s.len() - 1]
1381-
} else {
1382-
s
1383-
};
1384-
pruned.parse::<f32>()
1377+
/// helper to parse floats from string inputs
1378+
fn parse_string_to_float<F>(s: &str) -> Option<F>
1379+
where
1380+
F: FromStr + num::Float,
1381+
{
1382+
let s_lower = s.to_lowercase();
1383+
// Handle +inf / -inf
1384+
if s_lower == "inf" || s_lower == "+inf" || s_lower == "infinity" || s_lower == "+infinity" {
1385+
return Some(F::infinity());
13851386
}
1386-
}
1387-
1388-
impl FloatParse for f64 {
1389-
fn parse_spark_float(s: &str) -> Result<Self, ParseFloatError> {
1390-
let s_lower = s.to_lowercase();
1391-
1392-
if s_lower == "inf" || s_lower == "+inf" || s_lower == "infinity" || s_lower == "+infinity"
1393-
{
1394-
return Ok(f64::INFINITY);
1395-
}
1396-
1397-
if s_lower == "-inf" || s_lower == "-infinity" {
1398-
return Ok(f64::NEG_INFINITY);
1399-
}
1400-
1401-
if s_lower == "nan" {
1402-
return Ok(f64::NAN);
1403-
}
1404-
1405-
let cleaned = if s_lower.ends_with('d') || s_lower.ends_with('f') {
1406-
&s[..s.len() - 1]
1407-
} else {
1408-
s
1409-
};
1410-
cleaned.parse::<f64>()
1387+
if s_lower == "-inf" || s_lower == "-infinity" {
1388+
return Some(F::neg_infinity());
1389+
}
1390+
if s_lower == "nan" {
1391+
return Some(F::nan());
14111392
}
1393+
// Remove D/F suffix if present
1394+
let pruned_float_str = if s_lower.ends_with('d') || s_lower.ends_with('f') {
1395+
&s[..s.len() - 1]
1396+
} else {
1397+
s
1398+
};
1399+
// Rust's parse logic already handles scientific notations so we just rely on it
1400+
pruned_float_str.parse::<F>().ok()
14121401
}
14131402

14141403
fn cast_binary_to_string<O: OffsetSizeTrait>(

spark/src/test/scala/org/apache/comet/CometCastSuite.scala

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -691,7 +691,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
691691
}
692692
}
693693

694-
ignore("cast StringType to DecimalType(10,2)") {
694+
test("cast StringType to DecimalType(10,2) fuzz") {
695695
// https://github.com/apache/datafusion-comet/issues/325
696696
val values = gen.generateStrings(dataSize, numericPattern, 8).toDF("a")
697697
castTest(values, DataTypes.createDecimalType(10, 2))
@@ -713,7 +713,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
713713
"",
714714
"abc",
715715
null).toDF("a")
716-
castTest(values, DataTypes.createDecimalType(10, 2), testAnsi = false)
716+
Seq(true, false).foreach(k => castTest(values, DataTypes.createDecimalType(10, 2), k))
717717
}
718718

719719
test("cast StringType to DecimalType(38,10) high precision") {
@@ -729,18 +729,53 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
729729
"",
730730
"abc",
731731
null).toDF("a")
732-
castTest(values, DataTypes.createDecimalType(38, 10), testAnsi = false)
732+
Seq(true, false).foreach(k =>
733+
castTest(values, DataTypes.createDecimalType(38, 10), testAnsi = k))
733734
}
734735

735-
test("cast StringType to DecimalType(10,2) (partial support)") {
736-
withSQLConf(
737-
CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true",
738-
SQLConf.ANSI_ENABLED.key -> "false") {
736+
test("cast StringType to Float type scientific notation") {
737+
val values = Seq(
738+
"1.23E-5",
739+
"1.23e10",
740+
"1.23E+10",
741+
"-1.23e-5",
742+
"1e5",
743+
"1E-2",
744+
"-1.5e3",
745+
"1.23E0",
746+
"0e0",
747+
"1.23e",
748+
"e5",
749+
null).toDF("a")
750+
Seq(true, false).foreach(k => castTest(values, DataTypes.FloatType, testAnsi = k))
751+
}
752+
753+
test("cast StringType to Decimal type scientific notation") {
754+
val values = Seq(
755+
"1.23E-5",
756+
"1.23e10",
757+
"1.23E+10",
758+
"-1.23e-5",
759+
"1e5",
760+
"1E-2",
761+
"-1.5e3",
762+
"1.23E0",
763+
"0e0",
764+
"1.23e",
765+
"e5",
766+
null).toDF("a")
767+
Seq(true, false).foreach(k =>
768+
castTest(values, DataTypes.createDecimalType(23, 8), testAnsi = k))
769+
}
770+
771+
test("cast StringType to DecimalType(10,2)") {
772+
withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") {
739773
val values = gen
740774
.generateStrings(dataSize, "0123456789.", 8)
741775
.filter(_.exists(_.isDigit))
742776
.toDF("a")
743-
castTest(values, DataTypes.createDecimalType(10, 2), testAnsi = false)
777+
Seq(true, false).foreach(k =>
778+
castTest(values, DataTypes.createDecimalType(10, 2), testAnsi = k))
744779
}
745780
}
746781

0 commit comments

Comments
 (0)