diff --git a/datafusion/functions/src/math/iszero.rs b/datafusion/functions/src/math/iszero.rs index dd1b082cf0f37..aa93d797eb7b3 100644 --- a/datafusion/functions/src/math/iszero.rs +++ b/datafusion/functions/src/math/iszero.rs @@ -19,10 +19,16 @@ use std::any::Any; use std::sync::Arc; use arrow::array::{ArrowNativeTypeOp, AsArray, BooleanArray}; -use arrow::datatypes::DataType::{Boolean, Float16, Float32, Float64}; -use arrow::datatypes::{DataType, Float16Type, Float32Type, Float64Type}; +use arrow::datatypes::DataType::{ + Boolean, Decimal32, Decimal64, Decimal128, Decimal256, Float16, Float32, Float64, + Int8, Int16, Int32, Int64, Null, UInt8, UInt16, UInt32, UInt64, +}; +use arrow::datatypes::{ + DataType, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, Float16Type, + Float32Type, Float64Type, Int8Type, Int16Type, Int32Type, Int64Type, UInt8Type, + UInt16Type, UInt32Type, UInt64Type, +}; -use datafusion_common::types::NativeType; use datafusion_common::utils::take_function_args; use datafusion_common::{Result, ScalarValue, internal_err}; use datafusion_expr::{Coercion, TypeSignatureClass}; @@ -59,14 +65,10 @@ impl Default for IsZeroFunc { impl IsZeroFunc { pub fn new() -> Self { - // Accept any numeric type and coerce to float - let float = Coercion::new_implicit( - TypeSignatureClass::Float, - vec![TypeSignatureClass::Numeric], - NativeType::Float64, - ); + // Accept any numeric type (ints, uints, floats, decimals) without implicit casts. + let numeric = Coercion::new_exact(TypeSignatureClass::Numeric); Self { - signature: Signature::coercible(vec![float], Volatility::Immutable), + signature: Signature::coercible(vec![numeric], Volatility::Immutable), } } } @@ -107,6 +109,45 @@ impl ScalarUDFImpl for IsZeroFunc { ScalarValue::Float16(Some(v)) => Ok(ColumnarValue::Scalar( ScalarValue::Boolean(Some(v.is_zero())), )), + + ScalarValue::Int8(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0)))) + } + ScalarValue::Int16(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0)))) + } + ScalarValue::Int32(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0)))) + } + ScalarValue::Int64(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0)))) + } + ScalarValue::UInt8(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0)))) + } + ScalarValue::UInt16(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0)))) + } + ScalarValue::UInt32(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0)))) + } + ScalarValue::UInt64(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0)))) + } + + ScalarValue::Decimal32(Some(v), ..) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0)))) + } + ScalarValue::Decimal64(Some(v), ..) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0)))) + } + ScalarValue::Decimal128(Some(v), ..) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0)))) + } + ScalarValue::Decimal256(Some(v), ..) => Ok(ColumnarValue::Scalar( + ScalarValue::Boolean(Some(v.is_zero())), + )), + _ => { internal_err!( "Unexpected scalar type for iszero: {:?}", @@ -116,6 +157,10 @@ impl ScalarUDFImpl for IsZeroFunc { } } ColumnarValue::Array(array) => match array.data_type() { + Null => Ok(ColumnarValue::Array(Arc::new(BooleanArray::new_null( + array.len(), + )))), + Float64 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( array.as_primitive::(), |x| x == 0.0, @@ -128,6 +173,65 @@ impl ScalarUDFImpl for IsZeroFunc { array.as_primitive::(), |x| x.is_zero(), )))), + + Int8 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x == 0, + )))), + Int16 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x == 0, + )))), + Int32 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x == 0, + )))), + Int64 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x == 0, + )))), + UInt8 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x == 0, + )))), + UInt16 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x == 0, + )))), + UInt32 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x == 0, + )))), + UInt64 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x == 0, + )))), + + Decimal32(_, _) => { + Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x == 0, + )))) + } + Decimal64(_, _) => { + Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x == 0, + )))) + } + Decimal128(_, _) => { + Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x == 0, + )))) + } + Decimal256(_, _) => { + Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x.is_zero(), + )))) + } + other => { internal_err!("Unexpected data type {other:?} for function iszero") } diff --git a/datafusion/functions/src/math/nans.rs b/datafusion/functions/src/math/nans.rs index 03f246c28be19..632eafe1e009a 100644 --- a/datafusion/functions/src/math/nans.rs +++ b/datafusion/functions/src/math/nans.rs @@ -17,13 +17,21 @@ //! Math function: `isnan()`. -use arrow::datatypes::{DataType, Float16Type, Float32Type, Float64Type}; -use datafusion_common::types::NativeType; -use datafusion_common::{Result, ScalarValue, exec_err}; -use datafusion_expr::{Coercion, ColumnarValue, ScalarFunctionArgs, TypeSignatureClass}; - use arrow::array::{ArrayRef, AsArray, BooleanArray}; -use datafusion_expr::{Documentation, ScalarUDFImpl, Signature, Volatility}; +use arrow::datatypes::DataType::{ + Decimal32, Decimal64, Decimal128, Decimal256, Float16, Float32, Float64, Int8, Int16, + Int32, Int64, Null, UInt8, UInt16, UInt32, UInt64, +}; +use arrow::datatypes::{ + DataType, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, Float16Type, + Float32Type, Float64Type, Int8Type, Int16Type, Int32Type, Int64Type, UInt8Type, + UInt16Type, UInt32Type, UInt64Type, +}; +use datafusion_common::{Result, ScalarValue, exec_err, utils::take_function_args}; +use datafusion_expr::{ + Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignatureClass, Volatility, +}; use datafusion_macros::user_doc; use std::any::Any; use std::sync::Arc; @@ -55,14 +63,10 @@ impl Default for IsNanFunc { impl IsNanFunc { pub fn new() -> Self { - // Accept any numeric type and coerce to float - let float = Coercion::new_implicit( - TypeSignatureClass::Float, - vec![TypeSignatureClass::Numeric], - NativeType::Float64, - ); + // Accept any numeric type (ints, uints, floats, decimals) without implicit casts. + let numeric = Coercion::new_exact(TypeSignatureClass::Numeric); Self { - signature: Signature::coercible(vec![float], Volatility::Immutable), + signature: Signature::coercible(vec![numeric], Volatility::Immutable), } } } @@ -84,36 +88,123 @@ impl ScalarUDFImpl for IsNanFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - // Handle NULL input - if args.args[0].data_type().is_null() { - return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))); - } + let [arg] = take_function_args(self.name(), args.args)?; + + match arg { + ColumnarValue::Scalar(scalar) => { + if scalar.is_null() { + return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))); + } + + let result = match scalar { + ScalarValue::Float64(Some(v)) => Some(v.is_nan()), + ScalarValue::Float32(Some(v)) => Some(v.is_nan()), + ScalarValue::Float16(Some(v)) => Some(v.is_nan()), - let args = ColumnarValue::values_to_arrays(&args.args)?; - - let arr: ArrayRef = match args[0].data_type() { - DataType::Float64 => Arc::new(BooleanArray::from_unary( - args[0].as_primitive::(), - f64::is_nan, - )) as ArrayRef, - - DataType::Float32 => Arc::new(BooleanArray::from_unary( - args[0].as_primitive::(), - f32::is_nan, - )) as ArrayRef, - - DataType::Float16 => Arc::new(BooleanArray::from_unary( - args[0].as_primitive::(), - |x| x.is_nan(), - )) as ArrayRef, - other => { - return exec_err!( - "Unsupported data type {other:?} for function {}", - self.name() - ); + // Non-float numeric inputs are never NaN + ScalarValue::Int8(_) + | ScalarValue::Int16(_) + | ScalarValue::Int32(_) + | ScalarValue::Int64(_) + | ScalarValue::UInt8(_) + | ScalarValue::UInt16(_) + | ScalarValue::UInt32(_) + | ScalarValue::UInt64(_) + | ScalarValue::Decimal32(_, _, _) + | ScalarValue::Decimal64(_, _, _) + | ScalarValue::Decimal128(_, _, _) + | ScalarValue::Decimal256(_, _, _) => Some(false), + + other => { + return exec_err!( + "Unsupported data type {other:?} for function {}", + self.name() + ); + } + }; + + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(result))) } - }; - Ok(ColumnarValue::Array(arr)) + ColumnarValue::Array(array) => { + // NOTE: BooleanArray::from_unary preserves nulls. + let arr: ArrayRef = match array.data_type() { + Null => Arc::new(BooleanArray::new_null(array.len())) as ArrayRef, + + Float64 => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + f64::is_nan, + )) as ArrayRef, + Float32 => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + f32::is_nan, + )) as ArrayRef, + Float16 => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x.is_nan(), + )) as ArrayRef, + + // Non-float numeric arrays are never NaN + Decimal32(_, _) => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |_| false, + )) as ArrayRef, + Decimal64(_, _) => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |_| false, + )) as ArrayRef, + Decimal128(_, _) => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |_| false, + )) as ArrayRef, + Decimal256(_, _) => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |_| false, + )) as ArrayRef, + + Int8 => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |_| false, + )) as ArrayRef, + Int16 => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |_| false, + )) as ArrayRef, + Int32 => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |_| false, + )) as ArrayRef, + Int64 => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |_| false, + )) as ArrayRef, + UInt8 => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |_| false, + )) as ArrayRef, + UInt16 => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |_| false, + )) as ArrayRef, + UInt32 => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |_| false, + )) as ArrayRef, + UInt64 => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |_| false, + )) as ArrayRef, + + other => { + return exec_err!( + "Unsupported data type {other:?} for function {}", + self.name() + ); + } + }; + + Ok(ColumnarValue::Array(arr)) + } + } } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/sqllogictest/test_files/math.slt b/datafusion/sqllogictest/test_files/math.slt index 71a969c751591..2227466fdf254 100644 --- a/datafusion/sqllogictest/test_files/math.slt +++ b/datafusion/sqllogictest/test_files/math.slt @@ -111,12 +111,44 @@ SELECT isnan(1.0::FLOAT), isnan('NaN'::FLOAT), isnan(-'NaN'::FLOAT), isnan(NULL: ---- false true true NULL +# isnan: non-float numeric inputs are never NaN +query BBBB +SELECT isnan(1::INT), isnan(0::INT), isnan(NULL::INT), isnan(123::BIGINT) +---- +false false NULL false + +query BBBB +SELECT isnan(1::INT UNSIGNED), isnan(0::INT UNSIGNED), isnan(NULL::INT UNSIGNED), isnan(255::TINYINT UNSIGNED) +---- +false false NULL false + +query BBBB +SELECT isnan(1::DECIMAL(10,2)), isnan(0::DECIMAL(10,2)), isnan(NULL::DECIMAL(10,2)), isnan(-1::DECIMAL(10,2)) +---- +false false NULL false + # iszero query BBBB SELECT iszero(1.0), iszero(0.0), iszero(-0.0), iszero(NULL) ---- false true true NULL +# iszero: integers / unsigned / decimals +query BBBB +SELECT iszero(1::INT), iszero(0::INT), iszero(NULL::INT), iszero(-1::INT) +---- +false true NULL false + +query BBBB +SELECT iszero(1::INT UNSIGNED), iszero(0::INT UNSIGNED), iszero(NULL::INT UNSIGNED), iszero(255::TINYINT UNSIGNED) +---- +false true NULL false + +query BBBB +SELECT iszero(1::DECIMAL(10,2)), iszero(0::DECIMAL(10,2)), iszero(NULL::DECIMAL(10,2)), iszero(-1::DECIMAL(10,2)) +---- +false true NULL false + # abs: empty argument statement error SELECT abs();