diff --git a/rust/arrow/src/compute/kernels/arity.rs b/rust/arrow/src/compute/kernels/arity.rs new file mode 100644 index 00000000000..11139f83270 --- /dev/null +++ b/rust/arrow/src/compute/kernels/arity.rs @@ -0,0 +1,74 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines kernels suitable to perform operations to primitive arrays. + +use crate::array::{Array, ArrayData, PrimitiveArray}; +use crate::buffer::Buffer; +use crate::datatypes::ArrowPrimitiveType; + +#[inline] +fn into_primitive_array_data( + array: &PrimitiveArray, + buffer: Buffer, +) -> ArrayData { + ArrayData::new( + O::DATA_TYPE, + array.len(), + None, + array.data_ref().null_buffer().cloned(), + 0, + vec![buffer], + vec![], + ) +} + +/// Applies an unary and infalible function to a primitive array. +/// This is the fastest way to perform an operation on a primitive array when +/// the benefits of a vectorized operation outweights the cost of branching nulls and non-nulls. +/// # Implementation +/// This will apply the function for all values, including those on null slots. +/// This implies that the operation must be infalible for any value of the corresponding type +/// or this function may panic. +/// # Example +/// ```rust +/// # use arrow::array::Int32Array; +/// # use arrow::datatypes::Int32Type; +/// # use arrow::compute::kernels::arity::unary; +/// # fn main() { +/// let array = Int32Array::from(vec![Some(5), Some(7), None]); +/// let c = unary::<_, _, Int32Type>(&array, |x| x * 2 + 1); +/// assert_eq!(c, Int32Array::from(vec![Some(11), Some(15), None])); +/// # } +/// ``` +pub fn unary(array: &PrimitiveArray, op: F) -> PrimitiveArray +where + I: ArrowPrimitiveType, + O: ArrowPrimitiveType, + F: Fn(I::Native) -> O::Native, +{ + let values = array.values().iter().map(|v| op(*v)); + // JUSTIFICATION + // Benefit + // ~60% speedup + // Soundness + // `values` is an iterator with a known size because arrays are sized. + let buffer = unsafe { Buffer::from_trusted_len_iter(values) }; + + let data = into_primitive_array_data::<_, O>(array, buffer); + PrimitiveArray::::from(std::sync::Arc::new(data)) +} diff --git a/rust/arrow/src/compute/kernels/cast.rs b/rust/arrow/src/compute/kernels/cast.rs index e2e29620cc5..f0354ab91dc 100644 --- a/rust/arrow/src/compute/kernels/cast.rs +++ b/rust/arrow/src/compute/kernels/cast.rs @@ -39,6 +39,7 @@ use std::str; use std::sync::Arc; use crate::compute::kernels::arithmetic::{divide, multiply}; +use crate::compute::kernels::arity::unary; use crate::datatypes::*; use crate::error::{ArrowError, Result}; use crate::{array::*, compute::take}; @@ -569,45 +570,43 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result { (Time64(_), Int64) => cast_array_data::(array, to_type.clone()), (Date32(DateUnit::Day), Date64(DateUnit::Millisecond)) => { let date_array = array.as_any().downcast_ref::().unwrap(); - let mut b = Date64Builder::new(array.len()); - for i in 0..array.len() { - if array.is_null(i) { - b.append_null()?; - } else { - b.append_value(date_array.value(i) as i64 * MILLISECONDS_IN_DAY)?; - } - } - Ok(Arc::new(b.finish()) as ArrayRef) + let values = + unary::<_, _, Date64Type>(date_array, |x| x as i64 * MILLISECONDS_IN_DAY); + + Ok(Arc::new(values) as ArrayRef) } (Date64(DateUnit::Millisecond), Date32(DateUnit::Day)) => { let date_array = array.as_any().downcast_ref::().unwrap(); - let mut b = Date32Builder::new(array.len()); - for i in 0..array.len() { - if array.is_null(i) { - b.append_null()?; - } else { - b.append_value((date_array.value(i) / MILLISECONDS_IN_DAY) as i32)?; - } - } - Ok(Arc::new(b.finish()) as ArrayRef) + let values = unary::<_, _, Date32Type>(date_array, |x| { + (x / MILLISECONDS_IN_DAY) as i32 + }); + + Ok(Arc::new(values) as ArrayRef) } (Time32(TimeUnit::Second), Time32(TimeUnit::Millisecond)) => { - let time_array = Time32MillisecondArray::from(array.data()); - let mult = - Time32MillisecondArray::from(vec![MILLISECONDS as i32; array.len()]); - let time32_ms = multiply(&time_array, &mult)?; + let time_array = array.as_any().downcast_ref::().unwrap(); - Ok(Arc::new(time32_ms) as ArrayRef) + let values = unary::<_, _, Time32MillisecondType>(time_array, |x| { + x * MILLISECONDS as i32 + }); + + Ok(Arc::new(values) as ArrayRef) } (Time32(TimeUnit::Millisecond), Time32(TimeUnit::Second)) => { - let time_array = Time32SecondArray::from(array.data()); - let divisor = Time32SecondArray::from(vec![MILLISECONDS as i32; array.len()]); - let time32_s = divide(&time_array, &divisor)?; + let time_array = array + .as_any() + .downcast_ref::() + .unwrap(); + + let values = unary::<_, _, Time32SecondType>(time_array, |x| { + x / (MILLISECONDS as i32) + }); - Ok(Arc::new(time32_s) as ArrayRef) + Ok(Arc::new(values) as ArrayRef) } + //(Time32(TimeUnit::Second), Time64(_)) => {}, (Time32(from_unit), Time64(to_unit)) => { let time_array = Int32Array::from(array.data()); // note: (numeric_cast + SIMD multiply) is faster than (cast & multiply) @@ -632,18 +631,24 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result { } } (Time64(TimeUnit::Microsecond), Time64(TimeUnit::Nanosecond)) => { - let time_array = Time64NanosecondArray::from(array.data()); - let mult = Time64NanosecondArray::from(vec![MILLISECONDS; array.len()]); - let time64_ns = multiply(&time_array, &mult)?; + let time_array = array + .as_any() + .downcast_ref::() + .unwrap(); - Ok(Arc::new(time64_ns) as ArrayRef) + let values = + unary::<_, _, Time64NanosecondType>(time_array, |x| x * MILLISECONDS); + Ok(Arc::new(values) as ArrayRef) } (Time64(TimeUnit::Nanosecond), Time64(TimeUnit::Microsecond)) => { - let time_array = Time64MicrosecondArray::from(array.data()); - let divisor = Time64MicrosecondArray::from(vec![MILLISECONDS; array.len()]); - let time64_us = divide(&time_array, &divisor)?; + let time_array = array + .as_any() + .downcast_ref::() + .unwrap(); - Ok(Arc::new(time64_us) as ArrayRef) + let values = + unary::<_, _, Time64MicrosecondType>(time_array, |x| x / MILLISECONDS); + Ok(Arc::new(values) as ArrayRef) } (Time64(from_unit), Time32(to_unit)) => { let time_array = Int64Array::from(array.data()); @@ -652,33 +657,16 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result { let divisor = from_size / to_size; match to_unit { TimeUnit::Second => { - let mut b = Time32SecondBuilder::new(array.len()); - for i in 0..array.len() { - if array.is_null(i) { - b.append_null()?; - } else { - b.append_value( - (time_array.value(i) as i64 / divisor) as i32, - )?; - } - } - - Ok(Arc::new(b.finish()) as ArrayRef) + let values = unary::<_, _, Time32SecondType>(&time_array, |x| { + (x as i64 / divisor) as i32 + }); + Ok(Arc::new(values) as ArrayRef) } TimeUnit::Millisecond => { - // currently can't dedup this builder [ARROW-4164] - let mut b = Time32MillisecondBuilder::new(array.len()); - for i in 0..array.len() { - if array.is_null(i) { - b.append_null()?; - } else { - b.append_value( - (time_array.value(i) as i64 / divisor) as i32, - )?; - } - } - - Ok(Arc::new(b.finish()) as ArrayRef) + let values = unary::<_, _, Time32MillisecondType>(&time_array, |x| { + (x as i64 / divisor) as i32 + }); + Ok(Arc::new(values) as ArrayRef) } _ => unreachable!("array type not supported"), } @@ -806,7 +794,7 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result { } /// Get the time unit as a multiple of a second -fn time_unit_multiple(unit: &TimeUnit) -> i64 { +const fn time_unit_multiple(unit: &TimeUnit) -> i64 { match unit { TimeUnit::Second => 1, TimeUnit::Millisecond => MILLISECONDS, diff --git a/rust/arrow/src/compute/kernels/mod.rs b/rust/arrow/src/compute/kernels/mod.rs index 5ac0f0b0c2e..49283e430bd 100644 --- a/rust/arrow/src/compute/kernels/mod.rs +++ b/rust/arrow/src/compute/kernels/mod.rs @@ -19,6 +19,7 @@ pub mod aggregate; pub mod arithmetic; +pub mod arity; pub mod boolean; pub mod cast; pub mod comparison;