Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/user-guide/latest/compatibility.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ The following cast operations are generally compatible with Spark except for the
| string | short | |
| string | integer | |
| string | long | |
| string | float | |
| string | double | |
| string | binary | |
| string | date | Only supports years between 262143 BC and 262142 AD |
| binary | string | |
Expand All @@ -181,8 +183,6 @@ The following cast operations are not compatible with Spark for all inputs and a
|-|-|-|
| float | decimal | There can be rounding differences |
| double | decimal | There can be rounding differences |
| string | float | Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. Does not support ANSI mode. |
| string | double | Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. Does not support ANSI mode. |
| string | decimal | Does not support fullwidth unicode digits (e.g \\uFF10)
or strings containing null bytes (e.g \\u0000) |
| string | timestamp | Not all valid formats are supported |
Expand Down
103 changes: 85 additions & 18 deletions native/spark-expr/src/conversion_funcs/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ use arrow::{
record_batch::RecordBatch,
util::display::FormatOptions,
};
use base64::prelude::*;
use chrono::{DateTime, NaiveDate, TimeZone, Timelike};
use datafusion::common::{
cast::as_generic_string_array, internal_err, DataFusionError, Result as DataFusionResult,
Expand All @@ -66,8 +67,6 @@ use std::{
sync::Arc,
};

use base64::prelude::*;

static TIMESTAMP_FORMAT: Option<&str> = Some("%Y-%m-%d %H:%M:%S%.f");

const MICROS_PER_SECOND: i64 = 1000000;
Expand Down Expand Up @@ -217,12 +216,7 @@ fn can_cast_from_string(to_type: &DataType, options: &SparkCastOptions) -> bool
use DataType::*;
match to_type {
Boolean | Int8 | Int16 | Int32 | Int64 | Binary => true,
Float32 | Float64 => {
// https://github.com/apache/datafusion-comet/issues/326
// Does not support inputs ending with 'd' or 'f'. Does not support 'inf'.
// Does not support ANSI mode.
options.allow_incompat
}
Float32 | Float64 => true,
Decimal128(_, _) => {
// https://github.com/apache/datafusion-comet/issues/325
// Does not support fullwidth digits and null byte handling.
Expand Down Expand Up @@ -975,6 +969,7 @@ fn cast_array(
cast_string_to_timestamp(&array, to_type, eval_mode, &cast_options.timezone)
}
(Utf8, Date32) => cast_string_to_date(&array, to_type, eval_mode),
(Utf8, Float32 | Float64) => cast_string_to_float(&array, to_type, eval_mode),
(Utf8 | LargeUtf8, Decimal128(precision, scale)) => {
cast_string_to_decimal(&array, to_type, precision, scale, eval_mode)
}
Expand Down Expand Up @@ -1046,7 +1041,7 @@ fn cast_array(
}
(Binary, Utf8) => Ok(cast_binary_to_string::<i32>(&array, cast_options)?),
_ if cast_options.is_adapting_schema
|| is_datafusion_spark_compatible(from_type, to_type, cast_options.allow_incompat) =>
|| is_datafusion_spark_compatible(from_type, to_type) =>
{
// use DataFusion cast only when we know that it is compatible with Spark
Ok(cast_with_options(&array, to_type, &native_cast_options)?)
Expand All @@ -1063,6 +1058,86 @@ fn cast_array(
Ok(spark_cast_postprocess(cast_result?, from_type, to_type))
}

fn cast_string_to_float(
array: &ArrayRef,
to_type: &DataType,
eval_mode: EvalMode,
) -> SparkResult<ArrayRef> {
match to_type {
DataType::Float32 => cast_string_to_float_impl::<Float32Type>(array, eval_mode, "FLOAT"),
DataType::Float64 => cast_string_to_float_impl::<Float64Type>(array, eval_mode, "DOUBLE"),
_ => Err(SparkError::Internal(format!(
"Unsupported cast to float type: {:?}",
to_type
))),
}
}

fn cast_string_to_float_impl<T: ArrowPrimitiveType>(
array: &ArrayRef,
eval_mode: EvalMode,
type_name: &str,
) -> SparkResult<ArrayRef>
where
T::Native: FromStr + num::Float,
{
let arr = array
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| SparkError::Internal("Expected string array".to_string()))?;

let mut builder = PrimitiveBuilder::<T>::with_capacity(arr.len());

for i in 0..arr.len() {
if arr.is_null(i) {
builder.append_null();
} else {
let str_value = arr.value(i).trim();
match parse_string_to_float(str_value) {
Some(v) => builder.append_value(v),
None => {
if eval_mode == EvalMode::Ansi {
return Err(invalid_value(arr.value(i), "STRING", type_name));
}
builder.append_null();
}
}
}
}

Ok(Arc::new(builder.finish()))
}

/// helper to parse floats from string inputs
fn parse_string_to_float<F>(s: &str) -> Option<F>
where
F: FromStr + num::Float,
{
// Handle +inf / -inf
if s.eq_ignore_ascii_case("inf")
|| s.eq_ignore_ascii_case("+inf")
|| s.eq_ignore_ascii_case("infinity")
|| s.eq_ignore_ascii_case("+infinity")
{
return Some(F::infinity());
}
if s.eq_ignore_ascii_case("-inf") || s.eq_ignore_ascii_case("-infinity") {
return Some(F::neg_infinity());
}
if s.eq_ignore_ascii_case("nan") {
return Some(F::nan());
}
// Remove D/F suffix if present
let pruned_float_str =
if s.ends_with("d") || s.ends_with("D") || s.ends_with('f') || s.ends_with('F') {
&s[..s.len() - 1]
} else {
s
};
// Rust's parse logic already handles scientific notations so we just rely on it
pruned_float_str.parse::<F>().ok()
}

fn cast_binary_to_string<O: OffsetSizeTrait>(
array: &dyn Array,
spark_cast_options: &SparkCastOptions,
Expand Down Expand Up @@ -1133,11 +1208,7 @@ fn cast_binary_formatter(value: &[u8]) -> String {

/// Determines if DataFusion supports the given cast in a way that is
/// compatible with Spark
fn is_datafusion_spark_compatible(
from_type: &DataType,
to_type: &DataType,
allow_incompat: bool,
) -> bool {
fn is_datafusion_spark_compatible(from_type: &DataType, to_type: &DataType) -> bool {
if from_type == to_type {
return true;
}
Expand Down Expand Up @@ -1190,10 +1261,6 @@ fn is_datafusion_spark_compatible(
| DataType::Decimal256(_, _)
| DataType::Utf8 // note that there can be formatting differences
),
DataType::Utf8 if allow_incompat => matches!(
to_type,
DataType::Binary | DataType::Float32 | DataType::Float64
),
DataType::Utf8 => matches!(to_type, DataType::Binary),
DataType::Date32 => matches!(to_type, DataType::Utf8),
DataType::Timestamp(_, _) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,11 +185,7 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
case DataTypes.BinaryType =>
Compatible()
case DataTypes.FloatType | DataTypes.DoubleType =>
// https://github.com/apache/datafusion-comet/issues/326
Incompatible(
Some(
"Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. " +
"Does not support ANSI mode."))
Compatible()
case _: DecimalType =>
// https://github.com/apache/datafusion-comet/issues/325
Incompatible(Some("""Does not support fullwidth unicode digits (e.g \\uFF10)
Expand Down
58 changes: 37 additions & 21 deletions spark/src/test/scala/org/apache/comet/CometCastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -642,34 +642,50 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
castTest(gen.generateStrings(dataSize, numericPattern, 8).toDF("a"), DataTypes.LongType)
}

ignore("cast StringType to FloatType") {
test("cast StringType to DoubleType") {
// https://github.com/apache/datafusion-comet/issues/326
castTest(gen.generateStrings(dataSize, numericPattern, 8).toDF("a"), DataTypes.DoubleType)
}

test("cast StringType to FloatType") {
castTest(gen.generateStrings(dataSize, numericPattern, 8).toDF("a"), DataTypes.FloatType)
}

test("cast StringType to FloatType (partial support)") {
withSQLConf(
CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true",
SQLConf.ANSI_ENABLED.key -> "false") {
castTest(
gen.generateStrings(dataSize, "0123456789.", 8).toDF("a"),
DataTypes.FloatType,
testAnsi = false)
val specialValues: Seq[String] = Seq(
"1.5f",
"1.5F",
"2.0d",
"2.0D",
"3.14159265358979d",
"inf",
"Inf",
"INF",
"+inf",
"+Infinity",
"-inf",
"-Infinity",
"NaN",
"nan",
"NAN",
"1.23e4",
"1.23E4",
"-1.23e-4",
" 123.456789 ",
"0.0",
"-0.0",
"",
"xyz",
null)

test("cast StringType to FloatType special values") {
Seq(true, false).foreach { ansiMode =>
castTest(specialValues.toDF("a"), DataTypes.FloatType, testAnsi = ansiMode)
}
}

ignore("cast StringType to DoubleType") {
// https://github.com/apache/datafusion-comet/issues/326
castTest(gen.generateStrings(dataSize, numericPattern, 8).toDF("a"), DataTypes.DoubleType)
}
test("cast StringType to DoubleType (partial support)") {
withSQLConf(
CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true",
SQLConf.ANSI_ENABLED.key -> "false") {
castTest(
gen.generateStrings(dataSize, "0123456789.", 8).toDF("a"),
DataTypes.DoubleType,
testAnsi = false)
test("cast StringType to DoubleType special values") {
Seq(true, false).foreach { ansiMode =>
castTest(specialValues.toDF("a"), DataTypes.DoubleType, testAnsi = ansiMode)
}
}

Expand Down
Loading