Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 52 additions & 11 deletions datafusion/catalog/src/information_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use crate::{CatalogProviderList, SchemaProvider, TableProvider};
use arrow::array::builder::{BooleanBuilder, UInt8Builder};
use arrow::{
array::{StringBuilder, UInt64Builder},
datatypes::{DataType, Field, Schema, SchemaRef},
datatypes::{DataType, Field, FieldRef, Schema, SchemaRef},
record_batch::RecordBatch,
};
use async_trait::async_trait;
Expand All @@ -34,7 +34,10 @@ use datafusion_common::error::Result;
use datafusion_common::types::NativeType;
use datafusion_execution::TaskContext;
use datafusion_execution::runtime_env::RuntimeEnv;
use datafusion_expr::{AggregateUDF, ScalarUDF, Signature, TypeSignature, WindowUDF};
use datafusion_expr::function::WindowUDFFieldArgs;
use datafusion_expr::{
AggregateUDF, ReturnFieldArgs, ScalarUDF, Signature, TypeSignature, WindowUDF,
};
use datafusion_expr::{TableType, Volatility};
use datafusion_physical_plan::SendableRecordBatchStream;
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
Expand Down Expand Up @@ -421,10 +424,24 @@ fn get_udf_args_and_return_types(
Ok(arg_types
.into_iter()
.map(|arg_types| {
// only handle the function which implemented [`ScalarUDFImpl::return_type`] method
let arg_fields: Vec<FieldRef> = arg_types
.iter()
.enumerate()
.map(|(i, t)| {
Arc::new(Field::new(format!("arg_{i}"), t.clone(), true))
})
.collect();
let scalar_arguments = vec![None; arg_fields.len()];
let return_type = udf
.return_type(&arg_types)
.map(|t| remove_native_type_prefix(&NativeType::from(t)))
.return_field_from_args(ReturnFieldArgs {
arg_fields: &arg_fields,
scalar_arguments: &scalar_arguments,
})
.map(|f| {
remove_native_type_prefix(&NativeType::from(
f.data_type().clone(),
))
})
.ok();
let arg_types = arg_types
.into_iter()
Expand All @@ -447,11 +464,21 @@ fn get_udaf_args_and_return_types(
Ok(arg_types
.into_iter()
.map(|arg_types| {
// only handle the function which implemented [`ScalarUDFImpl::return_type`] method
let arg_fields: Vec<FieldRef> = arg_types
.iter()
.enumerate()
.map(|(i, t)| {
Arc::new(Field::new(format!("arg_{i}"), t.clone(), true))
})
.collect();
let return_type = udaf
.return_type(&arg_types)
.ok()
.map(|t| remove_native_type_prefix(&NativeType::from(t)));
.return_field(&arg_fields)
.map(|f| {
remove_native_type_prefix(&NativeType::from(
f.data_type().clone(),
))
})
.ok();
let arg_types = arg_types
.into_iter()
.map(|t| remove_native_type_prefix(&NativeType::from(t)))
Expand All @@ -473,12 +500,26 @@ fn get_udwf_args_and_return_types(
Ok(arg_types
.into_iter()
.map(|arg_types| {
// only handle the function which implemented [`ScalarUDFImpl::return_type`] method
let arg_fields: Vec<FieldRef> = arg_types
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is interesting that we omitted window functions before 🤔

And somehow this fix doesn't affect any existing tests?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good morning!

If i'm allowed, this could use a dedicated test/tests, at the same time i wanted to keep It lean for reviewers.

If any follow up Is needed, please consider me for it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this PR is fine to merge as is, but I wouldn't be opposed to adding some SLT tests for that case

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, i opened #20090 to track this as it might be a good newcomer task also

Thanks again for your review @Jefffrey

.iter()
.enumerate()
.map(|(i, t)| {
Arc::new(Field::new(format!("arg_{i}"), t.clone(), true))
})
.collect();
let return_type = udwf
.field(WindowUDFFieldArgs::new(&arg_fields, udwf.name()))
.map(|f| {
remove_native_type_prefix(&NativeType::from(
f.data_type().clone(),
))
})
.ok();
let arg_types = arg_types
.into_iter()
.map(|t| remove_native_type_prefix(&NativeType::from(t)))
.collect::<Vec<_>>();
(arg_types, None)
(arg_types, return_type)
})
.collect::<BTreeSet<_>>())
}
Expand Down
11 changes: 7 additions & 4 deletions datafusion/functions/benches/date_trunc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use arrow::datatypes::Field;
use criterion::{Criterion, criterion_group, criterion_main};
use datafusion_common::ScalarValue;
use datafusion_common::config::ConfigOptions;
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs};
use datafusion_expr::{ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs};
use datafusion_functions::datetime::date_trunc;
use rand::Rng;
use rand::rngs::ThreadRng;
Expand Down Expand Up @@ -57,10 +57,13 @@ fn criterion_benchmark(c: &mut Criterion) {
})
.collect::<Vec<_>>();

let return_type = udf
.return_type(&args.iter().map(|arg| arg.data_type()).collect::<Vec<_>>())
let scalar_arguments = vec![None; arg_fields.len()];
let return_field = udf
.return_field_from_args(ReturnFieldArgs {
arg_fields: &arg_fields,
scalar_arguments: &scalar_arguments,
})
.unwrap();
let return_field = Arc::new(Field::new("f", return_type, true));
let config_options = Arc::new(ConfigOptions::default());

b.iter(|| {
Expand Down
26 changes: 10 additions & 16 deletions datafusion/functions/src/datetime/date_trunc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ use arrow::datatypes::{Field, FieldRef};
use datafusion_common::cast::as_primitive_array;
use datafusion_common::types::{NativeType, logical_date, logical_string};
use datafusion_common::{
DataFusionError, Result, ScalarValue, exec_datafusion_err, exec_err,
DataFusionError, Result, ScalarValue, exec_datafusion_err, exec_err, internal_err,
};
use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
use datafusion_expr::{
Expand Down Expand Up @@ -223,27 +223,21 @@ impl ScalarUDFImpl for DateTruncFunc {
&self.signature
}

// keep return_type implementation for information schema generation
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks this should avoid internal_error, good catch

if arg_types[1].is_null() {
Ok(Timestamp(Nanosecond, None))
} else {
Ok(arg_types[1].clone())
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
internal_err!("return_field_from_args should be called instead")
}

fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
let data_types = args
.arg_fields
.iter()
.map(|f| f.data_type())
.cloned()
.collect::<Vec<_>>();
let return_type = self.return_type(&data_types)?;
let field = &args.arg_fields[1];
let return_type = if field.data_type().is_null() {
Timestamp(Nanosecond, None)
} else {
field.data_type().clone()
};
Ok(Arc::new(Field::new(
self.name(),
return_type,
args.arg_fields[1].is_nullable(),
field.is_nullable(),
)))
}

Expand Down