diff --git a/native/spark-expr/benches/cast_from_string.rs b/native/spark-expr/benches/cast_from_string.rs index 990cdec213..a09afae6e1 100644 --- a/native/spark-expr/benches/cast_from_string.rs +++ b/native/spark-expr/benches/cast_from_string.rs @@ -23,45 +23,99 @@ use datafusion_comet_spark_expr::{Cast, EvalMode, SparkCastOptions}; use std::sync::Arc; fn criterion_benchmark(c: &mut Criterion) { - let batch = create_utf8_batch(); + let small_int_batch = create_small_int_string_batch(); + let int_batch = create_int_string_batch(); + let decimal_batch = create_decimal_string_batch(); let expr = Arc::new(Column::new("a", 0)); + + for (mode, mode_name) in [ + (EvalMode::Legacy, "legacy"), + (EvalMode::Ansi, "ansi"), + (EvalMode::Try, "try"), + ] { + let spark_cast_options = SparkCastOptions::new(mode, "", false); + let cast_to_i8 = Cast::new(expr.clone(), DataType::Int8, spark_cast_options.clone()); + let cast_to_i16 = Cast::new(expr.clone(), DataType::Int16, spark_cast_options.clone()); + let cast_to_i32 = Cast::new(expr.clone(), DataType::Int32, spark_cast_options.clone()); + let cast_to_i64 = Cast::new(expr.clone(), DataType::Int64, spark_cast_options); + + let mut group = c.benchmark_group(format!("cast_string_to_int/{}", mode_name)); + group.bench_function("i8", |b| { + b.iter(|| cast_to_i8.evaluate(&small_int_batch).unwrap()); + }); + group.bench_function("i16", |b| { + b.iter(|| cast_to_i16.evaluate(&small_int_batch).unwrap()); + }); + group.bench_function("i32", |b| { + b.iter(|| cast_to_i32.evaluate(&int_batch).unwrap()); + }); + group.bench_function("i64", |b| { + b.iter(|| cast_to_i64.evaluate(&int_batch).unwrap()); + }); + group.finish(); + } + + // Benchmark decimal truncation (Legacy mode only) let spark_cast_options = SparkCastOptions::new(EvalMode::Legacy, "", false); - let cast_string_to_i8 = Cast::new(expr.clone(), DataType::Int8, spark_cast_options.clone()); - let cast_string_to_i16 = Cast::new(expr.clone(), DataType::Int16, spark_cast_options.clone()); - let cast_string_to_i32 = Cast::new(expr.clone(), DataType::Int32, spark_cast_options.clone()); - let cast_string_to_i64 = Cast::new(expr, DataType::Int64, spark_cast_options); + let cast_to_i32 = Cast::new(expr.clone(), DataType::Int32, spark_cast_options.clone()); + let cast_to_i64 = Cast::new(expr.clone(), DataType::Int64, spark_cast_options); - let mut group = c.benchmark_group("cast_string_to_int"); - group.bench_function("cast_string_to_i8", |b| { - b.iter(|| cast_string_to_i8.evaluate(&batch).unwrap()); + let mut group = c.benchmark_group("cast_string_to_int/legacy_decimals"); + group.bench_function("i32", |b| { + b.iter(|| cast_to_i32.evaluate(&decimal_batch).unwrap()); }); - group.bench_function("cast_string_to_i16", |b| { - b.iter(|| cast_string_to_i16.evaluate(&batch).unwrap()); - }); - group.bench_function("cast_string_to_i32", |b| { - b.iter(|| cast_string_to_i32.evaluate(&batch).unwrap()); - }); - group.bench_function("cast_string_to_i64", |b| { - b.iter(|| cast_string_to_i64.evaluate(&batch).unwrap()); + group.bench_function("i64", |b| { + b.iter(|| cast_to_i64.evaluate(&decimal_batch).unwrap()); }); + group.finish(); } -// Create UTF8 batch with strings representing ints, floats, nulls -fn create_utf8_batch() -> RecordBatch { +/// Create batch with small integer strings that fit in i8 range (for i8/i16 benchmarks) +fn create_small_int_string_batch() -> RecordBatch { let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)])); let mut b = StringBuilder::new(); for i in 0..1000 { if i % 10 == 0 { b.append_null(); - } else if i % 2 == 0 { - b.append_value(format!("{}", rand::random::())); } else { - b.append_value(format!("{}", rand::random::())); + b.append_value(format!("{}", rand::random::())); } } let array = b.finish(); + RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap() +} - RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap() +/// Create batch with valid integer strings (works for all eval modes) +fn create_int_string_batch() -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)])); + let mut b = StringBuilder::new(); + for i in 0..1000 { + if i % 10 == 0 { + b.append_null(); + } else { + b.append_value(format!("{}", rand::random::())); + } + } + let array = b.finish(); + RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap() +} + +/// Create batch with decimal strings (for Legacy mode decimal truncation) +fn create_decimal_string_batch() -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)])); + let mut b = StringBuilder::new(); + for i in 0..1000 { + if i % 10 == 0 { + b.append_null(); + } else { + // Generate integers with decimal portions to test truncation + let int_part: i32 = rand::random(); + let dec_part: u32 = rand::random::() % 1000; + b.append_value(format!("{}.{}", int_part, dec_part)); + } + } + let array = b.finish(); + RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap() } fn config() -> Criterion {