diff --git a/CMakeLists.txt b/CMakeLists.txt index 2d286cc75..7bd2316a9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -591,12 +591,14 @@ add_library(duckdb_java SHARED src/jni/bindings_common.cpp src/jni/bindings_data_chunk.cpp src/jni/bindings_logical_type.cpp + src/jni/bindings_scalar_function.cpp src/jni/bindings_validity.cpp src/jni/bindings_vector.cpp src/jni/config.cpp src/jni/duckdb_java.cpp src/jni/functions.cpp src/jni/refs.cpp + src/jni/scalar_functions.cpp src/jni/types.cpp src/jni/util.cpp ${DUCKDB_SRC_FILES}) diff --git a/CMakeLists.txt.in b/CMakeLists.txt.in index 3d6f041d6..751076c19 100644 --- a/CMakeLists.txt.in +++ b/CMakeLists.txt.in @@ -109,12 +109,14 @@ add_library(duckdb_java SHARED src/jni/bindings_common.cpp src/jni/bindings_data_chunk.cpp src/jni/bindings_logical_type.cpp + src/jni/bindings_scalar_function.cpp src/jni/bindings_validity.cpp src/jni/bindings_vector.cpp src/jni/config.cpp src/jni/duckdb_java.cpp src/jni/functions.cpp src/jni/refs.cpp + src/jni/scalar_functions.cpp src/jni/types.cpp src/jni/util.cpp ${DUCKDB_SRC_FILES}) diff --git a/README.md b/README.md index 4042a68db..3f1b60c52 100644 --- a/README.md +++ b/README.md @@ -20,3 +20,5 @@ This optionally takes an argument to only run a single test, for example: ``` java -cp "build/release/duckdb_jdbc_tests.jar:build/release/duckdb_jdbc.jar" org/duckdb/TestDuckDBJDBC test_valid_but_local_config_throws_exception ``` + +Scalar function usage examples: [UDF.MD](UDF.MD) diff --git a/UDF.MD b/UDF.MD new file mode 100644 index 000000000..f8407c0e2 --- /dev/null +++ b/UDF.MD @@ -0,0 +1,202 @@ +# Java Scalar Functions (UDF) + +Use `DuckDBFunctions.scalarFunction()` to build and register scalar functions. +`register(java.sql.Connection)` returns a `DuckDBRegisteredFunction` with metadata about the registered scalar function. + +Registered functions are also tracked in a Java-side registry exposed by `DuckDBDriver.registeredFunctions()`. +This registry is bookkeeping for functions registered through the JDBC API, not an authoritative view of the DuckDB catalog. +`DuckDBDriver.clearFunctionsRegistry()` clears only the Java-side registry and does not de-register functions from DuckDB. + +## Recommended API (Functional Interfaces) + +Use these overloads for simple functions: + +- `withFunction(Supplier)` for zero arguments +- `withFunction(Function)` for one argument +- `withFunction(BiFunction)` for two arguments +- `withIntFunction(IntUnaryOperator | IntBinaryOperator)` for `INTEGER` unary/binary functions +- `withLongFunction(LongUnaryOperator | LongBinaryOperator)` for `BIGINT` unary/binary functions +- `withDoubleFunction(DoubleUnaryOperator | DoubleBinaryOperator)` for `DOUBLE` unary/binary functions + +### Simple example (`withIntFunction`) + +```java +try (Connection conn = DriverManager.getConnection("jdbc:duckdb:")) { + DuckDBRegisteredFunction function = DuckDBFunctions.scalarFunction() + .withName("java_add_one") + .withParameter(Integer.class) + .withReturnType(Integer.class) + .withIntFunction(x -> x + 1) + .register(conn); +} +``` + +```sql +SELECT java_add_one(41); +``` + +### Slightly more complex example (`withDoubleFunction`) + +```java +try (Connection conn = DriverManager.getConnection("jdbc:duckdb:")) { + DuckDBFunctions.scalarFunction() + .withName("java_weighted_sum") + .withParameter(Double.class) + .withParameter(Double.class) + .withReturnType(Double.class) + .withDoubleFunction((x, w) -> x * w + 10.0) + .register(conn); +} +``` + +```sql +SELECT java_weighted_sum(2.5, 4.0); +``` + +Behavior: + +- `Function` and `BiFunction` callbacks receive `null` for NULL inputs; implement null handling in Java callback logic. +- `withIntFunction(...)`, `withLongFunction(...)`, and `withDoubleFunction(...)` run with null propagation enabled. +- For `withVectorizedFunction(...)`, use `ctx.propagateNulls(true)` when you want stream-level null skipping and automatic NULL output. +- For `Supplier`, returning `null` writes NULL output. +- `Function` and `BiFunction` are fixed arity only (no varargs). + +Runtime error model: + +- Callback-time reader/writer/context type and value failures throw `DuckDBFunctionException`. +- Invalid row/column indexes throw `IndexOutOfBoundsException`. +- `SQLException` remains for registration-time API usage and type declaration/validation. + +## Type declaration and mapping + +`withParameter(...)` and `withReturnType(...)` accept: + +- `Class` +- `DuckDBColumnType` +- `DuckDBLogicalType` + +Common class mappings include: + +- `Integer` -> `INTEGER` +- `Long` -> `BIGINT` +- `String` -> `VARCHAR` +- `BigDecimal` -> `DECIMAL` +- `BigInteger` -> `HUGEINT` +- `LocalDate` and `java.sql.Date` -> `DATE` +- `LocalDateTime`, `java.sql.Timestamp`, and `java.util.Date` -> `TIMESTAMP` + +Notes: + +- `UHUGEINT` is supported through explicit `DuckDBColumnType.UHUGEINT`/`DuckDBLogicalType` declarations. +- Java class auto-mapping for `BigInteger` remains `HUGEINT`. + +Use `DuckDBLogicalType.decimal(width, scale)` for explicit DECIMAL precision/scale. + +## Varargs + +Declare varargs type with `withVarArgs(DuckDBLogicalType)`. + +For functional varargs, use `withVarArgsFunction(Function)`: + +```java +try (Connection conn = DriverManager.getConnection("jdbc:duckdb:"); + DuckDBLogicalType intType = DuckDBLogicalType.of(DuckDBColumnType.INTEGER)) { + DuckDBFunctions.scalarFunction() + .withName("java_sum_varargs") + .withParameter(Integer.class) // fixed argument(s) + .withVarArgs(intType) // variadic argument type + .withReturnType(Integer.class) + .withVarArgsFunction(args -> { + int sum = 0; + for (Object arg : args) { + sum += (Integer) arg; + } + return sum; + }) + .register(conn); +} +``` + +```sql +SELECT java_sum_varargs(1, 2, 3, 4); +``` + +Notes: + +- `withFunction(Function)` and `withFunction(BiFunction)` reject varargs. +- `withVarArgsFunction(...)` requires `withVarArgs(...)` first. + +## Builder methods + +- `withName(String)` +- `withParameter(Class | DuckDBColumnType | DuckDBLogicalType)` +- `withParameters(Class...)` +- `withReturnType(Class | DuckDBColumnType | DuckDBLogicalType)` +- `withFunction(Supplier | Function | BiFunction)` +- `withIntFunction(IntUnaryOperator | IntBinaryOperator)` +- `withLongFunction(LongUnaryOperator | LongBinaryOperator)` +- `withDoubleFunction(DoubleUnaryOperator | DoubleBinaryOperator)` +- `withVarArgs(DuckDBLogicalType)` +- `withVarArgsFunction(Function)` +- `withVectorizedFunction(DuckDBScalarFunction)` +- `withVolatile()` +- `withSpecialHandling()` +- `register(java.sql.Connection)` + +## Registered Function Metadata And Registry + +`DuckDBRegisteredFunction` exposes immutable metadata about the successful registration result: + +- `name()` +- `functionKind()` +- `isScalar()` +- parameter and return type metadata +- callback and flags used at registration time + +To inspect Java-side registrations: + +```java +List functions = DuckDBDriver.registeredFunctions(); +``` + +The returned list is read-only. Duplicate function names may appear in the registry. + +## Advanced API (`DuckDBScalarFunction`) + +Use `withVectorizedFunction(...)` for full context control through `DuckDBScalarContext`. + +Example with multiple input types (`TIMESTAMP`, `VARCHAR`, `DOUBLE`) and `VARCHAR` output: + +```java +try (Connection conn = DriverManager.getConnection("jdbc:duckdb:"); + DuckDBLogicalType tsType = DuckDBLogicalType.of(DuckDBColumnType.TIMESTAMP); + DuckDBLogicalType strType = DuckDBLogicalType.of(DuckDBColumnType.VARCHAR); + DuckDBLogicalType dblType = DuckDBLogicalType.of(DuckDBColumnType.DOUBLE)) { + DuckDBFunctions.scalarFunction() + .withName("java_event_label") + .withParameter(tsType) + .withParameter(strType) + .withParameter(dblType) + .withReturnType(strType) + .withVectorizedFunction(ctx -> { + ctx.propagateNulls(true).stream().forEachOrdered(row -> { + String value = row.getLocalDateTime(0) + " | " + + row.getString(1).trim().toUpperCase() + + " | " + row.getDouble(2); + row.setString(value); + }); + }) + .register(conn); +} +``` + +```sql +SELECT java_event_label(TIMESTAMP '2026-04-04 12:00:00', 'launch', 4.5); +``` + +Lifecycle rules: + +- `DuckDBScalarContext`, `DuckDBScalarRow`, `DuckDBReadableVector`, and `DuckDBWritableVector` are valid only during callback execution. +- `DuckDBReadableVector` and `DuckDBWritableVector` are abstract callback runtime types (not interfaces). +- Write exactly one output value per input row for each callback invocation. +- With `propagateNulls(true)`, `DuckDBScalarContext.stream()` skips rows that contain NULL in any input column and writes NULL to the output for those rows. diff --git a/duckdb_java.def b/duckdb_java.def index 68ff3031b..7f340311b 100644 --- a/duckdb_java.def +++ b/duckdb_java.def @@ -52,7 +52,18 @@ Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1set_1schema Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1startup Java_org_duckdb_DuckDBBindings_duckdb_1vector_1size +Java_org_duckdb_DuckDBBindings_duckdb_1create_1scalar_1function +Java_org_duckdb_DuckDBBindings_duckdb_1destroy_1scalar_1function +Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1name +Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1add_1parameter +Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1return_1type +Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1varargs +Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1volatile +Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1special_1handling +Java_org_duckdb_DuckDBBindings_duckdb_1register_1scalar_1function +Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1error Java_org_duckdb_DuckDBBindings_duckdb_1create_1logical_1type +Java_org_duckdb_DuckDBBindings_duckdb_1create_1decimal_1type Java_org_duckdb_DuckDBBindings_duckdb_1get_1type_1id Java_org_duckdb_DuckDBBindings_duckdb_1decimal_1width Java_org_duckdb_DuckDBBindings_duckdb_1decimal_1scale @@ -74,6 +85,8 @@ Java_org_duckdb_DuckDBBindings_duckdb_1vector_1get_1data Java_org_duckdb_DuckDBBindings_duckdb_1vector_1get_1validity Java_org_duckdb_DuckDBBindings_duckdb_1vector_1ensure_1validity_1writable Java_org_duckdb_DuckDBBindings_duckdb_1vector_1assign_1string_1element_1len +Java_org_duckdb_DuckDBBindings_duckdb_1vector_1get_1string +Java_org_duckdb_DuckDBBindings_duckdb_1vector_1get_1string__Ljava_nio_ByteBuffer_2J Java_org_duckdb_DuckDBBindings_duckdb_1validity_1row_1is_1valid Java_org_duckdb_DuckDBBindings_duckdb_1validity_1set_1row_1validity Java_org_duckdb_DuckDBBindings_duckdb_1list_1vector_1get_1child @@ -98,6 +111,7 @@ Java_org_duckdb_DuckDBBindings_duckdb_1appender_1column_1count Java_org_duckdb_DuckDBBindings_duckdb_1appender_1column_1type Java_org_duckdb_DuckDBBindings_duckdb_1append_1data_1chunk Java_org_duckdb_DuckDBBindings_duckdb_1append_1default_1to_1chunk +Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1function duckdb_adbc_init duckdb_add_aggregate_function_to_set @@ -306,7 +320,6 @@ duckdb_destroy_selection_vector duckdb_destroy_table_function duckdb_destroy_task_state duckdb_destroy_value -duckdb_destroy_vector duckdb_disconnect duckdb_double_to_decimal duckdb_double_to_hugeint diff --git a/duckdb_java.exp b/duckdb_java.exp index 6b6cb687d..b0611b6f6 100644 --- a/duckdb_java.exp +++ b/duckdb_java.exp @@ -49,7 +49,21 @@ _Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1set_1schema _Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1startup _Java_org_duckdb_DuckDBBindings_duckdb_1vector_1size +_Java_org_duckdb_DuckDBBindings_duckdb_1create_1scalar_1function +_Java_org_duckdb_DuckDBBindings_duckdb_1destroy_1scalar_1function +_Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1name +_Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1add_1parameter +_Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1return_1type +_Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1varargs +_Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1volatile +_Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1special_1handling +_Java_org_duckdb_DuckDBBindings_duckdb_1register_1scalar_1function +_Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1error +_Java_org_duckdb_DuckDBBindings_duckdb_1vector_1get_1string +_Java_org_duckdb_DuckDBBindings_duckdb_1vector_1get_1string__Ljava_nio_ByteBuffer_2J _Java_org_duckdb_DuckDBBindings_duckdb_1create_1logical_1type +_Java_org_duckdb_DuckDBBindings_duckdb_1create_1decimal_1type +_Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1function _Java_org_duckdb_DuckDBBindings_duckdb_1get_1type_1id _Java_org_duckdb_DuckDBBindings_duckdb_1decimal_1width _Java_org_duckdb_DuckDBBindings_duckdb_1decimal_1scale diff --git a/duckdb_java.map b/duckdb_java.map index 7ed2d7233..0fc9c6fc7 100644 --- a/duckdb_java.map +++ b/duckdb_java.map @@ -51,7 +51,20 @@ DUCKDB_JAVA { Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1startup; Java_org_duckdb_DuckDBBindings_duckdb_1vector_1size; + Java_org_duckdb_DuckDBBindings_duckdb_1create_1scalar_1function; + Java_org_duckdb_DuckDBBindings_duckdb_1destroy_1scalar_1function; + Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1name; + Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1add_1parameter; + Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1return_1type; + Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1varargs; + Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1volatile; + Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1special_1handling; + Java_org_duckdb_DuckDBBindings_duckdb_1register_1scalar_1function; + Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1error; + Java_org_duckdb_DuckDBBindings_duckdb_1vector_1get_1string; + Java_org_duckdb_DuckDBBindings_duckdb_1vector_1get_1string__Ljava_nio_ByteBuffer_2J; Java_org_duckdb_DuckDBBindings_duckdb_1create_1logical_1type; + Java_org_duckdb_DuckDBBindings_duckdb_1create_1decimal_1type; Java_org_duckdb_DuckDBBindings_duckdb_1get_1type_1id; Java_org_duckdb_DuckDBBindings_duckdb_1decimal_1width; Java_org_duckdb_DuckDBBindings_duckdb_1decimal_1scale; @@ -97,6 +110,7 @@ DUCKDB_JAVA { Java_org_duckdb_DuckDBBindings_duckdb_1appender_1column_1type; Java_org_duckdb_DuckDBBindings_duckdb_1append_1data_1chunk; Java_org_duckdb_DuckDBBindings_duckdb_1append_1default_1to_1chunk; + Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1function; duckdb_adbc_init; duckdb_add_aggregate_function_to_set; diff --git a/src/jni/bindings_logical_type.cpp b/src/jni/bindings_logical_type.cpp index 63bc64cc0..302bcb160 100644 --- a/src/jni/bindings_logical_type.cpp +++ b/src/jni/bindings_logical_type.cpp @@ -36,6 +36,26 @@ JNIEXPORT jobject JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1create_1logical return make_ptr_buf(env, lt); } +/* + * Class: org_duckdb_DuckDBBindings + * Method: duckdb_create_decimal_type + * Signature: (II)Ljava/nio/ByteBuffer; + */ +JNIEXPORT jobject JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1create_1decimal_1type(JNIEnv *env, jclass, jint width, + jint scale) { + + if (width < 1 || width > 38) { + env->ThrowNew(J_SQLException, "Invalid decimal width: expected 1..38"); + return nullptr; + } + if (scale < 0 || scale > width) { + env->ThrowNew(J_SQLException, "Invalid decimal scale: expected 0..width"); + return nullptr; + } + duckdb_logical_type lt = duckdb_create_decimal_type(static_cast(width), static_cast(scale)); + return make_ptr_buf(env, lt); +} + /* * Class: org_duckdb_DuckDBBindings * Method: duckdb_get_type_id diff --git a/src/jni/bindings_scalar_function.cpp b/src/jni/bindings_scalar_function.cpp new file mode 100644 index 000000000..5fb9c0b7a --- /dev/null +++ b/src/jni/bindings_scalar_function.cpp @@ -0,0 +1,169 @@ +#include "bindings.hpp" +#include "functions.hpp" +#include "holders.hpp" +#include "refs.hpp" +#include "scalar_functions.hpp" +#include "util.hpp" + +static duckdb_scalar_function scalar_function_buf_to_scalar_function(JNIEnv *env, jobject scalar_function_buf) { + + if (scalar_function_buf == nullptr) { + env->ThrowNew(J_SQLException, "Invalid scalar function buffer"); + return nullptr; + } + + duckdb_scalar_function scalar_function = + reinterpret_cast(env->GetDirectBufferAddress(scalar_function_buf)); + if (scalar_function == nullptr) { + env->ThrowNew(J_SQLException, "Invalid scalar function"); + return nullptr; + } + + return scalar_function; +} + +static duckdb_function_info function_info_buf_to_function_info(JNIEnv *env, jobject function_info_buf) { + if (function_info_buf == nullptr) { + env->ThrowNew(J_SQLException, "Invalid scalar function info buffer"); + return nullptr; + } + + auto function_info = reinterpret_cast(env->GetDirectBufferAddress(function_info_buf)); + if (function_info == nullptr) { + env->ThrowNew(J_SQLException, "Invalid scalar function info"); + return nullptr; + } + return function_info; +} + +JNIEXPORT jobject JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1create_1scalar_1function(JNIEnv *env, jclass) { + return make_ptr_buf(env, duckdb_create_scalar_function()); +} + +JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1destroy_1scalar_1function(JNIEnv *env, jclass, + jobject scalar_function) { + auto function = scalar_function_buf_to_scalar_function(env, scalar_function); + if (env->ExceptionCheck()) { + return; + } + duckdb_destroy_scalar_function(&function); +} + +JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1name(JNIEnv *env, jclass, + jobject scalar_function, + jbyteArray name) { + auto function = scalar_function_buf_to_scalar_function(env, scalar_function); + if (env->ExceptionCheck()) { + return; + } + if (name == nullptr) { + env->ThrowNew(J_SQLException, "Invalid scalar function name"); + return; + } + auto function_name = jbyteArray_to_string(env, name); + if (env->ExceptionCheck()) { + return; + } + duckdb_scalar_function_set_name(function, function_name.c_str()); +} + +JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1add_1parameter(JNIEnv *env, jclass, + jobject scalar_function, + jobject logical_type) { + auto function = scalar_function_buf_to_scalar_function(env, scalar_function); + if (env->ExceptionCheck()) { + return; + } + auto type = logical_type_buf_to_logical_type(env, logical_type); + if (env->ExceptionCheck()) { + return; + } + duckdb_scalar_function_add_parameter(function, type); +} + +JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1return_1type( + JNIEnv *env, jclass, jobject scalar_function, jobject logical_type) { + auto function = scalar_function_buf_to_scalar_function(env, scalar_function); + if (env->ExceptionCheck()) { + return; + } + auto type = logical_type_buf_to_logical_type(env, logical_type); + if (env->ExceptionCheck()) { + return; + } + duckdb_scalar_function_set_return_type(function, type); +} + +JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1varargs(JNIEnv *env, jclass, + jobject scalar_function, + jobject logical_type) { + auto function = scalar_function_buf_to_scalar_function(env, scalar_function); + if (env->ExceptionCheck()) { + return; + } + auto type = logical_type_buf_to_logical_type(env, logical_type); + if (env->ExceptionCheck()) { + return; + } + duckdb_scalar_function_set_varargs(function, type); +} + +JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1volatile(JNIEnv *env, jclass, + jobject scalar_function) { + auto function = scalar_function_buf_to_scalar_function(env, scalar_function); + if (env->ExceptionCheck()) { + return; + } + duckdb_scalar_function_set_volatile(function); +} + +JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1special_1handling( + JNIEnv *env, jclass, jobject scalar_function) { + auto function = scalar_function_buf_to_scalar_function(env, scalar_function); + if (env->ExceptionCheck()) { + return; + } + duckdb_scalar_function_set_special_handling(function); +} + +JNIEXPORT jint JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1register_1scalar_1function(JNIEnv *env, jclass, + jobject connection, + jobject scalar_function) { + auto conn = conn_ref_buf_to_conn(env, connection); + if (env->ExceptionCheck()) { + return static_cast(DuckDBError); + } + auto function = scalar_function_buf_to_scalar_function(env, scalar_function); + if (env->ExceptionCheck()) { + return static_cast(DuckDBError); + } + return static_cast(duckdb_register_scalar_function(conn, function)); +} + +JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1function( + JNIEnv *env, jclass, jobject scalar_function_buf, jobject function_j) { + try { + scalar_function_set_function(env, scalar_function_buf, function_j); + } catch (const std::exception &e) { + duckdb::ErrorData error(e); + ThrowJNI(env, error.Message().c_str()); + } +} + +JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1error(JNIEnv *env, jclass, + jobject function_info_buf, + jbyteArray error) { + auto function_info = function_info_buf_to_function_info(env, function_info_buf); + if (env->ExceptionCheck()) { + return; + } + if (error == nullptr) { + env->ThrowNew(J_SQLException, "Invalid scalar function error"); + return; + } + auto error_message = jbyteArray_to_string(env, error); + if (env->ExceptionCheck()) { + return; + } + duckdb_scalar_function_set_error(function_info, error_message.c_str()); +} diff --git a/src/jni/bindings_vector.cpp b/src/jni/bindings_vector.cpp index 56876d68e..f8f37bf7c 100644 --- a/src/jni/bindings_vector.cpp +++ b/src/jni/bindings_vector.cpp @@ -21,6 +21,40 @@ static duckdb_vector vector_buf_to_vector(JNIEnv *env, jobject vector_buf) { return vector; } +/* + * Class: org_duckdb_DuckDBBindings + * Method: duckdb_vector_get_string + * Signature: (Ljava/nio/ByteBuffer;J)[B + */ +JNIEXPORT jbyteArray JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1vector_1get_1string(JNIEnv *env, jclass, + jobject vector_data, + jlong row) { + + if (vector_data == nullptr) { + env->ThrowNew(J_SQLException, "Invalid vector data buffer"); + return nullptr; + } + auto data = reinterpret_cast(env->GetDirectBufferAddress(vector_data)); + if (data == nullptr) { + env->ThrowNew(J_SQLException, "Invalid vector data"); + return nullptr; + } + idx_t row_idx = jlong_to_idx(env, row); + if (env->ExceptionCheck()) { + return nullptr; + } + auto &string_value = data[row_idx]; + auto string_len = duckdb_string_t_length(string_value); + auto string_ptr = duckdb_string_t_data(&string_value); + return make_jbyteArray(env, string_ptr, string_len); +} + +extern "C" JNIEXPORT jbyteArray JNICALL +Java_org_duckdb_DuckDBBindings_duckdb_1vector_1get_1string__Ljava_nio_ByteBuffer_2J(JNIEnv *env, jclass clazz, + jobject vector_data, jlong row) { + return Java_org_duckdb_DuckDBBindings_duckdb_1vector_1get_1string(env, clazz, vector_data, row); +} + /* * Class: org_duckdb_DuckDBBindings * Method: duckdb_create_vector diff --git a/src/jni/duckdb_java.cpp b/src/jni/duckdb_java.cpp index 436dda5c4..f8a5b1964 100644 --- a/src/jni/duckdb_java.cpp +++ b/src/jni/duckdb_java.cpp @@ -24,6 +24,7 @@ extern "C" { #include #include +#include using namespace duckdb; using namespace std; diff --git a/src/jni/scalar_functions.cpp b/src/jni/scalar_functions.cpp new file mode 100644 index 000000000..054561ac1 --- /dev/null +++ b/src/jni/scalar_functions.cpp @@ -0,0 +1,271 @@ +extern "C" { +#include "duckdb.h" +} + +#include "holders.hpp" +#include "refs.hpp" +#include "scalar_functions.hpp" +#include "util.hpp" + +#include +#include + +class ScalarFunctionException : public std::exception { +public: + explicit ScalarFunctionException(std::string message_p) : message(std::move(message_p)) { + } + + const char *what() const noexcept override { + return message.c_str(); + } + +private: + std::string message; +}; + +struct JNIEnvGuard { + JavaVM *vm; + JNIEnv *env; + bool detach_when_done; + + explicit JNIEnvGuard(JavaVM *vm_p) : vm(vm_p), env(nullptr), detach_when_done(false) { + if (!vm) { + throw ScalarFunctionException("JVM is not available"); + } + auto get_env_status = vm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6); + if (get_env_status == JNI_OK) { + return; + } + if (get_env_status != JNI_EDETACHED) { + throw ScalarFunctionException("Failed to get JNI environment"); + } + auto attach_status = vm->AttachCurrentThread(reinterpret_cast(&env), nullptr); + if (attach_status != JNI_OK || !env) { + throw ScalarFunctionException("Failed to attach current thread to JVM"); + } + detach_when_done = true; + } + + ~JNIEnvGuard() { + if (detach_when_done && vm) { + vm->DetachCurrentThread(); + } + } +}; + +struct JavaScalarFunctionState { + JavaVM *vm; + jobject callback; + jmethodID apply_method; + + JavaScalarFunctionState(JavaVM *vm_p, jobject callback_p, jmethodID apply_method_p) + : vm(vm_p), callback(callback_p), apply_method(apply_method_p) { + } + + ~JavaScalarFunctionState() { + if (!vm || !callback) { + return; + } + try { + JNIEnvGuard env_guard(vm); + env_guard.env->DeleteGlobalRef(callback); + } catch (...) { + // noop in destructor + } + } +}; + +struct JavaScalarFunctionLocalState { + JavaVM *vm; + JNIEnv *env; + bool detach_when_done; +}; + +static duckdb_scalar_function scalar_function_buf_to_scalar_function(JNIEnv *env, jobject scalar_function_buf) { + if (scalar_function_buf == nullptr) { + env->ThrowNew(J_SQLException, "Invalid scalar function buffer"); + return nullptr; + } + + auto scalar_function = reinterpret_cast(env->GetDirectBufferAddress(scalar_function_buf)); + if (scalar_function == nullptr) { + env->ThrowNew(J_SQLException, "Invalid scalar function"); + return nullptr; + } + return scalar_function; +} + +static std::string consume_java_exception_message(JNIEnv *env) { + auto throwable = env->ExceptionOccurred(); + if (!throwable) { + return "Java exception"; + } + env->ExceptionClear(); + + std::string message = "Java exception"; + auto msg = (jstring)env->CallObjectMethod(throwable, J_Throwable_getMessage); + if (!env->ExceptionCheck() && msg) { + message = jstring_to_string(env, msg); + } + if (env->ExceptionCheck()) { + env->ExceptionClear(); + } + + env->DeleteLocalRef(throwable); + if (msg) { + env->DeleteLocalRef(msg); + } + + return message; +} + +static void get_or_attach_jni_env(JavaVM *vm, JNIEnv *&env, bool &detach_when_done) { + if (!vm) { + throw ScalarFunctionException("JVM is not available"); + } + + detach_when_done = false; + auto get_env_status = vm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6); + if (get_env_status == JNI_OK) { + return; + } + if (get_env_status != JNI_EDETACHED) { + throw ScalarFunctionException("Failed to get JNI environment"); + } + + auto attach_status = vm->AttachCurrentThread(reinterpret_cast(&env), nullptr); + if (attach_status != JNI_OK || !env) { + throw ScalarFunctionException("Failed to attach current thread to JVM"); + } + detach_when_done = true; +} + +static void execute_java_scalar_function(JNIEnv *env, JavaScalarFunctionState &state, duckdb_function_info info, + duckdb_data_chunk input, duckdb_vector output) { + jobject function_info_buf = make_ptr_buf(env, info); + jobject input_chunk_buf = make_ptr_buf(env, input); + jobject output_vector_buf = make_ptr_buf(env, output); + env->CallVoidMethod(state.callback, state.apply_method, function_info_buf, input_chunk_buf, output_vector_buf); + if (function_info_buf) { + env->DeleteLocalRef(function_info_buf); + } + if (input_chunk_buf) { + env->DeleteLocalRef(input_chunk_buf); + } + if (output_vector_buf) { + env->DeleteLocalRef(output_vector_buf); + } + + if (env->ExceptionCheck()) { + throw ScalarFunctionException("Java scalar function wrapper threw exception: " + + consume_java_exception_message(env)); + } +} + +static void destroy_java_scalar_function_state(void *extra_info); +static void init_java_scalar_function_capi(duckdb_init_info info); +static void execute_java_scalar_function_capi(duckdb_function_info info, duckdb_data_chunk input, duckdb_vector output); + +static jmethodID get_scalar_callback_method(JNIEnv *env, jobject function_j, const char *signature, + const char *method_name, const char *error_message) { + auto callback_class = env->GetObjectClass(function_j); + auto apply_method = env->GetMethodID(callback_class, method_name, signature); + env->DeleteLocalRef(callback_class); + if (!apply_method || env->ExceptionCheck()) { + consume_java_exception_message(env); + throw ScalarFunctionException(error_message); + } + return apply_method; +} + +void scalar_function_set_function(JNIEnv *env, jobject scalar_function_buf, jobject function_j) { + auto scalar_function = scalar_function_buf_to_scalar_function(env, scalar_function_buf); + if (env->ExceptionCheck()) { + return; + } + if (!function_j) { + throw ScalarFunctionException("Invalid scalar function callback"); + } + + JavaVM *vm = nullptr; + if (env->GetJavaVM(&vm) != JNI_OK || !vm) { + throw ScalarFunctionException("Failed to get JVM reference"); + } + + auto callback_ref = env->NewGlobalRef(function_j); + if (!callback_ref) { + throw ScalarFunctionException("Could not create global reference for scalar function callback"); + } + + try { + auto apply_method = get_scalar_callback_method( + env, function_j, "(Ljava/nio/ByteBuffer;Ljava/nio/ByteBuffer;Ljava/nio/ByteBuffer;)V", "execute", + "Could not find execute(ByteBuffer, ByteBuffer, ByteBuffer) on scalar function callback"); + auto state = new JavaScalarFunctionState(vm, callback_ref, apply_method); + duckdb_scalar_function_set_extra_info(scalar_function, state, destroy_java_scalar_function_state); + duckdb_scalar_function_set_function(scalar_function, execute_java_scalar_function_capi); + duckdb_scalar_function_set_init(scalar_function, init_java_scalar_function_capi); + } catch (...) { + env->DeleteGlobalRef(callback_ref); + throw; + } +} + +static void destroy_java_scalar_function_state(void *extra_info) { + if (!extra_info) { + return; + } + delete reinterpret_cast(extra_info); +} + +static void destroy_java_scalar_function_local_state(void *state_ptr) { + if (!state_ptr) { + return; + } + + auto state = reinterpret_cast(state_ptr); + if (state->detach_when_done && state->vm) { + state->vm->DetachCurrentThread(); + } + delete state; +} + +static void init_java_scalar_function_capi(duckdb_init_info info) { + JavaScalarFunctionLocalState *local_state = nullptr; + try { + auto state = reinterpret_cast(duckdb_scalar_function_init_get_extra_info(info)); + if (!state) { + duckdb_scalar_function_init_set_error(info, "Invalid Java scalar function callback state"); + return; + } + + local_state = new JavaScalarFunctionLocalState(); + local_state->vm = state->vm; + local_state->env = nullptr; + local_state->detach_when_done = false; + get_or_attach_jni_env(local_state->vm, local_state->env, local_state->detach_when_done); + duckdb_scalar_function_init_set_state(info, local_state, destroy_java_scalar_function_local_state); + local_state = nullptr; + } catch (const std::exception &e) { + if (local_state) { + destroy_java_scalar_function_local_state(local_state); + } + duckdb_scalar_function_init_set_error(info, e.what()); + } +} + +static void execute_java_scalar_function_capi(duckdb_function_info info, duckdb_data_chunk input, + duckdb_vector output) { + auto state = reinterpret_cast(duckdb_scalar_function_get_extra_info(info)); + auto local_state = reinterpret_cast(duckdb_scalar_function_get_state(info)); + if (!state || !local_state || !local_state->env || !input || !output) { + duckdb_scalar_function_set_error(info, "Invalid Java scalar function callback state"); + return; + } + + try { + execute_java_scalar_function(local_state->env, *state, info, input, output); + } catch (const std::exception &e) { + duckdb_scalar_function_set_error(info, e.what()); + } +} diff --git a/src/jni/scalar_functions.hpp b/src/jni/scalar_functions.hpp new file mode 100644 index 000000000..966213bb0 --- /dev/null +++ b/src/jni/scalar_functions.hpp @@ -0,0 +1,5 @@ +#pragma once + +#include "bindings.hpp" + +void scalar_function_set_function(JNIEnv *env, jobject scalar_function_buf, jobject function_j); diff --git a/src/main/java/org/duckdb/DuckDBBindings.java b/src/main/java/org/duckdb/DuckDBBindings.java index 4ee45c04d..eb535881b 100644 --- a/src/main/java/org/duckdb/DuckDBBindings.java +++ b/src/main/java/org/duckdb/DuckDBBindings.java @@ -17,10 +17,36 @@ public class DuckDBBindings { static native long duckdb_vector_size(); + // scalar function + + static native ByteBuffer duckdb_create_scalar_function(); + + static native void duckdb_destroy_scalar_function(ByteBuffer scalarFunction); + + static native void duckdb_scalar_function_set_name(ByteBuffer scalarFunction, byte[] name); + + static native void duckdb_scalar_function_add_parameter(ByteBuffer scalarFunction, ByteBuffer logicalType); + + static native void duckdb_scalar_function_set_return_type(ByteBuffer scalarFunction, ByteBuffer logicalType); + + static native void duckdb_scalar_function_set_varargs(ByteBuffer scalarFunction, ByteBuffer logicalType); + + static native void duckdb_scalar_function_set_volatile(ByteBuffer scalarFunction); + + static native void duckdb_scalar_function_set_special_handling(ByteBuffer scalarFunction); + + static native int duckdb_register_scalar_function(ByteBuffer connection, ByteBuffer scalarFunction); + + static native void duckdb_scalar_function_set_function(ByteBuffer scalarFunction, Object function); + + static native void duckdb_scalar_function_set_error(ByteBuffer functionInfo, byte[] error); + // logical type static native ByteBuffer duckdb_create_logical_type(int duckdb_type); + static native ByteBuffer duckdb_create_decimal_type(int width, int scale); + static native int duckdb_get_type_id(ByteBuffer logical_type); static native int duckdb_decimal_width(ByteBuffer logical_type); @@ -65,6 +91,8 @@ public class DuckDBBindings { static native void duckdb_vector_assign_string_element_len(ByteBuffer vector, long index, byte[] str); + static native byte[] duckdb_vector_get_string(ByteBuffer vectorData, long row); + static native ByteBuffer duckdb_list_vector_get_child(ByteBuffer vector); static native long duckdb_list_vector_get_size(ByteBuffer vector); diff --git a/src/main/java/org/duckdb/DuckDBDataChunkReader.java b/src/main/java/org/duckdb/DuckDBDataChunkReader.java new file mode 100644 index 000000000..23fc16190 --- /dev/null +++ b/src/main/java/org/duckdb/DuckDBDataChunkReader.java @@ -0,0 +1,49 @@ +package org.duckdb; + +import static org.duckdb.DuckDBBindings.*; + +import java.nio.ByteBuffer; + +/** + * Reader over callback input data chunks. + * + *

Column index violations throw {@link IndexOutOfBoundsException}. + */ +public final class DuckDBDataChunkReader { + private final ByteBuffer chunkRef; + private final long rowCount; + private final long columnCount; + private final DuckDBReadableVector[] vectors; + + DuckDBDataChunkReader(ByteBuffer chunkRef) { + if (chunkRef == null) { + throw new DuckDBFunctionException("Invalid data chunk reference"); + } + this.chunkRef = chunkRef; + this.rowCount = duckdb_data_chunk_get_size(chunkRef); + this.columnCount = duckdb_data_chunk_get_column_count(chunkRef); + this.vectors = new DuckDBReadableVector[Math.toIntExact(columnCount)]; + } + + public long rowCount() { + return rowCount; + } + + public long columnCount() { + return columnCount; + } + + public DuckDBReadableVector vector(long columnIndex) { + if (columnIndex < 0 || columnIndex >= columnCount) { + throw new IndexOutOfBoundsException("Column index out of bounds: " + columnIndex); + } + int arrayIndex = Math.toIntExact(columnIndex); + DuckDBReadableVector vector = vectors[arrayIndex]; + if (vector == null) { + ByteBuffer vectorRef = duckdb_data_chunk_get_vector(chunkRef, columnIndex); + vector = new DuckDBReadableVectorImpl(vectorRef, rowCount); + vectors[arrayIndex] = vector; + } + return vector; + } +} diff --git a/src/main/java/org/duckdb/DuckDBDriver.java b/src/main/java/org/duckdb/DuckDBDriver.java index e045d14f7..9e9c4cffa 100644 --- a/src/main/java/org/duckdb/DuckDBDriver.java +++ b/src/main/java/org/duckdb/DuckDBDriver.java @@ -42,6 +42,9 @@ public class DuckDBDriver implements java.sql.Driver { private static boolean pinnedDbRefsShutdownHookRegistered = false; private static boolean pinnedDbRefsShutdownHookRun = false; + private static final ArrayList functionsRegistry = new ArrayList<>(); + private static final ReentrantLock functionsRegistryLock = new ReentrantLock(); + private static final Set supportedOptions = new LinkedHashSet<>(); private static final ReentrantLock supportedOptionsLock = new ReentrantLock(); @@ -263,6 +266,33 @@ public static boolean shutdownQueryCancelScheduler() { return true; } + public static List registeredFunctions() { + functionsRegistryLock.lock(); + try { + return Collections.unmodifiableList(new ArrayList<>(functionsRegistry)); + } finally { + functionsRegistryLock.unlock(); + } + } + + public static void clearFunctionsRegistry() { + functionsRegistryLock.lock(); + try { + functionsRegistry.clear(); + } finally { + functionsRegistryLock.unlock(); + } + } + + static void registerFunction(DuckDBRegisteredFunction function) { + functionsRegistryLock.lock(); + try { + functionsRegistry.add(function); + } finally { + functionsRegistryLock.unlock(); + } + } + private static DriverPropertyInfo createDriverPropInfo(String name, String value, String description) { DriverPropertyInfo dpi = new DriverPropertyInfo(name, value); dpi.description = description; diff --git a/src/main/java/org/duckdb/DuckDBFunctionException.java b/src/main/java/org/duckdb/DuckDBFunctionException.java new file mode 100644 index 000000000..2010fe433 --- /dev/null +++ b/src/main/java/org/duckdb/DuckDBFunctionException.java @@ -0,0 +1,13 @@ +package org.duckdb; + +public final class DuckDBFunctionException extends RuntimeException { + private static final long serialVersionUID = 1L; + + public DuckDBFunctionException(String message) { + super(message); + } + + public DuckDBFunctionException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/src/main/java/org/duckdb/DuckDBFunctions.java b/src/main/java/org/duckdb/DuckDBFunctions.java new file mode 100644 index 000000000..0a9337301 --- /dev/null +++ b/src/main/java/org/duckdb/DuckDBFunctions.java @@ -0,0 +1,14 @@ +package org.duckdb; + +import java.sql.SQLException; + +public final class DuckDBFunctions { + public enum DuckDBFunctionKind { SCALAR } + + private DuckDBFunctions() { + } + + public static DuckDBScalarFunctionBuilder scalarFunction() throws SQLException { + return new DuckDBScalarFunctionBuilder(); + } +} diff --git a/src/main/java/org/duckdb/DuckDBHugeInt.java b/src/main/java/org/duckdb/DuckDBHugeInt.java index 5912e0fa2..eeca81f83 100644 --- a/src/main/java/org/duckdb/DuckDBHugeInt.java +++ b/src/main/java/org/duckdb/DuckDBHugeInt.java @@ -1,11 +1,13 @@ package org.duckdb; import java.math.BigInteger; +import java.nio.ByteBuffer; import java.sql.SQLException; class DuckDBHugeInt { static final BigInteger HUGE_INT_MIN = BigInteger.ONE.shiftLeft(127).negate(); static final BigInteger HUGE_INT_MAX = BigInteger.ONE.shiftLeft(127).subtract(BigInteger.ONE); + static final BigInteger UHUGE_INT_MAX = BigInteger.ONE.shiftLeft(128).subtract(BigInteger.ONE); private final long lower; private final long upper; @@ -25,4 +27,24 @@ class DuckDBHugeInt { this.lower = bi.longValue(); this.upper = bi.shiftRight(64).longValue(); } + + static BigInteger toBigInteger(long lower, long upper) { + byte[] bytes = new byte[Long.BYTES * 2]; + ByteBuffer.wrap(bytes).putLong(upper).putLong(lower); + return new BigInteger(bytes); + } + + static BigInteger toUnsignedBigInteger(long lower, long upper) { + byte[] bytes = new byte[Long.BYTES * 2]; + ByteBuffer.wrap(bytes).putLong(upper).putLong(lower); + return new BigInteger(1, bytes); + } + + long lower() { + return lower; + } + + long upper() { + return upper; + } } diff --git a/src/main/java/org/duckdb/DuckDBLogicalType.java b/src/main/java/org/duckdb/DuckDBLogicalType.java new file mode 100644 index 000000000..2d4f53b10 --- /dev/null +++ b/src/main/java/org/duckdb/DuckDBLogicalType.java @@ -0,0 +1,97 @@ +package org.duckdb; + +import static org.duckdb.DuckDBBindings.*; +import static org.duckdb.DuckDBBindings.CAPIType.*; + +import java.nio.ByteBuffer; +import java.sql.SQLException; + +public final class DuckDBLogicalType implements AutoCloseable { + private ByteBuffer logicalTypeRef; + + private DuckDBLogicalType(ByteBuffer logicalTypeRef) throws SQLException { + if (logicalTypeRef == null) { + throw new SQLException("Failed to create logical type"); + } + this.logicalTypeRef = logicalTypeRef; + } + + public static DuckDBLogicalType of(DuckDBColumnType type) throws SQLException { + if (type == null) { + throw new SQLException("Logical type cannot be null"); + } + switch (type) { + case BOOLEAN: + return createPrimitive(DUCKDB_TYPE_BOOLEAN); + case TINYINT: + return createPrimitive(DUCKDB_TYPE_TINYINT); + case SMALLINT: + return createPrimitive(DUCKDB_TYPE_SMALLINT); + case INTEGER: + return createPrimitive(DUCKDB_TYPE_INTEGER); + case BIGINT: + return createPrimitive(DUCKDB_TYPE_BIGINT); + case HUGEINT: + return createPrimitive(DUCKDB_TYPE_HUGEINT); + case UTINYINT: + return createPrimitive(DUCKDB_TYPE_UTINYINT); + case USMALLINT: + return createPrimitive(DUCKDB_TYPE_USMALLINT); + case UINTEGER: + return createPrimitive(DUCKDB_TYPE_UINTEGER); + case UBIGINT: + return createPrimitive(DUCKDB_TYPE_UBIGINT); + case UHUGEINT: + return createPrimitive(DUCKDB_TYPE_UHUGEINT); + case FLOAT: + return createPrimitive(DUCKDB_TYPE_FLOAT); + case DOUBLE: + return createPrimitive(DUCKDB_TYPE_DOUBLE); + case VARCHAR: + return createPrimitive(DUCKDB_TYPE_VARCHAR); + case DATE: + return createPrimitive(DUCKDB_TYPE_DATE); + case TIMESTAMP_S: + return createPrimitive(DUCKDB_TYPE_TIMESTAMP_S); + case TIMESTAMP_MS: + return createPrimitive(DUCKDB_TYPE_TIMESTAMP_MS); + case TIMESTAMP: + return createPrimitive(DUCKDB_TYPE_TIMESTAMP); + case TIMESTAMP_NS: + return createPrimitive(DUCKDB_TYPE_TIMESTAMP_NS); + case TIMESTAMP_WITH_TIME_ZONE: + return createPrimitive(DUCKDB_TYPE_TIMESTAMP_TZ); + default: + throw new SQLException("Unsupported logical type for scalar UDF registration: " + type); + } + } + + public static DuckDBLogicalType decimal(int width, int scale) throws SQLException { + if (width < 1 || width > 38) { + throw new SQLException("DECIMAL width must be between 1 and 38, got: " + width); + } + if (scale < 0 || scale > width) { + throw new SQLException("DECIMAL scale must be between 0 and width, got: " + scale); + } + return new DuckDBLogicalType(duckdb_create_decimal_type(width, scale)); + } + + ByteBuffer logicalTypeRef() throws SQLException { + if (logicalTypeRef == null) { + throw new SQLException("Logical type is already closed"); + } + return logicalTypeRef; + } + + @Override + public void close() { + if (logicalTypeRef != null) { + duckdb_destroy_logical_type(logicalTypeRef); + logicalTypeRef = null; + } + } + + private static DuckDBLogicalType createPrimitive(DuckDBBindings.CAPIType type) throws SQLException { + return new DuckDBLogicalType(duckdb_create_logical_type(type.typeId)); + } +} diff --git a/src/main/java/org/duckdb/DuckDBReadableVector.java b/src/main/java/org/duckdb/DuckDBReadableVector.java new file mode 100644 index 000000000..52324ad70 --- /dev/null +++ b/src/main/java/org/duckdb/DuckDBReadableVector.java @@ -0,0 +1,86 @@ +package org.duckdb; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.sql.Date; +import java.sql.Timestamp; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.OffsetDateTime; +import java.util.stream.LongStream; + +/** + * Read-only scalar callback view over a DuckDB vector. + * + *

Implementations throw {@link DuckDBFunctionException} for callback-time type/value errors. + * Invalid row indexes throw {@link IndexOutOfBoundsException}. + */ +public abstract class DuckDBReadableVector { + public abstract DuckDBColumnType getType(); + + public abstract long rowCount(); + + public abstract LongStream rowIndexStream(); + + public abstract boolean isNull(long row); + + public abstract boolean getBoolean(long row); + + public abstract boolean getBoolean(long row, boolean defaultVal); + + public abstract byte getByte(long row); + + public abstract byte getByte(long row, byte defaultVal); + + public abstract short getShort(long row); + + public abstract short getShort(long row, short defaultVal); + + public abstract short getUint8(long row); + + public abstract short getUint8(long row, short defaultVal); + + public abstract int getUint16(long row); + + public abstract int getUint16(long row, int defaultVal); + + public abstract int getInt(long row); + + public abstract int getInt(long row, int defaultVal); + + public abstract long getUint32(long row); + + public abstract long getUint32(long row, long defaultVal); + + public abstract long getLong(long row); + + public abstract long getLong(long row, long defaultVal); + + public abstract BigInteger getHugeInt(long row); + + public abstract BigInteger getUHugeInt(long row); + + public abstract BigInteger getUint64(long row); + + public abstract float getFloat(long row); + + public abstract float getFloat(long row, float defaultVal); + + public abstract double getDouble(long row); + + public abstract double getDouble(long row, double defaultVal); + + public abstract LocalDate getLocalDate(long row); + + public abstract Date getDate(long row); + + public abstract LocalDateTime getLocalDateTime(long row); + + public abstract Timestamp getTimestamp(long row); + + public abstract OffsetDateTime getOffsetDateTime(long row); + + public abstract BigDecimal getBigDecimal(long row); + + public abstract String getString(long row); +} diff --git a/src/main/java/org/duckdb/DuckDBReadableVectorImpl.java b/src/main/java/org/duckdb/DuckDBReadableVectorImpl.java new file mode 100644 index 000000000..58b0a3f03 --- /dev/null +++ b/src/main/java/org/duckdb/DuckDBReadableVectorImpl.java @@ -0,0 +1,426 @@ +package org.duckdb; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.duckdb.DuckDBBindings.*; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.sql.Date; +import java.sql.Timestamp; +import java.time.Instant; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.OffsetDateTime; +import java.time.ZoneId; +import java.time.temporal.ChronoUnit; +import java.util.stream.LongStream; + +final class DuckDBReadableVectorImpl extends DuckDBReadableVector { + private static final BigDecimal ULONG_MULTIPLIER = new BigDecimal("18446744073709551616"); + private static final ByteOrder NATIVE_ORDER = ByteOrder.nativeOrder(); + + private final ByteBuffer vectorRef; + private final long rowCount; + private final DuckDBVectorTypeInfo typeInfo; + private final ByteBuffer data; + private final ByteBuffer validity; + + DuckDBReadableVectorImpl(ByteBuffer vectorRef, long rowCount) { + if (vectorRef == null) { + throw new DuckDBFunctionException("Invalid vector reference"); + } + this.vectorRef = vectorRef; + this.rowCount = rowCount; + try { + this.typeInfo = DuckDBVectorTypeInfo.fromVector(vectorRef); + } catch (java.sql.SQLException exception) { + throw new DuckDBFunctionException("Failed to resolve vector type info", exception); + } + this.data = + duckdb_vector_get_data(vectorRef, Math.multiplyExact(rowCount, typeInfo.widthBytes)).order(NATIVE_ORDER); + ByteBuffer validityBuffer = duckdb_vector_get_validity(vectorRef, rowCount); + this.validity = validityBuffer == null ? null : validityBuffer.order(NATIVE_ORDER); + } + + @Override + public DuckDBColumnType getType() { + return typeInfo.columnType; + } + + @Override + public long rowCount() { + return rowCount; + } + + @Override + public LongStream rowIndexStream() { + return LongStream.range(0, rowCount); + } + + @Override + public boolean isNull(long row) { + checkRowIndex(row); + if (validity == null) { + return false; + } + int entryPos = Math.toIntExact(Math.multiplyExact(row / Long.SIZE, (long) Long.BYTES)); + long mask = validity.getLong(entryPos); + return (mask & (1L << (row % Long.SIZE))) == 0; + } + + @Override + public boolean getBoolean(long row) { + requireType(DuckDBColumnType.BOOLEAN); + if (isNull(row)) { + throw primitiveNullValue(DuckDBColumnType.BOOLEAN, row); + } + return data.get(checkedRowIndex(row)) != 0; + } + + @Override + public boolean getBoolean(long row, boolean defaultVal) { + requireType(DuckDBColumnType.BOOLEAN); + return isNull(row) ? defaultVal : data.get(checkedRowIndex(row)) != 0; + } + + @Override + public byte getByte(long row) { + requireType(DuckDBColumnType.TINYINT); + if (isNull(row)) { + throw primitiveNullValue(DuckDBColumnType.TINYINT, row); + } + return data.get(checkedRowIndex(row)); + } + + @Override + public byte getByte(long row, byte defaultVal) { + requireType(DuckDBColumnType.TINYINT); + return isNull(row) ? defaultVal : data.get(checkedRowIndex(row)); + } + + @Override + public short getShort(long row) { + requireType(DuckDBColumnType.SMALLINT); + if (isNull(row)) { + throw primitiveNullValue(DuckDBColumnType.SMALLINT, row); + } + return data.getShort(checkedByteOffset(row, Short.BYTES)); + } + + @Override + public short getShort(long row, short defaultVal) { + requireType(DuckDBColumnType.SMALLINT); + return isNull(row) ? defaultVal : data.getShort(checkedByteOffset(row, Short.BYTES)); + } + + @Override + public short getUint8(long row) { + requireType(DuckDBColumnType.UTINYINT); + if (isNull(row)) { + throw primitiveNullValue(DuckDBColumnType.UTINYINT, row); + } + return (short) Byte.toUnsignedInt(data.get(checkedRowIndex(row))); + } + + @Override + public short getUint8(long row, short defaultVal) { + requireType(DuckDBColumnType.UTINYINT); + return isNull(row) ? defaultVal : (short) Byte.toUnsignedInt(data.get(checkedRowIndex(row))); + } + + @Override + public int getUint16(long row) { + requireType(DuckDBColumnType.USMALLINT); + if (isNull(row)) { + throw primitiveNullValue(DuckDBColumnType.USMALLINT, row); + } + return Short.toUnsignedInt(data.getShort(checkedByteOffset(row, Short.BYTES))); + } + + @Override + public int getUint16(long row, int defaultVal) { + requireType(DuckDBColumnType.USMALLINT); + return isNull(row) ? defaultVal : Short.toUnsignedInt(data.getShort(checkedByteOffset(row, Short.BYTES))); + } + + @Override + public int getInt(long row) { + requireType(DuckDBColumnType.INTEGER); + if (isNull(row)) { + throw primitiveNullValue(DuckDBColumnType.INTEGER, row); + } + return data.getInt(checkedByteOffset(row, Integer.BYTES)); + } + + @Override + public int getInt(long row, int defaultVal) { + requireType(DuckDBColumnType.INTEGER); + return isNull(row) ? defaultVal : data.getInt(checkedByteOffset(row, Integer.BYTES)); + } + + @Override + public long getUint32(long row) { + requireType(DuckDBColumnType.UINTEGER); + if (isNull(row)) { + throw primitiveNullValue(DuckDBColumnType.UINTEGER, row); + } + return Integer.toUnsignedLong(data.getInt(checkedByteOffset(row, Integer.BYTES))); + } + + @Override + public long getUint32(long row, long defaultVal) { + requireType(DuckDBColumnType.UINTEGER); + return isNull(row) ? defaultVal : Integer.toUnsignedLong(data.getInt(checkedByteOffset(row, Integer.BYTES))); + } + + @Override + public long getLong(long row) { + requireType(DuckDBColumnType.BIGINT); + if (isNull(row)) { + throw primitiveNullValue(DuckDBColumnType.BIGINT, row); + } + return data.getLong(checkedByteOffset(row, Long.BYTES)); + } + + @Override + public long getLong(long row, long defaultVal) { + requireType(DuckDBColumnType.BIGINT); + return isNull(row) ? defaultVal : data.getLong(checkedByteOffset(row, Long.BYTES)); + } + + @Override + public BigInteger getHugeInt(long row) { + requireType(DuckDBColumnType.HUGEINT); + if (isNull(row)) { + return null; + } + int offset = checkedByteOffset(row, typeInfo.widthBytes); + long lower = data.getLong(offset); + long upper = data.getLong(offset + Long.BYTES); + return DuckDBHugeInt.toBigInteger(lower, upper); + } + + @Override + public BigInteger getUHugeInt(long row) { + requireType(DuckDBColumnType.UHUGEINT); + if (isNull(row)) { + return null; + } + int offset = checkedByteOffset(row, typeInfo.widthBytes); + long lower = data.getLong(offset); + long upper = data.getLong(offset + Long.BYTES); + return DuckDBHugeInt.toUnsignedBigInteger(lower, upper); + } + + @Override + public BigInteger getUint64(long row) { + requireType(DuckDBColumnType.UBIGINT); + if (isNull(row)) { + return null; + } + long value = data.getLong(checkedByteOffset(row, Long.BYTES)); + return unsignedLongToBigInteger(value); + } + + @Override + public float getFloat(long row) { + requireType(DuckDBColumnType.FLOAT); + if (isNull(row)) { + throw primitiveNullValue(DuckDBColumnType.FLOAT, row); + } + return data.getFloat(checkedByteOffset(row, Float.BYTES)); + } + + @Override + public float getFloat(long row, float defaultVal) { + requireType(DuckDBColumnType.FLOAT); + return isNull(row) ? defaultVal : data.getFloat(checkedByteOffset(row, Float.BYTES)); + } + + @Override + public double getDouble(long row) { + requireType(DuckDBColumnType.DOUBLE); + if (isNull(row)) { + throw primitiveNullValue(DuckDBColumnType.DOUBLE, row); + } + return data.getDouble(checkedByteOffset(row, Double.BYTES)); + } + + @Override + public double getDouble(long row, double defaultVal) { + requireType(DuckDBColumnType.DOUBLE); + return isNull(row) ? defaultVal : data.getDouble(checkedByteOffset(row, Double.BYTES)); + } + + @Override + public LocalDate getLocalDate(long row) { + requireType(DuckDBColumnType.DATE); + if (isNull(row)) { + return null; + } + return LocalDate.ofEpochDay(data.getInt(checkedByteOffset(row, Integer.BYTES))); + } + + @Override + public Date getDate(long row) { + LocalDate value = getLocalDate(row); + return value == null ? null : Date.valueOf(value); + } + + @Override + public LocalDateTime getLocalDateTime(long row) { + requireTimestampType(); + if (isNull(row)) { + return null; + } + long epochValue = data.getLong(checkedByteOffset(row, Long.BYTES)); + try { + switch (typeInfo.capiType) { + case DUCKDB_TYPE_TIMESTAMP_S: + return DuckDBTimestamp.localDateTimeFromTimestamp(epochValue, ChronoUnit.SECONDS, null); + case DUCKDB_TYPE_TIMESTAMP_MS: + return DuckDBTimestamp.localDateTimeFromTimestamp(epochValue, ChronoUnit.MILLIS, null); + case DUCKDB_TYPE_TIMESTAMP: + return DuckDBTimestamp.localDateTimeFromTimestamp(epochValue, ChronoUnit.MICROS, null); + case DUCKDB_TYPE_TIMESTAMP_NS: + return DuckDBTimestamp.localDateTimeFromTimestamp(epochValue, ChronoUnit.NANOS, null); + case DUCKDB_TYPE_TIMESTAMP_TZ: + return DuckDBTimestamp.localDateTimeFromTimestampWithTimezone(epochValue, ChronoUnit.MICROS, null); + default: + throw new DuckDBFunctionException("Expected vector type TIMESTAMP*, found " + typeInfo.columnType); + } + } catch (java.sql.SQLException exception) { + throw new DuckDBFunctionException("Failed to decode timestamp at row " + row, exception); + } + } + + @Override + public Timestamp getTimestamp(long row) { + LocalDateTime value = getLocalDateTime(row); + return value == null ? null : Timestamp.valueOf(value); + } + + @Override + public OffsetDateTime getOffsetDateTime(long row) { + requireType(DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE); + if (isNull(row)) { + return null; + } + long micros = data.getLong(checkedByteOffset(row, Long.BYTES)); + Instant instant = instantFromEpoch(micros, ChronoUnit.MICROS); + return instant.atZone(ZoneId.systemDefault()).toOffsetDateTime(); + } + + @Override + public BigDecimal getBigDecimal(long row) { + requireType(DuckDBColumnType.DECIMAL); + if (isNull(row)) { + return null; + } + switch (typeInfo.storageType) { + case DUCKDB_TYPE_SMALLINT: + return BigDecimal.valueOf(data.getShort(checkedByteOffset(row, Short.BYTES)), typeInfo.decimalMeta.scale); + case DUCKDB_TYPE_INTEGER: + return BigDecimal.valueOf(data.getInt(checkedByteOffset(row, Integer.BYTES)), typeInfo.decimalMeta.scale); + case DUCKDB_TYPE_BIGINT: + return BigDecimal.valueOf(data.getLong(checkedByteOffset(row, Long.BYTES)), typeInfo.decimalMeta.scale); + case DUCKDB_TYPE_HUGEINT: { + int offset = checkedByteOffset(row, typeInfo.widthBytes); + long lower = data.getLong(offset); + long upper = data.getLong(offset + Long.BYTES); + return new BigDecimal(upper) + .multiply(ULONG_MULTIPLIER) + .add(new BigDecimal(Long.toUnsignedString(lower))) + .scaleByPowerOfTen(typeInfo.decimalMeta.scale * -1); + } + default: + throw new DuckDBFunctionException("Unsupported DECIMAL storage type: " + typeInfo.storageType); + } + } + + @Override + public String getString(long row) { + requireType(DuckDBColumnType.VARCHAR); + if (isNull(row)) { + return null; + } + byte[] bytes = duckdb_vector_get_string(data, row); + if (bytes == null) { + return null; + } + return new String(bytes, UTF_8); + } + + ByteBuffer vectorRef() { + return vectorRef; + } + + private void requireType(DuckDBColumnType expected) { + if (typeInfo.columnType != expected) { + throw new DuckDBFunctionException("Expected vector type " + expected + ", found " + typeInfo.columnType); + } + } + + private void requireTimestampType() { + switch (typeInfo.columnType) { + case TIMESTAMP: + case TIMESTAMP_S: + case TIMESTAMP_MS: + case TIMESTAMP_NS: + case TIMESTAMP_WITH_TIME_ZONE: + return; + default: + throw new DuckDBFunctionException("Expected vector type TIMESTAMP*, found " + typeInfo.columnType); + } + } + + private void checkRowIndex(long row) { + if (row < 0 || row >= rowCount) { + throw new IndexOutOfBoundsException("Row index out of bounds: " + row); + } + } + + private int checkedRowIndex(long row) { + checkRowIndex(row); + return Math.toIntExact(row); + } + + private int checkedByteOffset(long row, int elementWidth) { + checkRowIndex(row); + return Math.toIntExact(Math.multiplyExact(row, (long) elementWidth)); + } + + private static Instant instantFromEpoch(long value, ChronoUnit unit) { + switch (unit) { + case SECONDS: + return Instant.ofEpochSecond(value); + case MILLIS: + return Instant.ofEpochMilli(value); + case MICROS: { + long epochSecond = Math.floorDiv(value, 1_000_000L); + long nanoAdjustment = Math.floorMod(value, 1_000_000L) * 1000L; + return Instant.ofEpochSecond(epochSecond, nanoAdjustment); + } + case NANOS: { + long epochSecond = Math.floorDiv(value, 1_000_000_000L); + long nanoAdjustment = Math.floorMod(value, 1_000_000_000L); + return Instant.ofEpochSecond(epochSecond, nanoAdjustment); + } + default: + throw new DuckDBFunctionException("Unsupported unit type: " + unit); + } + } + + private static BigInteger unsignedLongToBigInteger(long value) { + if (value >= 0) { + return BigInteger.valueOf(value); + } + return BigInteger.valueOf(value & Long.MAX_VALUE).setBit(Long.SIZE - 1); + } + + private static DuckDBFunctionException primitiveNullValue(DuckDBColumnType type, long row) { + return new DuckDBFunctionException("Primitive value for " + type + " at row " + row + " is NULL"); + } +} diff --git a/src/main/java/org/duckdb/DuckDBRegisteredFunction.java b/src/main/java/org/duckdb/DuckDBRegisteredFunction.java new file mode 100644 index 000000000..f11f6c968 --- /dev/null +++ b/src/main/java/org/duckdb/DuckDBRegisteredFunction.java @@ -0,0 +1,98 @@ +package org.duckdb; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +public final class DuckDBRegisteredFunction { + private final String name; + private final DuckDBFunctions.DuckDBFunctionKind functionKind; + private final List parameterTypes; + private final List parameterColumnTypes; + private final DuckDBLogicalType returnType; + private final DuckDBColumnType returnColumnType; + private final DuckDBScalarFunction function; + private final DuckDBLogicalType varArgType; + private final boolean volatileFlag; + private final boolean specialHandlingFlag; + private final boolean propagateNullsFlag; + + private DuckDBRegisteredFunction(String name, DuckDBFunctions.DuckDBFunctionKind functionKind, + List parameterTypes, + List parameterColumnTypes, DuckDBLogicalType returnType, + DuckDBColumnType returnColumnType, DuckDBScalarFunction function, + DuckDBLogicalType varArgType, boolean volatileFlag, boolean specialHandlingFlag, + boolean propagateNullsFlag) { + this.name = name; + this.functionKind = functionKind; + this.parameterTypes = parameterTypes; + this.parameterColumnTypes = parameterColumnTypes; + this.returnType = returnType; + this.returnColumnType = returnColumnType; + this.function = function; + this.varArgType = varArgType; + this.volatileFlag = volatileFlag; + this.specialHandlingFlag = specialHandlingFlag; + this.propagateNullsFlag = propagateNullsFlag; + } + + public String name() { + return name; + } + + public DuckDBFunctions.DuckDBFunctionKind functionKind() { + return functionKind; + } + + public List parameterTypes() { + return parameterTypes; + } + + public List parameterColumnTypes() { + return parameterColumnTypes; + } + + public DuckDBLogicalType returnType() { + return returnType; + } + + public DuckDBColumnType returnColumnType() { + return returnColumnType; + } + + public DuckDBScalarFunction function() { + return function; + } + + public DuckDBLogicalType varArgType() { + return varArgType; + } + + public boolean isVolatile() { + return volatileFlag; + } + + public boolean hasSpecialHandling() { + return specialHandlingFlag; + } + + public boolean propagateNulls() { + return propagateNullsFlag; + } + + public boolean isScalar() { + return functionKind == DuckDBFunctions.DuckDBFunctionKind.SCALAR; + } + + static DuckDBRegisteredFunction of(String name, List parameterTypes, + List parameterColumnTypes, DuckDBLogicalType returnType, + DuckDBColumnType returnColumnType, DuckDBScalarFunction function, + DuckDBLogicalType varArgType, boolean volatileFlag, boolean specialHandlingFlag, + boolean propagateNullsFlag) { + return new DuckDBRegisteredFunction(name, DuckDBFunctions.DuckDBFunctionKind.SCALAR, + Collections.unmodifiableList(new ArrayList<>(parameterTypes)), + Collections.unmodifiableList(new ArrayList<>(parameterColumnTypes)), + returnType, returnColumnType, function, varArgType, volatileFlag, + specialHandlingFlag, propagateNullsFlag); + } +} diff --git a/src/main/java/org/duckdb/DuckDBScalarContext.java b/src/main/java/org/duckdb/DuckDBScalarContext.java new file mode 100644 index 000000000..ad0c927ae --- /dev/null +++ b/src/main/java/org/duckdb/DuckDBScalarContext.java @@ -0,0 +1,90 @@ +package org.duckdb; + +import java.util.stream.LongStream; +import java.util.stream.Stream; + +/** + * Per-invocation scalar callback context. + * + *

Runtime type/value failures are surfaced as {@link DuckDBFunctionException}. Invalid row or + * column indexes throw {@link IndexOutOfBoundsException}. + */ +public final class DuckDBScalarContext { + private final DuckDBDataChunkReader input; + private final DuckDBWritableVector output; + private boolean propagateNulls; + + DuckDBScalarContext(DuckDBDataChunkReader input, DuckDBWritableVector output, boolean propagateNulls) { + if (input == null) { + throw new IllegalArgumentException("Input chunk cannot be null"); + } + if (output == null) { + throw new IllegalArgumentException("Output vector cannot be null"); + } + this.input = input; + this.output = output; + this.propagateNulls = propagateNulls; + } + + public long rowCount() { + return input.rowCount(); + } + + public long columnCount() { + return input.columnCount(); + } + + public DuckDBReadableVector input(long columnIndex) { + return input.vector(columnIndex); + } + + public DuckDBWritableVector output() { + return output; + } + + public boolean nullsPropagated() { + return propagateNulls; + } + + public DuckDBScalarContext propagateNulls(boolean propagateNulls) { + this.propagateNulls = propagateNulls; + return this; + } + + DuckDBDataChunkReader inputChunk() { + return input; + } + + public Stream stream() { + LongStream rows = LongStream.range(0, rowCount()); + if (propagateNulls) { + rows = rows.filter(this::rowHasNoNullInputs); + } + return rows.mapToObj(this::row); + } + + public DuckDBScalarRow row(long rowIndex) { + checkRowIndex(rowIndex); + return new DuckDBScalarRow(this, rowIndex); + } + + DuckDBReadableVector inputUnchecked(long columnIndex) { + return input(columnIndex); + } + + private void checkRowIndex(long rowIndex) { + if (rowIndex < 0 || rowIndex >= rowCount()) { + throw new IndexOutOfBoundsException("Row index out of bounds: " + rowIndex); + } + } + + private boolean rowHasNoNullInputs(long rowIndex) { + for (long columnIndex = 0; columnIndex < columnCount(); columnIndex++) { + if (inputUnchecked(columnIndex).isNull(rowIndex)) { + output.setNull(rowIndex); + return false; + } + } + return true; + } +} diff --git a/src/main/java/org/duckdb/DuckDBScalarFunction.java b/src/main/java/org/duckdb/DuckDBScalarFunction.java new file mode 100644 index 000000000..beef14162 --- /dev/null +++ b/src/main/java/org/duckdb/DuckDBScalarFunction.java @@ -0,0 +1,15 @@ +package org.duckdb; + +@FunctionalInterface +public interface DuckDBScalarFunction { + /** + * Processes a full input chunk and writes one output value per row directly into the DuckDB output vector. + * + *

The context and all wrappers returned from it are valid only for the duration of the callback and must not + * be retained. + * + * @param ctx scalar function execution context for the current chunk + * @throws Exception when function execution fails + */ + void apply(DuckDBScalarContext ctx) throws Exception; +} diff --git a/src/main/java/org/duckdb/DuckDBScalarFunctionAdapter.java b/src/main/java/org/duckdb/DuckDBScalarFunctionAdapter.java new file mode 100644 index 000000000..f67d20a72 --- /dev/null +++ b/src/main/java/org/duckdb/DuckDBScalarFunctionAdapter.java @@ -0,0 +1,550 @@ +package org.duckdb; + +import static org.duckdb.DuckDBBindings.*; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.sql.SQLException; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.OffsetDateTime; +import java.time.ZoneOffset; +import java.util.Date; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.function.BiFunction; +import java.util.function.DoubleBinaryOperator; +import java.util.function.DoubleUnaryOperator; +import java.util.function.Function; +import java.util.function.IntBinaryOperator; +import java.util.function.IntUnaryOperator; +import java.util.function.LongBinaryOperator; +import java.util.function.LongUnaryOperator; +import java.util.function.Supplier; + +final class DuckDBScalarFunctionAdapter { + private static final Map CODECS_BY_DUCKDB_TYPE = new LinkedHashMap<>(); + private static final Map, DuckDBColumnType> DUCKDB_TYPE_BY_JAVA_CLASS = new LinkedHashMap<>(); + private static final Map, Class> BOXED_CLASSES = new LinkedHashMap<>(); + private static final TypeCodec DATE_SQL_CODEC = + new TypeCodec(java.sql.Date.class, DuckDBReadableVector::getDate, + (vector, row, value) -> vector.setDate(row, (java.sql.Date) value)); + private static final TypeCodec TIMESTAMP_SQL_CODEC = + new TypeCodec(java.sql.Timestamp.class, DuckDBReadableVector::getTimestamp, + (vector, row, value) -> vector.setTimestamp(row, (java.sql.Timestamp) value)); + private static final TypeCodec TIMESTAMP_UTIL_DATE_CODEC = new TypeCodec( + Date.class, + (vector, row) + -> Date.from(vector.getLocalDateTime(row).toInstant(ZoneOffset.UTC)), + (vector, row, + value) -> vector.setTimestamp(row, LocalDateTime.ofInstant(((Date) value).toInstant(), ZoneOffset.UTC))); + + static { + register(DuckDBColumnType.BOOLEAN, Boolean.class, DuckDBReadableVector::getBoolean, + DuckDBWritableVector::setBoolean); + register(DuckDBColumnType.TINYINT, Byte.class, DuckDBReadableVector::getByte, DuckDBWritableVector::setByte); + register(DuckDBColumnType.UTINYINT, Short.class, DuckDBReadableVector::getUint8, + (out, row, value) -> out.setUint8(row, value)); + register(DuckDBColumnType.SMALLINT, Short.class, DuckDBReadableVector::getShort, + DuckDBWritableVector::setShort); + register(DuckDBColumnType.USMALLINT, Integer.class, DuckDBReadableVector::getUint16, + (out, row, value) -> out.setUint16(row, value)); + register(DuckDBColumnType.INTEGER, Integer.class, DuckDBReadableVector::getInt, DuckDBWritableVector::setInt); + register(DuckDBColumnType.UINTEGER, Long.class, DuckDBReadableVector::getUint32, + (out, row, value) -> out.setUint32(row, value)); + register(DuckDBColumnType.BIGINT, Long.class, DuckDBReadableVector::getLong, DuckDBWritableVector::setLong); + register(DuckDBColumnType.HUGEINT, BigInteger.class, DuckDBReadableVector::getHugeInt, + DuckDBWritableVector::setHugeInt); + register(DuckDBColumnType.UHUGEINT, BigInteger.class, DuckDBReadableVector::getUHugeInt, + DuckDBWritableVector::setUHugeInt); + register(DuckDBColumnType.UBIGINT, BigInteger.class, DuckDBReadableVector::getUint64, + DuckDBWritableVector::setUint64); + register(DuckDBColumnType.FLOAT, Float.class, DuckDBReadableVector::getFloat, DuckDBWritableVector::setFloat); + register(DuckDBColumnType.DOUBLE, Double.class, DuckDBReadableVector::getDouble, + DuckDBWritableVector::setDouble); + register(DuckDBColumnType.DECIMAL, BigDecimal.class, DuckDBReadableVector::getBigDecimal, + DuckDBWritableVector::setBigDecimal); + register(DuckDBColumnType.VARCHAR, String.class, DuckDBReadableVector::getString, + DuckDBWritableVector::setString); + register(DuckDBColumnType.DATE, LocalDate.class, DuckDBReadableVector::getLocalDate, + DuckDBWritableVector::setDate); + register(DuckDBColumnType.TIMESTAMP_S, LocalDateTime.class, DuckDBReadableVector::getLocalDateTime, + DuckDBWritableVector::setTimestamp); + register(DuckDBColumnType.TIMESTAMP_MS, LocalDateTime.class, DuckDBReadableVector::getLocalDateTime, + DuckDBWritableVector::setTimestamp); + register(DuckDBColumnType.TIMESTAMP, LocalDateTime.class, DuckDBReadableVector::getLocalDateTime, + DuckDBWritableVector::setTimestamp); + register(DuckDBColumnType.TIMESTAMP_NS, LocalDateTime.class, DuckDBReadableVector::getLocalDateTime, + DuckDBWritableVector::setTimestamp); + register(DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE, OffsetDateTime.class, + DuckDBReadableVector::getOffsetDateTime, DuckDBWritableVector::setOffsetDateTime); + + DUCKDB_TYPE_BY_JAVA_CLASS.put(boolean.class, DuckDBColumnType.BOOLEAN); + DUCKDB_TYPE_BY_JAVA_CLASS.put(Boolean.class, DuckDBColumnType.BOOLEAN); + DUCKDB_TYPE_BY_JAVA_CLASS.put(byte.class, DuckDBColumnType.TINYINT); + DUCKDB_TYPE_BY_JAVA_CLASS.put(Byte.class, DuckDBColumnType.TINYINT); + DUCKDB_TYPE_BY_JAVA_CLASS.put(short.class, DuckDBColumnType.SMALLINT); + DUCKDB_TYPE_BY_JAVA_CLASS.put(Short.class, DuckDBColumnType.SMALLINT); + DUCKDB_TYPE_BY_JAVA_CLASS.put(int.class, DuckDBColumnType.INTEGER); + DUCKDB_TYPE_BY_JAVA_CLASS.put(Integer.class, DuckDBColumnType.INTEGER); + DUCKDB_TYPE_BY_JAVA_CLASS.put(long.class, DuckDBColumnType.BIGINT); + DUCKDB_TYPE_BY_JAVA_CLASS.put(Long.class, DuckDBColumnType.BIGINT); + DUCKDB_TYPE_BY_JAVA_CLASS.put(float.class, DuckDBColumnType.FLOAT); + DUCKDB_TYPE_BY_JAVA_CLASS.put(Float.class, DuckDBColumnType.FLOAT); + DUCKDB_TYPE_BY_JAVA_CLASS.put(double.class, DuckDBColumnType.DOUBLE); + DUCKDB_TYPE_BY_JAVA_CLASS.put(Double.class, DuckDBColumnType.DOUBLE); + DUCKDB_TYPE_BY_JAVA_CLASS.put(String.class, DuckDBColumnType.VARCHAR); + DUCKDB_TYPE_BY_JAVA_CLASS.put(BigDecimal.class, DuckDBColumnType.DECIMAL); + DUCKDB_TYPE_BY_JAVA_CLASS.put(BigInteger.class, DuckDBColumnType.HUGEINT); + DUCKDB_TYPE_BY_JAVA_CLASS.put(LocalDate.class, DuckDBColumnType.DATE); + DUCKDB_TYPE_BY_JAVA_CLASS.put(java.sql.Date.class, DuckDBColumnType.DATE); + DUCKDB_TYPE_BY_JAVA_CLASS.put(LocalDateTime.class, DuckDBColumnType.TIMESTAMP); + DUCKDB_TYPE_BY_JAVA_CLASS.put(java.sql.Timestamp.class, DuckDBColumnType.TIMESTAMP); + DUCKDB_TYPE_BY_JAVA_CLASS.put(Date.class, DuckDBColumnType.TIMESTAMP); + DUCKDB_TYPE_BY_JAVA_CLASS.put(OffsetDateTime.class, DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE); + + BOXED_CLASSES.put(boolean.class, Boolean.class); + BOXED_CLASSES.put(byte.class, Byte.class); + BOXED_CLASSES.put(short.class, Short.class); + BOXED_CLASSES.put(int.class, Integer.class); + BOXED_CLASSES.put(long.class, Long.class); + BOXED_CLASSES.put(float.class, Float.class); + BOXED_CLASSES.put(double.class, Double.class); + } + + static DuckDBScalarFunction unary(Function function, DuckDBColumnType parameterType, + Class parameterJavaType, DuckDBColumnType returnType, Class returnJavaType) + throws SQLException { + @SuppressWarnings("unchecked") Function typedFunction = (Function) function; + TypeCodec inCodec = codecFor(parameterType, parameterJavaType); + TypeCodec outCodec = codecFor(returnType, returnJavaType); + return ctx -> { + DuckDBReadableVector in = ctx.input(0); + DuckDBWritableVector out = ctx.output(); + long rowCount = ctx.rowCount(); + boolean propagateNulls = ctx.nullsPropagated(); + for (long row = 0; row < rowCount; row++) { + try { + if (propagateNulls && in.isNull(row)) { + out.setNull(row); + continue; + } + Object argument = in.isNull(row) ? null : inCodec.read(in, row); + Object result = typedFunction.apply(argument); + outCodec.write(out, row, result); + } catch (DuckDBFunctionException exception) { + throw new DuckDBFunctionException("Failed to execute unary scalar function at row " + row, + exception); + } + } + }; + } + + static DuckDBScalarFunction binary(BiFunction function, DuckDBColumnType leftType, Class leftJavaType, + DuckDBColumnType rightType, Class rightJavaType, DuckDBColumnType returnType, + Class returnJavaType) throws SQLException { + @SuppressWarnings("unchecked") + BiFunction typedFunction = (BiFunction) function; + TypeCodec leftCodec = codecFor(leftType, leftJavaType); + TypeCodec rightCodec = codecFor(rightType, rightJavaType); + TypeCodec outCodec = codecFor(returnType, returnJavaType); + return ctx -> { + DuckDBReadableVector left = ctx.input(0); + DuckDBReadableVector right = ctx.input(1); + DuckDBWritableVector out = ctx.output(); + long rowCount = ctx.rowCount(); + boolean propagateNulls = ctx.nullsPropagated(); + for (long row = 0; row < rowCount; row++) { + try { + boolean leftIsNull = left.isNull(row); + boolean rightIsNull = right.isNull(row); + if (propagateNulls && (leftIsNull || rightIsNull)) { + out.setNull(row); + continue; + } + Object leftValue = leftIsNull ? null : leftCodec.read(left, row); + Object rightValue = rightIsNull ? null : rightCodec.read(right, row); + Object result = typedFunction.apply(leftValue, rightValue); + outCodec.write(out, row, result); + } catch (DuckDBFunctionException exception) { + throw new DuckDBFunctionException("Failed to execute binary scalar function at row " + row, + exception); + } + } + }; + } + + static DuckDBScalarFunction intUnary(IntUnaryOperator function) { + return ctx -> { + if (!ctx.nullsPropagated()) { + throw new DuckDBFunctionException("withIntFunction requires propagateNulls(true)"); + } + DuckDBReadableVector in = ctx.input(0); + DuckDBWritableVector out = ctx.output(); + long rowCount = ctx.rowCount(); + for (long row = 0; row < rowCount; row++) { + try { + if (in.isNull(row)) { + out.setNull(row); + } else { + out.setInt(row, function.applyAsInt(in.getInt(row))); + } + } catch (DuckDBFunctionException exception) { + throw new DuckDBFunctionException("Failed to execute withIntFunction at row " + row, exception); + } + } + }; + } + + static DuckDBScalarFunction intBinary(IntBinaryOperator function) { + return ctx -> { + if (!ctx.nullsPropagated()) { + throw new DuckDBFunctionException("withIntFunction requires propagateNulls(true)"); + } + DuckDBReadableVector left = ctx.input(0); + DuckDBReadableVector right = ctx.input(1); + DuckDBWritableVector out = ctx.output(); + long rowCount = ctx.rowCount(); + for (long row = 0; row < rowCount; row++) { + try { + if (left.isNull(row) || right.isNull(row)) { + out.setNull(row); + } else { + out.setInt(row, function.applyAsInt(left.getInt(row), right.getInt(row))); + } + } catch (DuckDBFunctionException exception) { + throw new DuckDBFunctionException("Failed to execute withIntFunction at row " + row, exception); + } + } + }; + } + + static DuckDBScalarFunction doubleUnary(DoubleUnaryOperator function) { + return ctx -> { + if (!ctx.nullsPropagated()) { + throw new DuckDBFunctionException("withDoubleFunction requires propagateNulls(true)"); + } + DuckDBReadableVector in = ctx.input(0); + DuckDBWritableVector out = ctx.output(); + long rowCount = ctx.rowCount(); + for (long row = 0; row < rowCount; row++) { + try { + if (in.isNull(row)) { + out.setNull(row); + } else { + out.setDouble(row, function.applyAsDouble(in.getDouble(row))); + } + } catch (DuckDBFunctionException exception) { + throw new DuckDBFunctionException("Failed to execute withDoubleFunction at row " + row, exception); + } + } + }; + } + + static DuckDBScalarFunction doubleBinary(DoubleBinaryOperator function) { + return ctx -> { + if (!ctx.nullsPropagated()) { + throw new DuckDBFunctionException("withDoubleFunction requires propagateNulls(true)"); + } + DuckDBReadableVector left = ctx.input(0); + DuckDBReadableVector right = ctx.input(1); + DuckDBWritableVector out = ctx.output(); + long rowCount = ctx.rowCount(); + for (long row = 0; row < rowCount; row++) { + try { + if (left.isNull(row) || right.isNull(row)) { + out.setNull(row); + } else { + out.setDouble(row, function.applyAsDouble(left.getDouble(row), right.getDouble(row))); + } + } catch (DuckDBFunctionException exception) { + throw new DuckDBFunctionException("Failed to execute withDoubleFunction at row " + row, exception); + } + } + }; + } + + static DuckDBScalarFunction longUnary(LongUnaryOperator function) { + return ctx -> { + if (!ctx.nullsPropagated()) { + throw new DuckDBFunctionException("withLongFunction requires propagateNulls(true)"); + } + DuckDBReadableVector in = ctx.input(0); + DuckDBWritableVector out = ctx.output(); + long rowCount = ctx.rowCount(); + for (long row = 0; row < rowCount; row++) { + try { + if (in.isNull(row)) { + out.setNull(row); + } else { + out.setLong(row, function.applyAsLong(in.getLong(row))); + } + } catch (DuckDBFunctionException exception) { + throw new DuckDBFunctionException("Failed to execute withLongFunction at row " + row, exception); + } + } + }; + } + + static DuckDBScalarFunction longBinary(LongBinaryOperator function) { + return ctx -> { + if (!ctx.nullsPropagated()) { + throw new DuckDBFunctionException("withLongFunction requires propagateNulls(true)"); + } + DuckDBReadableVector left = ctx.input(0); + DuckDBReadableVector right = ctx.input(1); + DuckDBWritableVector out = ctx.output(); + long rowCount = ctx.rowCount(); + for (long row = 0; row < rowCount; row++) { + try { + if (left.isNull(row) || right.isNull(row)) { + out.setNull(row); + } else { + out.setLong(row, function.applyAsLong(left.getLong(row), right.getLong(row))); + } + } catch (DuckDBFunctionException exception) { + throw new DuckDBFunctionException("Failed to execute withLongFunction at row " + row, exception); + } + } + }; + } + + static DuckDBScalarFunction nullary(Supplier function, DuckDBColumnType returnType, Class returnJavaType) + throws SQLException { + @SuppressWarnings("unchecked") Supplier typedFunction = (Supplier) function; + TypeCodec outCodec = codecFor(returnType, returnJavaType); + return ctx -> { + DuckDBWritableVector out = ctx.output(); + long rowCount = ctx.rowCount(); + for (long row = 0; row < rowCount; row++) { + try { + Object result = typedFunction.get(); + outCodec.write(out, row, result); + } catch (DuckDBFunctionException exception) { + throw new DuckDBFunctionException("Failed to execute supplier scalar function at row " + row, + exception); + } + } + }; + } + + static DuckDBScalarFunction variadic(Function function, DuckDBColumnType[] fixedTypes, + Class[] fixedJavaTypes, DuckDBColumnType varArgType, + Class varArgJavaType, DuckDBColumnType returnType, Class returnJavaType) + throws SQLException { + TypeCodec outCodec = codecFor(returnType, returnJavaType); + TypeCodec varArgCodec = codecFor(varArgType, varArgJavaType); + TypeCodec[] fixedCodecs = new TypeCodec[fixedTypes.length]; + for (int i = 0; i < fixedTypes.length; i++) { + fixedCodecs[i] = codecFor(fixedTypes[i], fixedJavaTypes[i]); + } + return ctx -> { + DuckDBWritableVector out = ctx.output(); + long rowCount = ctx.rowCount(); + int vectorCount = Math.toIntExact(ctx.columnCount()); + boolean propagateNulls = ctx.nullsPropagated(); + DuckDBReadableVector[] vectors = new DuckDBReadableVector[vectorCount]; + TypeCodec[] codecs = new TypeCodec[vectorCount]; + for (int column = 0; column < vectorCount; column++) { + vectors[column] = ctx.input(column); + codecs[column] = column < fixedCodecs.length ? fixedCodecs[column] : varArgCodec; + } + Object[] args = new Object[vectorCount]; + for (long row = 0; row < rowCount; row++) { + try { + boolean skipRow = false; + for (int column = 0; column < vectorCount; column++) { + DuckDBReadableVector vector = vectors[column]; + if (vector.isNull(row)) { + if (propagateNulls) { + out.setNull(row); + skipRow = true; + break; + } + args[column] = null; + } else { + args[column] = codecs[column].read(vector, row); + } + } + if (skipRow) { + continue; + } + Object result = function.apply(args); + outCodec.write(out, row, result); + } catch (DuckDBFunctionException exception) { + throw new DuckDBFunctionException("Failed to execute variadic scalar function at row " + row, + exception); + } + } + }; + } + + static DuckDBColumnType mapJavaClassToDuckDBType(Class javaType) throws SQLException { + if (javaType == null) { + throw new SQLException("Java type cannot be null"); + } + Class normalizedClass = normalizeJavaClass(javaType); + DuckDBColumnType mappedType = DUCKDB_TYPE_BY_JAVA_CLASS.get(normalizedClass); + if (mappedType != null) { + return mappedType; + } + throw new SQLException("Unsupported Java type for scalar function mapping: " + javaType.getName()); + } + + static DuckDBColumnType mapLogicalTypeToDuckDBType(DuckDBLogicalType logicalType) throws SQLException { + if (logicalType == null) { + throw new SQLException("Logical type cannot be null"); + } + DuckDBBindings.CAPIType type = + DuckDBBindings.CAPIType.capiTypeFromTypeId(duckdb_get_type_id(logicalType.logicalTypeRef())); + switch (type) { + case DUCKDB_TYPE_BOOLEAN: + return DuckDBColumnType.BOOLEAN; + case DUCKDB_TYPE_TINYINT: + return DuckDBColumnType.TINYINT; + case DUCKDB_TYPE_UTINYINT: + return DuckDBColumnType.UTINYINT; + case DUCKDB_TYPE_SMALLINT: + return DuckDBColumnType.SMALLINT; + case DUCKDB_TYPE_USMALLINT: + return DuckDBColumnType.USMALLINT; + case DUCKDB_TYPE_INTEGER: + return DuckDBColumnType.INTEGER; + case DUCKDB_TYPE_UINTEGER: + return DuckDBColumnType.UINTEGER; + case DUCKDB_TYPE_BIGINT: + return DuckDBColumnType.BIGINT; + case DUCKDB_TYPE_HUGEINT: + return DuckDBColumnType.HUGEINT; + case DUCKDB_TYPE_UBIGINT: + return DuckDBColumnType.UBIGINT; + case DUCKDB_TYPE_UHUGEINT: + return DuckDBColumnType.UHUGEINT; + case DUCKDB_TYPE_FLOAT: + return DuckDBColumnType.FLOAT; + case DUCKDB_TYPE_DOUBLE: + return DuckDBColumnType.DOUBLE; + case DUCKDB_TYPE_DECIMAL: + return DuckDBColumnType.DECIMAL; + case DUCKDB_TYPE_VARCHAR: + return DuckDBColumnType.VARCHAR; + case DUCKDB_TYPE_DATE: + return DuckDBColumnType.DATE; + case DUCKDB_TYPE_TIMESTAMP_S: + return DuckDBColumnType.TIMESTAMP_S; + case DUCKDB_TYPE_TIMESTAMP_MS: + return DuckDBColumnType.TIMESTAMP_MS; + case DUCKDB_TYPE_TIMESTAMP: + return DuckDBColumnType.TIMESTAMP; + case DUCKDB_TYPE_TIMESTAMP_NS: + return DuckDBColumnType.TIMESTAMP_NS; + case DUCKDB_TYPE_TIMESTAMP_TZ: + return DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE; + default: + throw new SQLException("Unsupported logical type for Function/BiFunction mapping: " + type); + } + } + + private static TypeCodec codecFor(DuckDBColumnType type) throws SQLException { + TypeCodec codec = CODECS_BY_DUCKDB_TYPE.get(type); + if (codec != null) { + return codec; + } + throw new SQLException("Unsupported DuckDB type for Function/BiFunction mapping: " + type); + } + + private static TypeCodec codecFor(DuckDBColumnType type, Class declaredJavaType) throws SQLException { + if (declaredJavaType == null) { + return codecFor(type); + } + Class normalizedClass = normalizeJavaClass(declaredJavaType); + switch (type) { + case DATE: + if (normalizedClass == LocalDate.class) { + return codecFor(type); + } + if (normalizedClass == java.sql.Date.class) { + return DATE_SQL_CODEC; + } + break; + case TIMESTAMP: + case TIMESTAMP_S: + case TIMESTAMP_MS: + case TIMESTAMP_NS: + if (normalizedClass == LocalDateTime.class) { + return codecFor(type); + } + if (normalizedClass == java.sql.Timestamp.class) { + return TIMESTAMP_SQL_CODEC; + } + if (normalizedClass == Date.class) { + return TIMESTAMP_UTIL_DATE_CODEC; + } + break; + default: { + TypeCodec codec = codecFor(type); + if (codec.matches(normalizedClass)) { + return codec; + } + break; + } + } + throw new SQLException("Unsupported Java type " + normalizedClass.getName() + " for DuckDB type " + type + + " in functional scalar function mapping"); + } + + private static void register(DuckDBColumnType type, Class javaType, Reader reader, Writer writer) { + CODECS_BY_DUCKDB_TYPE.put(type, new TypeCodec(javaType, reader, writer)); + } + + private static Class normalizeJavaClass(Class javaType) { + Class boxedClass = BOXED_CLASSES.get(javaType); + return boxedClass != null ? boxedClass : javaType; + } + + private DuckDBScalarFunctionAdapter() { + } + + @FunctionalInterface + private interface Reader { + T read(DuckDBReadableVector vector, long row); + } + + @FunctionalInterface + private interface Writer { + void write(DuckDBWritableVector vector, long row, T value); + } + + private static final class TypeCodec { + private final Class javaType; + private final Reader reader; + private final Writer writer; + + private TypeCodec(Class javaType, Reader reader, Writer writer) { + this.javaType = javaType; + this.reader = reader; + this.writer = writer; + } + + boolean matches(Class declaredJavaType) { + return javaType == declaredJavaType; + } + + Object read(DuckDBReadableVector vector, long row) { + return reader.read(vector, row); + } + + void write(DuckDBWritableVector vector, long row, Object value) { + if (value == null) { + vector.setNull(row); + return; + } + if (!javaType.isInstance(value)) { + throw new ClassCastException("Expected value of type " + javaType.getName() + ", got " + + value.getClass().getName()); + } + @SuppressWarnings("unchecked") Writer typedWriter = (Writer) writer; + typedWriter.write(vector, row, value); + } + } +} diff --git a/src/main/java/org/duckdb/DuckDBScalarFunctionBuilder.java b/src/main/java/org/duckdb/DuckDBScalarFunctionBuilder.java new file mode 100644 index 000000000..6298d0713 --- /dev/null +++ b/src/main/java/org/duckdb/DuckDBScalarFunctionBuilder.java @@ -0,0 +1,471 @@ +package org.duckdb; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.duckdb.DuckDBBindings.*; + +import java.nio.ByteBuffer; +import java.sql.Connection; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.locks.Lock; +import java.util.function.BiFunction; +import java.util.function.DoubleBinaryOperator; +import java.util.function.DoubleUnaryOperator; +import java.util.function.Function; +import java.util.function.IntBinaryOperator; +import java.util.function.IntUnaryOperator; +import java.util.function.LongBinaryOperator; +import java.util.function.LongUnaryOperator; +import java.util.function.Supplier; + +public final class DuckDBScalarFunctionBuilder implements AutoCloseable { + private ByteBuffer scalarFunctionRef; + private String functionName; + private DuckDBLogicalType returnType; + private DuckDBColumnType returnColumnType; + private Class returnJavaType; + private DuckDBScalarFunction callback; + private DuckDBLogicalType varArgType; + private final List parameterTypes = new ArrayList<>(); + private final List parameterColumnTypes = new ArrayList<>(); + private final List> parameterJavaTypes = new ArrayList<>(); + private boolean volatileFlag; + private boolean specialHandlingFlag; + private boolean propagateNullsFlag; + private boolean finalized; + + DuckDBScalarFunctionBuilder() throws SQLException { + this.scalarFunctionRef = duckdb_create_scalar_function(); + if (scalarFunctionRef == null) { + throw new SQLException("Failed to create scalar function"); + } + } + + public DuckDBScalarFunctionBuilder withName(String name) throws SQLException { + ensureNotFinalized(); + if (name == null || name.trim().isEmpty()) { + throw new SQLException("Function name cannot be null or empty"); + } + this.functionName = name; + duckdb_scalar_function_set_name(scalarFunctionRef, name.getBytes(UTF_8)); + return this; + } + + public DuckDBScalarFunctionBuilder withReturnType(DuckDBLogicalType returnType) throws SQLException { + ensureNotFinalized(); + if (returnType == null) { + throw new SQLException("Return type cannot be null"); + } + this.returnType = returnType; + this.returnColumnType = null; + this.returnJavaType = null; + duckdb_scalar_function_set_return_type(scalarFunctionRef, returnType.logicalTypeRef()); + return this; + } + + public DuckDBScalarFunctionBuilder withParameter(DuckDBLogicalType parameterType) throws SQLException { + ensureNotFinalized(); + if (parameterType == null) { + throw new SQLException("Parameter type cannot be null"); + } + parameterTypes.add(parameterType); + parameterColumnTypes.add(null); + parameterJavaTypes.add(null); + duckdb_scalar_function_add_parameter(scalarFunctionRef, parameterType.logicalTypeRef()); + return this; + } + + public DuckDBScalarFunctionBuilder withReturnType(Class returnType) throws SQLException { + ensureNotFinalized(); + if (returnType == null) { + throw new SQLException("Return type cannot be null"); + } + DuckDBColumnType mappedType = DuckDBScalarFunctionAdapter.mapJavaClassToDuckDBType(returnType); + return setMappedReturnType(mappedType, returnType); + } + + public DuckDBScalarFunctionBuilder withParameter(Class parameterType) throws SQLException { + ensureNotFinalized(); + if (parameterType == null) { + throw new SQLException("Parameter type cannot be null"); + } + DuckDBColumnType mappedType = DuckDBScalarFunctionAdapter.mapJavaClassToDuckDBType(parameterType); + return addMappedParameterType(mappedType, parameterType); + } + + public DuckDBScalarFunctionBuilder withParameters(Class... parameterTypes) throws SQLException { + ensureNotFinalized(); + if (parameterTypes == null) { + throw new SQLException("Parameter types cannot be null"); + } + for (Class parameterType : parameterTypes) { + withParameter(parameterType); + } + return this; + } + + public DuckDBScalarFunctionBuilder withReturnType(DuckDBColumnType returnType) throws SQLException { + ensureNotFinalized(); + if (returnType == null) { + throw new SQLException("Return type cannot be null"); + } + return setMappedReturnType(returnType, null); + } + + public DuckDBScalarFunctionBuilder withParameter(DuckDBColumnType parameterType) throws SQLException { + ensureNotFinalized(); + if (parameterType == null) { + throw new SQLException("Parameter type cannot be null"); + } + return addMappedParameterType(parameterType, null); + } + + public DuckDBScalarFunctionBuilder withVectorizedFunction(DuckDBScalarFunction function) throws SQLException { + ensureNotFinalized(); + if (function == null) { + throw new SQLException("Scalar function callback cannot be null"); + } + return setCallback(function, false); + } + + public DuckDBScalarFunctionBuilder withIntFunction(IntUnaryOperator function) throws SQLException { + ensureNotFinalized(); + if (function == null) { + throw new SQLException("Scalar function callback cannot be null"); + } + enablePrimitiveNullPropagation(); + ensurePrimitiveCallbackCompatible("withIntFunction"); + ensureUnaryPrimitiveSignature(DuckDBColumnType.INTEGER, "withIntFunction"); + return setCallback(DuckDBScalarFunctionAdapter.intUnary(function), true); + } + + public DuckDBScalarFunctionBuilder withIntFunction(IntBinaryOperator function) throws SQLException { + ensureNotFinalized(); + if (function == null) { + throw new SQLException("Scalar function callback cannot be null"); + } + enablePrimitiveNullPropagation(); + ensurePrimitiveCallbackCompatible("withIntFunction"); + ensureBinaryPrimitiveSignature(DuckDBColumnType.INTEGER, "withIntFunction"); + return setCallback(DuckDBScalarFunctionAdapter.intBinary(function), true); + } + + public DuckDBScalarFunctionBuilder withDoubleFunction(DoubleUnaryOperator function) throws SQLException { + ensureNotFinalized(); + if (function == null) { + throw new SQLException("Scalar function callback cannot be null"); + } + enablePrimitiveNullPropagation(); + ensurePrimitiveCallbackCompatible("withDoubleFunction"); + ensureUnaryPrimitiveSignature(DuckDBColumnType.DOUBLE, "withDoubleFunction"); + return setCallback(DuckDBScalarFunctionAdapter.doubleUnary(function), true); + } + + public DuckDBScalarFunctionBuilder withDoubleFunction(DoubleBinaryOperator function) throws SQLException { + ensureNotFinalized(); + if (function == null) { + throw new SQLException("Scalar function callback cannot be null"); + } + enablePrimitiveNullPropagation(); + ensurePrimitiveCallbackCompatible("withDoubleFunction"); + ensureBinaryPrimitiveSignature(DuckDBColumnType.DOUBLE, "withDoubleFunction"); + return setCallback(DuckDBScalarFunctionAdapter.doubleBinary(function), true); + } + + public DuckDBScalarFunctionBuilder withLongFunction(LongUnaryOperator function) throws SQLException { + ensureNotFinalized(); + if (function == null) { + throw new SQLException("Scalar function callback cannot be null"); + } + enablePrimitiveNullPropagation(); + ensurePrimitiveCallbackCompatible("withLongFunction"); + ensureUnaryPrimitiveSignature(DuckDBColumnType.BIGINT, "withLongFunction"); + return setCallback(DuckDBScalarFunctionAdapter.longUnary(function), true); + } + + public DuckDBScalarFunctionBuilder withLongFunction(LongBinaryOperator function) throws SQLException { + ensureNotFinalized(); + if (function == null) { + throw new SQLException("Scalar function callback cannot be null"); + } + enablePrimitiveNullPropagation(); + ensurePrimitiveCallbackCompatible("withLongFunction"); + ensureBinaryPrimitiveSignature(DuckDBColumnType.BIGINT, "withLongFunction"); + return setCallback(DuckDBScalarFunctionAdapter.longBinary(function), true); + } + + public DuckDBScalarFunctionBuilder withFunction(Function function) + throws SQLException { + ensureNotFinalized(); + if (function == null) { + throw new SQLException("Scalar function callback cannot be null"); + } + if (varArgType != null) { + throw new SQLException("Function callback does not support varargs; use withVarArgsFunction instead"); + } + if (parameterTypes.size() != 1) { + throw new SQLException("Function callback requires exactly 1 declared parameter"); + } + DuckDBColumnType parameterType = effectiveParameterType(0); + Class parameterJavaType = effectiveParameterJavaType(0); + DuckDBColumnType resolvedReturnType = effectiveReturnType(); + Class resolvedReturnJavaType = effectiveReturnJavaType(); + return withVectorizedFunction(DuckDBScalarFunctionAdapter.unary(function, parameterType, parameterJavaType, + resolvedReturnType, resolvedReturnJavaType)); + } + + public DuckDBScalarFunctionBuilder withFunction(BiFunction function) + throws SQLException { + ensureNotFinalized(); + if (function == null) { + throw new SQLException("Scalar function callback cannot be null"); + } + if (varArgType != null) { + throw new SQLException("BiFunction callback does not support varargs; use withVarArgsFunction instead"); + } + if (parameterTypes.size() != 2) { + throw new SQLException("BiFunction callback requires exactly 2 declared parameters"); + } + DuckDBColumnType leftType = effectiveParameterType(0); + Class leftJavaType = effectiveParameterJavaType(0); + DuckDBColumnType rightType = effectiveParameterType(1); + Class rightJavaType = effectiveParameterJavaType(1); + DuckDBColumnType resolvedReturnType = effectiveReturnType(); + Class resolvedReturnJavaType = effectiveReturnJavaType(); + return withVectorizedFunction(DuckDBScalarFunctionAdapter.binary( + function, leftType, leftJavaType, rightType, rightJavaType, resolvedReturnType, resolvedReturnJavaType)); + } + + public DuckDBScalarFunctionBuilder withFunction(Supplier function) throws SQLException { + ensureNotFinalized(); + if (function == null) { + throw new SQLException("Scalar function callback cannot be null"); + } + if (!parameterTypes.isEmpty()) { + throw new SQLException("Supplier callback requires zero declared parameters"); + } + if (varArgType != null) { + throw new SQLException("Supplier callback does not support varargs"); + } + DuckDBColumnType resolvedReturnType = effectiveReturnType(); + Class resolvedReturnJavaType = effectiveReturnJavaType(); + return withVectorizedFunction( + DuckDBScalarFunctionAdapter.nullary(function, resolvedReturnType, resolvedReturnJavaType)); + } + + public DuckDBScalarFunctionBuilder withVarArgsFunction(Function function) throws SQLException { + ensureNotFinalized(); + if (function == null) { + throw new SQLException("Scalar function callback cannot be null"); + } + if (varArgType == null) { + throw new SQLException("Varargs functional callback requires withVarArgs(...) declaration"); + } + DuckDBColumnType[] fixedTypes = effectiveFixedParameterTypes(); + Class[] fixedJavaTypes = effectiveFixedParameterJavaTypes(); + DuckDBColumnType varArgColumnType = DuckDBScalarFunctionAdapter.mapLogicalTypeToDuckDBType(varArgType); + DuckDBColumnType resolvedReturnType = effectiveReturnType(); + Class resolvedReturnJavaType = effectiveReturnJavaType(); + return withVectorizedFunction(DuckDBScalarFunctionAdapter.variadic( + function, fixedTypes, fixedJavaTypes, varArgColumnType, null, resolvedReturnType, resolvedReturnJavaType)); + } + + public DuckDBScalarFunctionBuilder withVarArgs(DuckDBLogicalType varArgType) throws SQLException { + ensureNotFinalized(); + if (varArgType == null) { + throw new SQLException("Varargs type cannot be null"); + } + this.varArgType = varArgType; + duckdb_scalar_function_set_varargs(scalarFunctionRef, varArgType.logicalTypeRef()); + return this; + } + + public DuckDBScalarFunctionBuilder withVolatile() throws SQLException { + ensureNotFinalized(); + this.volatileFlag = true; + duckdb_scalar_function_set_volatile(scalarFunctionRef); + return this; + } + + public DuckDBScalarFunctionBuilder withSpecialHandling() throws SQLException { + ensureNotFinalized(); + this.specialHandlingFlag = true; + duckdb_scalar_function_set_special_handling(scalarFunctionRef); + return this; + } + + public DuckDBRegisteredFunction register(Connection connection) throws SQLException { + ensureNotFinalized(); + if (connection == null) { + throw new SQLException("Connection cannot be null"); + } + if (functionName == null) { + throw new SQLException("Function name must be defined"); + } + if (returnType == null && returnColumnType == null) { + throw new SQLException("Return type must be defined"); + } + if (callback == null) { + throw new SQLException("Scalar function callback must be defined"); + } + DuckDBConnection duckConnection = unwrapConnection(connection); + Lock connectionLock = duckConnection.connRefLock; + connectionLock.lock(); + try { + duckConnection.checkOpen(); + int status = duckdb_register_scalar_function(duckConnection.connRef, scalarFunctionRef); + if (status != 0) { + throw new SQLException("Failed to register scalar function '" + functionName + "'"); + } + DuckDBRegisteredFunction registeredFunction = DuckDBRegisteredFunction.of( + functionName, parameterTypes, parameterColumnTypes, returnType, returnColumnType, callback, varArgType, + volatileFlag, specialHandlingFlag, propagateNullsFlag); + DuckDBDriver.registerFunction(registeredFunction); + return registeredFunction; + } finally { + connectionLock.unlock(); + close(); + } + } + + @Override + public void close() { + if (scalarFunctionRef != null) { + duckdb_destroy_scalar_function(scalarFunctionRef); + scalarFunctionRef = null; + } + finalized = true; + } + + private void ensureNotFinalized() throws SQLException { + if (finalized || scalarFunctionRef == null) { + throw new SQLException("Scalar function builder is already finalized"); + } + } + + private DuckDBColumnType effectiveParameterType(int index) throws SQLException { + DuckDBColumnType parameterColumnType = parameterColumnTypes.get(index); + if (parameterColumnType != null) { + return parameterColumnType; + } + DuckDBLogicalType parameterLogicalType = parameterTypes.get(index); + return DuckDBScalarFunctionAdapter.mapLogicalTypeToDuckDBType(parameterLogicalType); + } + + private DuckDBColumnType effectiveReturnType() throws SQLException { + if (returnColumnType != null) { + return returnColumnType; + } + if (returnType != null) { + return DuckDBScalarFunctionAdapter.mapLogicalTypeToDuckDBType(returnType); + } + throw new SQLException("Return type must be defined before functional callback"); + } + + private Class effectiveParameterJavaType(int index) { + return parameterJavaTypes.get(index); + } + + private Class effectiveReturnJavaType() { + return returnJavaType; + } + + private DuckDBColumnType[] effectiveFixedParameterTypes() throws SQLException { + DuckDBColumnType[] fixedTypes = new DuckDBColumnType[parameterTypes.size()]; + for (int i = 0; i < fixedTypes.length; i++) { + fixedTypes[i] = effectiveParameterType(i); + } + return fixedTypes; + } + + private Class[] effectiveFixedParameterJavaTypes() { + return parameterJavaTypes.toArray(new Class[ 0 ]); + } + + private DuckDBScalarFunctionBuilder setMappedReturnType(DuckDBColumnType mappedType, Class javaType) + throws SQLException { + this.returnType = null; + this.returnColumnType = mappedType; + this.returnJavaType = javaType; + try (DuckDBLogicalType logicalType = DuckDBLogicalType.of(mappedType)) { + duckdb_scalar_function_set_return_type(scalarFunctionRef, logicalType.logicalTypeRef()); + } + return this; + } + + private DuckDBScalarFunctionBuilder addMappedParameterType(DuckDBColumnType mappedType, Class javaType) + throws SQLException { + parameterTypes.add(null); + parameterColumnTypes.add(mappedType); + parameterJavaTypes.add(javaType); + try (DuckDBLogicalType logicalType = DuckDBLogicalType.of(mappedType)) { + duckdb_scalar_function_add_parameter(scalarFunctionRef, logicalType.logicalTypeRef()); + } + return this; + } + + private DuckDBScalarFunctionBuilder setCallback(DuckDBScalarFunction function, boolean requiresNullPropagation) + throws SQLException { + this.callback = function; + this.propagateNullsFlag = requiresNullPropagation; + duckdb_scalar_function_set_function(scalarFunctionRef, + new DuckDBScalarFunctionWrapper(function, propagateNullsFlag)); + return this; + } + + private void ensurePrimitiveCallbackCompatible(String callbackMethodName) throws SQLException { + if (varArgType != null) { + throw new SQLException(callbackMethodName + " does not support varargs; use withVarArgsFunction instead"); + } + } + + private void enablePrimitiveNullPropagation() { + propagateNullsFlag = true; + } + + private void ensureUnaryPrimitiveSignature(DuckDBColumnType expectedType, String callbackMethodName) + throws SQLException { + if (parameterTypes.size() != 1) { + throw new SQLException(callbackMethodName + " requires exactly 1 declared parameter"); + } + ensurePrimitiveParameterType(0, expectedType, callbackMethodName); + ensurePrimitiveReturnType(expectedType, callbackMethodName); + } + + private void ensureBinaryPrimitiveSignature(DuckDBColumnType expectedType, String callbackMethodName) + throws SQLException { + if (parameterTypes.size() != 2) { + throw new SQLException(callbackMethodName + " requires exactly 2 declared parameters"); + } + ensurePrimitiveParameterType(0, expectedType, callbackMethodName); + ensurePrimitiveParameterType(1, expectedType, callbackMethodName); + ensurePrimitiveReturnType(expectedType, callbackMethodName); + } + + private void ensurePrimitiveParameterType(int index, DuckDBColumnType expectedType, String callbackMethodName) + throws SQLException { + DuckDBColumnType actualType = effectiveParameterType(index); + if (actualType != expectedType) { + throw new SQLException(callbackMethodName + " requires parameter " + index + " to be " + expectedType + + ", got " + actualType); + } + } + + private void ensurePrimitiveReturnType(DuckDBColumnType expectedType, String callbackMethodName) + throws SQLException { + DuckDBColumnType actualType = effectiveReturnType(); + if (actualType != expectedType) { + throw new SQLException(callbackMethodName + " requires return type " + expectedType + ", got " + + actualType); + } + } + + private static DuckDBConnection unwrapConnection(Connection connection) throws SQLException { + try { + return connection.unwrap(DuckDBConnection.class); + } catch (SQLException exception) { + throw new SQLException("Scalar function registration requires a DuckDB JDBC connection", exception); + } + } +} diff --git a/src/main/java/org/duckdb/DuckDBScalarFunctionWrapper.java b/src/main/java/org/duckdb/DuckDBScalarFunctionWrapper.java new file mode 100644 index 000000000..1bfc7f2d1 --- /dev/null +++ b/src/main/java/org/duckdb/DuckDBScalarFunctionWrapper.java @@ -0,0 +1,35 @@ +package org.duckdb; + +import static java.nio.charset.StandardCharsets.UTF_8; + +import java.nio.ByteBuffer; + +final class DuckDBScalarFunctionWrapper { + private final DuckDBScalarFunction function; + private final boolean propagateNulls; + + DuckDBScalarFunctionWrapper(DuckDBScalarFunction function, boolean propagateNulls) { + this.function = function; + this.propagateNulls = propagateNulls; + } + + public void execute(ByteBuffer functionInfo, ByteBuffer inputChunk, ByteBuffer outputVector) { + try { + DuckDBDataChunkReader inputReader = new DuckDBDataChunkReader(inputChunk); + DuckDBWritableVector outputWriter = new DuckDBWritableVectorImpl(outputVector, inputReader.rowCount()); + DuckDBScalarContext context = new DuckDBScalarContext(inputReader, outputWriter, propagateNulls); + function.apply(context); + } catch (Throwable throwable) { + reportError(functionInfo, throwable); + } + } + + private static void reportError(ByteBuffer functionInfo, Throwable throwable) { + String message = throwable.getMessage(); + String className = throwable.getClass().getName(); + String formatted = + message == null || message.isEmpty() ? className : String.format("%s: %s", className, message); + String error = "Java scalar function threw exception: " + formatted; + DuckDBBindings.duckdb_scalar_function_set_error(functionInfo, error.getBytes(UTF_8)); + } +} diff --git a/src/main/java/org/duckdb/DuckDBScalarRow.java b/src/main/java/org/duckdb/DuckDBScalarRow.java new file mode 100644 index 000000000..7b51cf29e --- /dev/null +++ b/src/main/java/org/duckdb/DuckDBScalarRow.java @@ -0,0 +1,471 @@ +package org.duckdb; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.sql.Timestamp; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.OffsetDateTime; + +public final class DuckDBScalarRow { + private final DuckDBScalarContext context; + private final long rowIndex; + + DuckDBScalarRow(DuckDBScalarContext context, long rowIndex) { + this.context = context; + this.rowIndex = rowIndex; + } + + public long index() { + return rowIndex; + } + + public boolean isNull(int columnIndex) { + return input(columnIndex).isNull(rowIndex); + } + + public boolean getBoolean(int columnIndex) { + try { + return input(columnIndex).getBoolean(rowIndex); + } catch (DuckDBFunctionException exception) { + throw readFailure("BOOLEAN", columnIndex, exception); + } + } + + public boolean getBoolean(int columnIndex, boolean defaultVal) { + try { + return input(columnIndex).getBoolean(rowIndex, defaultVal); + } catch (DuckDBFunctionException exception) { + throw readFailure("BOOLEAN", columnIndex, exception); + } + } + + public byte getByte(int columnIndex) { + try { + return input(columnIndex).getByte(rowIndex); + } catch (DuckDBFunctionException exception) { + throw readFailure("TINYINT", columnIndex, exception); + } + } + + public byte getByte(int columnIndex, byte defaultVal) { + try { + return input(columnIndex).getByte(rowIndex, defaultVal); + } catch (DuckDBFunctionException exception) { + throw readFailure("TINYINT", columnIndex, exception); + } + } + + public short getShort(int columnIndex) { + try { + return input(columnIndex).getShort(rowIndex); + } catch (DuckDBFunctionException exception) { + throw readFailure("SMALLINT", columnIndex, exception); + } + } + + public short getShort(int columnIndex, short defaultVal) { + try { + return input(columnIndex).getShort(rowIndex, defaultVal); + } catch (DuckDBFunctionException exception) { + throw readFailure("SMALLINT", columnIndex, exception); + } + } + + public short getUint8(int columnIndex) { + try { + return input(columnIndex).getUint8(rowIndex); + } catch (DuckDBFunctionException exception) { + throw readFailure("UTINYINT", columnIndex, exception); + } + } + + public short getUint8(int columnIndex, short defaultVal) { + try { + return input(columnIndex).getUint8(rowIndex, defaultVal); + } catch (DuckDBFunctionException exception) { + throw readFailure("UTINYINT", columnIndex, exception); + } + } + + public int getUint16(int columnIndex) { + try { + return input(columnIndex).getUint16(rowIndex); + } catch (DuckDBFunctionException exception) { + throw readFailure("USMALLINT", columnIndex, exception); + } + } + + public int getUint16(int columnIndex, int defaultVal) { + try { + return input(columnIndex).getUint16(rowIndex, defaultVal); + } catch (DuckDBFunctionException exception) { + throw readFailure("USMALLINT", columnIndex, exception); + } + } + + public int getInt(int columnIndex) { + try { + return input(columnIndex).getInt(rowIndex); + } catch (DuckDBFunctionException exception) { + throw readFailure("INTEGER", columnIndex, exception); + } + } + + public int getInt(int columnIndex, int defaultVal) { + try { + return input(columnIndex).getInt(rowIndex, defaultVal); + } catch (DuckDBFunctionException exception) { + throw readFailure("INTEGER", columnIndex, exception); + } + } + + public long getUint32(int columnIndex) { + try { + return input(columnIndex).getUint32(rowIndex); + } catch (DuckDBFunctionException exception) { + throw readFailure("UINTEGER", columnIndex, exception); + } + } + + public long getUint32(int columnIndex, long defaultVal) { + try { + return input(columnIndex).getUint32(rowIndex, defaultVal); + } catch (DuckDBFunctionException exception) { + throw readFailure("UINTEGER", columnIndex, exception); + } + } + + public long getLong(int columnIndex) { + try { + return input(columnIndex).getLong(rowIndex); + } catch (DuckDBFunctionException exception) { + throw readFailure("BIGINT", columnIndex, exception); + } + } + + public long getLong(int columnIndex, long defaultVal) { + try { + return input(columnIndex).getLong(rowIndex, defaultVal); + } catch (DuckDBFunctionException exception) { + throw readFailure("BIGINT", columnIndex, exception); + } + } + + public BigInteger getHugeInt(int columnIndex) { + try { + return input(columnIndex).getHugeInt(rowIndex); + } catch (DuckDBFunctionException exception) { + throw readFailure("HUGEINT", columnIndex, exception); + } + } + + public BigInteger getUHugeInt(int columnIndex) { + try { + return input(columnIndex).getUHugeInt(rowIndex); + } catch (DuckDBFunctionException exception) { + throw readFailure("UHUGEINT", columnIndex, exception); + } + } + + public BigInteger getUint64(int columnIndex) { + try { + return input(columnIndex).getUint64(rowIndex); + } catch (DuckDBFunctionException exception) { + throw readFailure("UBIGINT", columnIndex, exception); + } + } + + public float getFloat(int columnIndex) { + try { + return input(columnIndex).getFloat(rowIndex); + } catch (DuckDBFunctionException exception) { + throw readFailure("FLOAT", columnIndex, exception); + } + } + + public float getFloat(int columnIndex, float defaultVal) { + try { + return input(columnIndex).getFloat(rowIndex, defaultVal); + } catch (DuckDBFunctionException exception) { + throw readFailure("FLOAT", columnIndex, exception); + } + } + + public double getDouble(int columnIndex) { + try { + return input(columnIndex).getDouble(rowIndex); + } catch (DuckDBFunctionException exception) { + throw readFailure("DOUBLE", columnIndex, exception); + } + } + + public double getDouble(int columnIndex, double defaultVal) { + try { + return input(columnIndex).getDouble(rowIndex, defaultVal); + } catch (DuckDBFunctionException exception) { + throw readFailure("DOUBLE", columnIndex, exception); + } + } + + public LocalDate getLocalDate(int columnIndex) { + try { + return input(columnIndex).getLocalDate(rowIndex); + } catch (DuckDBFunctionException exception) { + throw readFailure("DATE", columnIndex, exception); + } + } + + public java.sql.Date getDate(int columnIndex) { + try { + return input(columnIndex).getDate(rowIndex); + } catch (DuckDBFunctionException exception) { + throw readFailure("DATE", columnIndex, exception); + } + } + + public LocalDateTime getLocalDateTime(int columnIndex) { + try { + return input(columnIndex).getLocalDateTime(rowIndex); + } catch (DuckDBFunctionException exception) { + throw readFailure("TIMESTAMP", columnIndex, exception); + } + } + + public Timestamp getTimestamp(int columnIndex) { + try { + return input(columnIndex).getTimestamp(rowIndex); + } catch (DuckDBFunctionException exception) { + throw readFailure("TIMESTAMP", columnIndex, exception); + } + } + + public OffsetDateTime getOffsetDateTime(int columnIndex) { + try { + return input(columnIndex).getOffsetDateTime(rowIndex); + } catch (DuckDBFunctionException exception) { + throw readFailure("TIMESTAMP WITH TIME ZONE", columnIndex, exception); + } + } + + public BigDecimal getBigDecimal(int columnIndex) { + try { + return input(columnIndex).getBigDecimal(rowIndex); + } catch (DuckDBFunctionException exception) { + throw readFailure("DECIMAL", columnIndex, exception); + } + } + + public String getString(int columnIndex) { + try { + return input(columnIndex).getString(rowIndex); + } catch (DuckDBFunctionException exception) { + throw readFailure("VARCHAR", columnIndex, exception); + } + } + + public void setNull() { + try { + context.output().setNull(rowIndex); + } catch (DuckDBFunctionException exception) { + throw writeFailure("NULL", exception); + } + } + + public void setBoolean(boolean value) { + try { + context.output().setBoolean(rowIndex, value); + } catch (DuckDBFunctionException exception) { + throw writeFailure("BOOLEAN", exception); + } + } + + public void setByte(byte value) { + try { + context.output().setByte(rowIndex, value); + } catch (DuckDBFunctionException exception) { + throw writeFailure("TINYINT", exception); + } + } + + public void setShort(short value) { + try { + context.output().setShort(rowIndex, value); + } catch (DuckDBFunctionException exception) { + throw writeFailure("SMALLINT", exception); + } + } + + public void setUint8(int value) { + try { + context.output().setUint8(rowIndex, value); + } catch (DuckDBFunctionException exception) { + throw writeFailure("UTINYINT", exception); + } + } + + public void setUint16(int value) { + try { + context.output().setUint16(rowIndex, value); + } catch (DuckDBFunctionException exception) { + throw writeFailure("USMALLINT", exception); + } + } + + public void setInt(int value) { + try { + context.output().setInt(rowIndex, value); + } catch (DuckDBFunctionException exception) { + throw writeFailure("INTEGER", exception); + } + } + + public void setUint32(long value) { + try { + context.output().setUint32(rowIndex, value); + } catch (DuckDBFunctionException exception) { + throw writeFailure("UINTEGER", exception); + } + } + + public void setLong(long value) { + try { + context.output().setLong(rowIndex, value); + } catch (DuckDBFunctionException exception) { + throw writeFailure("BIGINT", exception); + } + } + + public void setHugeInt(BigInteger value) { + try { + context.output().setHugeInt(rowIndex, value); + } catch (DuckDBFunctionException exception) { + throw writeFailure("HUGEINT", exception); + } + } + + public void setUHugeInt(BigInteger value) { + try { + context.output().setUHugeInt(rowIndex, value); + } catch (DuckDBFunctionException exception) { + throw writeFailure("UHUGEINT", exception); + } + } + + public void setUint64(BigInteger value) { + try { + context.output().setUint64(rowIndex, value); + } catch (DuckDBFunctionException exception) { + throw writeFailure("UBIGINT", exception); + } + } + + public void setFloat(float value) { + try { + context.output().setFloat(rowIndex, value); + } catch (DuckDBFunctionException exception) { + throw writeFailure("FLOAT", exception); + } + } + + public void setDouble(double value) { + try { + context.output().setDouble(rowIndex, value); + } catch (DuckDBFunctionException exception) { + throw writeFailure("DOUBLE", exception); + } + } + + public void setDate(LocalDate value) { + try { + context.output().setDate(rowIndex, value); + } catch (DuckDBFunctionException exception) { + throw writeFailure("DATE", exception); + } + } + + public void setDate(java.sql.Date value) { + try { + context.output().setDate(rowIndex, value); + } catch (DuckDBFunctionException exception) { + throw writeFailure("DATE", exception); + } + } + + public void setDate(java.util.Date value) { + try { + context.output().setDate(rowIndex, value); + } catch (DuckDBFunctionException exception) { + throw writeFailure("DATE", exception); + } + } + + public void setTimestamp(LocalDateTime value) { + try { + context.output().setTimestamp(rowIndex, value); + } catch (DuckDBFunctionException exception) { + throw writeFailure("TIMESTAMP", exception); + } + } + + public void setTimestamp(Timestamp value) { + try { + context.output().setTimestamp(rowIndex, value); + } catch (DuckDBFunctionException exception) { + throw writeFailure("TIMESTAMP", exception); + } + } + + public void setTimestamp(java.util.Date value) { + try { + context.output().setTimestamp(rowIndex, value); + } catch (DuckDBFunctionException exception) { + throw writeFailure("TIMESTAMP", exception); + } + } + + public void setTimestamp(LocalDate value) { + try { + context.output().setTimestamp(rowIndex, value); + } catch (DuckDBFunctionException exception) { + throw writeFailure("TIMESTAMP", exception); + } + } + + public void setOffsetDateTime(OffsetDateTime value) { + try { + context.output().setOffsetDateTime(rowIndex, value); + } catch (DuckDBFunctionException exception) { + throw writeFailure("TIMESTAMP WITH TIME ZONE", exception); + } + } + + public void setBigDecimal(BigDecimal value) { + try { + context.output().setBigDecimal(rowIndex, value); + } catch (DuckDBFunctionException exception) { + throw writeFailure("DECIMAL", exception); + } + } + + public void setString(String value) { + try { + context.output().setString(rowIndex, value); + } catch (DuckDBFunctionException exception) { + throw writeFailure("VARCHAR", exception); + } + } + + private DuckDBReadableVector input(int columnIndex) { + return context.inputUnchecked(columnIndex); + } + + private DuckDBFunctionException readFailure(String type, int columnIndex, DuckDBFunctionException exception) { + return new DuckDBFunctionException( + "Failed to read " + type + " from input column " + columnIndex + " at row " + rowIndex, exception); + } + + private DuckDBFunctionException writeFailure(String type, DuckDBFunctionException exception) { + return new DuckDBFunctionException("Failed to write " + type + " to output row " + rowIndex, exception); + } +} diff --git a/src/main/java/org/duckdb/DuckDBVectorTypeInfo.java b/src/main/java/org/duckdb/DuckDBVectorTypeInfo.java new file mode 100644 index 000000000..a7768cc96 --- /dev/null +++ b/src/main/java/org/duckdb/DuckDBVectorTypeInfo.java @@ -0,0 +1,91 @@ +package org.duckdb; + +import static org.duckdb.DuckDBBindings.*; + +import java.nio.ByteBuffer; +import java.sql.SQLException; + +final class DuckDBVectorTypeInfo { + final DuckDBColumnType columnType; + final DuckDBBindings.CAPIType capiType; + final DuckDBBindings.CAPIType storageType; + final int widthBytes; + final DuckDBColumnTypeMetaData decimalMeta; + + private DuckDBVectorTypeInfo(DuckDBColumnType columnType, DuckDBBindings.CAPIType capiType, + DuckDBBindings.CAPIType storageType, int widthBytes, + DuckDBColumnTypeMetaData decimalMeta) { + this.columnType = columnType; + this.capiType = capiType; + this.storageType = storageType; + this.widthBytes = widthBytes; + this.decimalMeta = decimalMeta; + } + + static DuckDBVectorTypeInfo fromVector(ByteBuffer vectorRef) throws SQLException { + ByteBuffer logicalType = duckdb_vector_get_column_type(vectorRef); + if (logicalType == null) { + throw new SQLException("Cannot read vector type"); + } + + try { + DuckDBBindings.CAPIType capiType = + DuckDBBindings.CAPIType.capiTypeFromTypeId(duckdb_get_type_id(logicalType)); + switch (capiType) { + case DUCKDB_TYPE_BOOLEAN: + return new DuckDBVectorTypeInfo(DuckDBColumnType.BOOLEAN, capiType, capiType, 1, null); + case DUCKDB_TYPE_TINYINT: + return new DuckDBVectorTypeInfo(DuckDBColumnType.TINYINT, capiType, capiType, 1, null); + case DUCKDB_TYPE_UTINYINT: + return new DuckDBVectorTypeInfo(DuckDBColumnType.UTINYINT, capiType, capiType, 1, null); + case DUCKDB_TYPE_SMALLINT: + return new DuckDBVectorTypeInfo(DuckDBColumnType.SMALLINT, capiType, capiType, 2, null); + case DUCKDB_TYPE_USMALLINT: + return new DuckDBVectorTypeInfo(DuckDBColumnType.USMALLINT, capiType, capiType, 2, null); + case DUCKDB_TYPE_INTEGER: + return new DuckDBVectorTypeInfo(DuckDBColumnType.INTEGER, capiType, capiType, 4, null); + case DUCKDB_TYPE_UINTEGER: + return new DuckDBVectorTypeInfo(DuckDBColumnType.UINTEGER, capiType, capiType, 4, null); + case DUCKDB_TYPE_BIGINT: + return new DuckDBVectorTypeInfo(DuckDBColumnType.BIGINT, capiType, capiType, 8, null); + case DUCKDB_TYPE_HUGEINT: + return new DuckDBVectorTypeInfo(DuckDBColumnType.HUGEINT, capiType, capiType, 16, null); + case DUCKDB_TYPE_UBIGINT: + return new DuckDBVectorTypeInfo(DuckDBColumnType.UBIGINT, capiType, capiType, 8, null); + case DUCKDB_TYPE_UHUGEINT: + return new DuckDBVectorTypeInfo(DuckDBColumnType.UHUGEINT, capiType, capiType, 16, null); + case DUCKDB_TYPE_FLOAT: + return new DuckDBVectorTypeInfo(DuckDBColumnType.FLOAT, capiType, capiType, 4, null); + case DUCKDB_TYPE_DOUBLE: + return new DuckDBVectorTypeInfo(DuckDBColumnType.DOUBLE, capiType, capiType, 8, null); + case DUCKDB_TYPE_DATE: + return new DuckDBVectorTypeInfo(DuckDBColumnType.DATE, capiType, capiType, 4, null); + case DUCKDB_TYPE_TIMESTAMP_S: + return new DuckDBVectorTypeInfo(DuckDBColumnType.TIMESTAMP_S, capiType, capiType, 8, null); + case DUCKDB_TYPE_TIMESTAMP_MS: + return new DuckDBVectorTypeInfo(DuckDBColumnType.TIMESTAMP_MS, capiType, capiType, 8, null); + case DUCKDB_TYPE_TIMESTAMP: + return new DuckDBVectorTypeInfo(DuckDBColumnType.TIMESTAMP, capiType, capiType, 8, null); + case DUCKDB_TYPE_TIMESTAMP_NS: + return new DuckDBVectorTypeInfo(DuckDBColumnType.TIMESTAMP_NS, capiType, capiType, 8, null); + case DUCKDB_TYPE_TIMESTAMP_TZ: + return new DuckDBVectorTypeInfo(DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE, capiType, capiType, 8, null); + case DUCKDB_TYPE_VARCHAR: + return new DuckDBVectorTypeInfo(DuckDBColumnType.VARCHAR, capiType, capiType, 16, null); + case DUCKDB_TYPE_DECIMAL: { + DuckDBBindings.CAPIType internalType = + DuckDBBindings.CAPIType.capiTypeFromTypeId(duckdb_decimal_internal_type(logicalType)); + DuckDBColumnTypeMetaData decimalMeta = new DuckDBColumnTypeMetaData( + (short) (internalType.widthBytes * 8), (short) duckdb_decimal_width(logicalType), + (short) duckdb_decimal_scale(logicalType)); + return new DuckDBVectorTypeInfo(DuckDBColumnType.DECIMAL, capiType, internalType, + (int) internalType.widthBytes, decimalMeta); + } + default: + throw new SQLException("Unsupported scalar function vector type: " + capiType); + } + } finally { + duckdb_destroy_logical_type(logicalType); + } + } +} diff --git a/src/main/java/org/duckdb/DuckDBWritableVector.java b/src/main/java/org/duckdb/DuckDBWritableVector.java new file mode 100644 index 000000000..1179e9084 --- /dev/null +++ b/src/main/java/org/duckdb/DuckDBWritableVector.java @@ -0,0 +1,116 @@ +package org.duckdb; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.sql.Timestamp; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.OffsetDateTime; + +/** + * Mutable scalar callback view over a DuckDB output vector. + * + *

Implementations throw {@link DuckDBFunctionException} for callback-time type/value write + * errors. Invalid row indexes throw {@link IndexOutOfBoundsException}. + */ +public abstract class DuckDBWritableVector { + public abstract DuckDBColumnType getType(); + + public abstract long rowCount(); + + public abstract void addNull(); + + public abstract void setNull(long row); + + public abstract void addBoolean(boolean value); + + public abstract void setBoolean(long row, boolean value); + + public abstract void addByte(byte value); + + public abstract void setByte(long row, byte value); + + public abstract void addShort(short value); + + public abstract void setShort(long row, short value); + + public abstract void addUint8(int value); + + public abstract void setUint8(long row, int value); + + public abstract void addUint16(int value); + + public abstract void setUint16(long row, int value); + + public abstract void addInt(int value); + + public abstract void setInt(long row, int value); + + public abstract void addUint32(long value); + + public abstract void setUint32(long row, long value); + + public abstract void addLong(long value); + + public abstract void setLong(long row, long value); + + public abstract void addHugeInt(BigInteger value); + + public abstract void setHugeInt(long row, BigInteger value); + + public abstract void addUHugeInt(BigInteger value); + + public abstract void setUHugeInt(long row, BigInteger value); + + public abstract void addUint64(BigInteger value); + + public abstract void setUint64(long row, BigInteger value); + + public abstract void addFloat(float value); + + public abstract void setFloat(long row, float value); + + public abstract void addDouble(double value); + + public abstract void setDouble(long row, double value); + + public abstract void addDate(LocalDate value); + + public abstract void setDate(long row, LocalDate value); + + public abstract void addDate(java.sql.Date value); + + public abstract void setDate(long row, java.sql.Date value); + + public abstract void addDate(java.util.Date value); + + public abstract void setDate(long row, java.util.Date value); + + public abstract void addTimestamp(LocalDateTime value); + + public abstract void setTimestamp(long row, LocalDateTime value); + + public abstract void addTimestamp(Timestamp value); + + public abstract void setTimestamp(long row, Timestamp value); + + public abstract void addTimestamp(java.util.Date value); + + public abstract void setTimestamp(long row, java.util.Date value); + + public abstract void addTimestamp(LocalDate value); + + public abstract void setTimestamp(long row, LocalDate value); + + public abstract void addOffsetDateTime(OffsetDateTime value); + + public abstract void setOffsetDateTime(long row, OffsetDateTime value); + + public abstract void addBigDecimal(BigDecimal value); + + public abstract void setBigDecimal(long row, BigDecimal value); + + public abstract void addString(String value); + + public abstract void setString(long row, String value); +} diff --git a/src/main/java/org/duckdb/DuckDBWritableVectorImpl.java b/src/main/java/org/duckdb/DuckDBWritableVectorImpl.java new file mode 100644 index 000000000..33887abb7 --- /dev/null +++ b/src/main/java/org/duckdb/DuckDBWritableVectorImpl.java @@ -0,0 +1,765 @@ +package org.duckdb; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.duckdb.DuckDBBindings.*; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.sql.Timestamp; +import java.time.Instant; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.OffsetDateTime; +import java.time.ZoneId; +import java.time.ZoneOffset; + +final class DuckDBWritableVectorImpl extends DuckDBWritableVector { + private static final BigInteger UINT64_MAX = new BigInteger("18446744073709551615"); + private static final ByteOrder NATIVE_ORDER = ByteOrder.nativeOrder(); + + private final ByteBuffer vectorRef; + private final long rowCount; + private final DuckDBVectorTypeInfo typeInfo; + private final ByteBuffer data; + private ByteBuffer validity; + private long appendIndex; + + DuckDBWritableVectorImpl(ByteBuffer vectorRef, long rowCount) { + if (vectorRef == null) { + throw new DuckDBFunctionException("Invalid vector reference"); + } + this.vectorRef = vectorRef; + this.rowCount = rowCount; + try { + this.typeInfo = DuckDBVectorTypeInfo.fromVector(vectorRef); + } catch (java.sql.SQLException exception) { + throw new DuckDBFunctionException("Failed to resolve vector type info", exception); + } + this.data = + duckdb_vector_get_data(vectorRef, Math.multiplyExact(rowCount, typeInfo.widthBytes)).order(NATIVE_ORDER); + ByteBuffer validityBuffer = duckdb_vector_get_validity(vectorRef, rowCount); + this.validity = validityBuffer == null ? null : validityBuffer.order(NATIVE_ORDER); + } + + @Override + public DuckDBColumnType getType() { + return typeInfo.columnType; + } + + @Override + public long rowCount() { + return rowCount; + } + + @Override + public void addNull() { + setNull(nextAppendRow()); + } + + @Override + public void setNull(long row) { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + ensureValidity(); + setRowValidity(row, false); + advanceAppendIndex(row); + } + + @Override + public void addBoolean(boolean value) { + setBoolean(nextAppendRow(), value); + } + + @Override + public void setBoolean(long row, boolean value) { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = typeMismatchMessage(DuckDBColumnType.BOOLEAN); + if (typeError != null) { + throw new DuckDBFunctionException(typeError); + } + data.put(checkedRowIndex(row), value ? (byte) 1 : (byte) 0); + markValid(row); + } + + @Override + public void addByte(byte value) { + setByte(nextAppendRow(), value); + } + + @Override + public void setByte(long row, byte value) { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = typeMismatchMessage(DuckDBColumnType.TINYINT); + if (typeError != null) { + throw new DuckDBFunctionException(typeError); + } + data.put(checkedRowIndex(row), value); + markValid(row); + } + + @Override + public void addShort(short value) { + setShort(nextAppendRow(), value); + } + + @Override + public void setShort(long row, short value) { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = typeMismatchMessage(DuckDBColumnType.SMALLINT); + if (typeError != null) { + throw new DuckDBFunctionException(typeError); + } + data.putShort(checkedByteOffset(row, Short.BYTES), value); + markValid(row); + } + + @Override + public void addUint8(int value) { + setUint8(nextAppendRow(), value); + } + + @Override + public void setUint8(long row, int value) { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = typeMismatchMessage(DuckDBColumnType.UTINYINT); + if (typeError != null) { + throw new DuckDBFunctionException(typeError); + } + String rangeError = unsignedRangeErrorMessage("UTINYINT", value, 0xFFL); + if (rangeError != null) { + throw new DuckDBFunctionException(rangeError); + } + data.put(checkedRowIndex(row), (byte) value); + markValid(row); + } + + @Override + public void addUint16(int value) { + setUint16(nextAppendRow(), value); + } + + @Override + public void setUint16(long row, int value) { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = typeMismatchMessage(DuckDBColumnType.USMALLINT); + if (typeError != null) { + throw new DuckDBFunctionException(typeError); + } + String rangeError = unsignedRangeErrorMessage("USMALLINT", value, 0xFFFFL); + if (rangeError != null) { + throw new DuckDBFunctionException(rangeError); + } + data.putShort(checkedByteOffset(row, Short.BYTES), (short) value); + markValid(row); + } + + @Override + public void addInt(int value) { + setInt(nextAppendRow(), value); + } + + @Override + public void setInt(long row, int value) { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = typeMismatchMessage(DuckDBColumnType.INTEGER); + if (typeError != null) { + throw new DuckDBFunctionException(typeError); + } + data.putInt(checkedByteOffset(row, Integer.BYTES), value); + markValid(row); + } + + @Override + public void addUint32(long value) { + setUint32(nextAppendRow(), value); + } + + @Override + public void setUint32(long row, long value) { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = typeMismatchMessage(DuckDBColumnType.UINTEGER); + if (typeError != null) { + throw new DuckDBFunctionException(typeError); + } + String rangeError = unsignedRangeErrorMessage("UINTEGER", value, 0xFFFFFFFFL); + if (rangeError != null) { + throw new DuckDBFunctionException(rangeError); + } + data.putInt(checkedByteOffset(row, Integer.BYTES), (int) value); + markValid(row); + } + + @Override + public void addLong(long value) { + setLong(nextAppendRow(), value); + } + + @Override + public void setLong(long row, long value) { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = typeMismatchMessage(DuckDBColumnType.BIGINT); + if (typeError != null) { + throw new DuckDBFunctionException(typeError); + } + data.putLong(checkedByteOffset(row, Long.BYTES), value); + markValid(row); + } + + @Override + public void addHugeInt(BigInteger value) { + setHugeInt(nextAppendRow(), value); + } + + @Override + public void setHugeInt(long row, BigInteger value) { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = typeMismatchMessage(DuckDBColumnType.HUGEINT); + if (typeError != null) { + throw new DuckDBFunctionException(typeError); + } + if (value == null) { + setNull(row); + return; + } + DuckDBHugeInt hugeInt; + try { + hugeInt = new DuckDBHugeInt(value); + } catch (java.sql.SQLException exception) { + throw new DuckDBFunctionException("Value out of range for HUGEINT: " + value, exception); + } + int offset = checkedByteOffset(row, typeInfo.widthBytes); + data.putLong(offset, hugeInt.lower()); + data.putLong(offset + Long.BYTES, hugeInt.upper()); + markValid(row); + } + + @Override + public void addUHugeInt(BigInteger value) { + setUHugeInt(nextAppendRow(), value); + } + + @Override + public void setUHugeInt(long row, BigInteger value) { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = typeMismatchMessage(DuckDBColumnType.UHUGEINT); + if (typeError != null) { + throw new DuckDBFunctionException(typeError); + } + if (value == null) { + setNull(row); + return; + } + if (value.signum() < 0 || value.compareTo(DuckDBHugeInt.UHUGE_INT_MAX) > 0) { + throw new DuckDBFunctionException("Value out of range for UHUGEINT: " + value); + } + int offset = checkedByteOffset(row, typeInfo.widthBytes); + data.putLong(offset, value.longValue()); + data.putLong(offset + Long.BYTES, value.shiftRight(Long.SIZE).longValue()); + markValid(row); + } + + @Override + public void addUint64(BigInteger value) { + setUint64(nextAppendRow(), value); + } + + @Override + public void setUint64(long row, BigInteger value) { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = typeMismatchMessage(DuckDBColumnType.UBIGINT); + if (typeError != null) { + throw new DuckDBFunctionException(typeError); + } + if (value == null) { + setNull(row); + return; + } + if (value.signum() < 0 || value.compareTo(UINT64_MAX) > 0) { + throw new DuckDBFunctionException("Value out of range for UBIGINT: " + value); + } + data.putLong(checkedByteOffset(row, Long.BYTES), value.longValue()); + markValid(row); + } + + @Override + public void addFloat(float value) { + setFloat(nextAppendRow(), value); + } + + @Override + public void setFloat(long row, float value) { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = typeMismatchMessage(DuckDBColumnType.FLOAT); + if (typeError != null) { + throw new DuckDBFunctionException(typeError); + } + data.putFloat(checkedByteOffset(row, Float.BYTES), value); + markValid(row); + } + + @Override + public void addDouble(double value) { + setDouble(nextAppendRow(), value); + } + + @Override + public void setDouble(long row, double value) { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = typeMismatchMessage(DuckDBColumnType.DOUBLE); + if (typeError != null) { + throw new DuckDBFunctionException(typeError); + } + data.putDouble(checkedByteOffset(row, Double.BYTES), value); + markValid(row); + } + + @Override + public void addDate(LocalDate value) { + setDate(nextAppendRow(), value); + } + + @Override + public void setDate(long row, LocalDate value) { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = typeMismatchMessage(DuckDBColumnType.DATE); + if (typeError != null) { + throw new DuckDBFunctionException(typeError); + } + if (value == null) { + setNull(row); + return; + } + long days = value.toEpochDay(); + if (days < Integer.MIN_VALUE || days > Integer.MAX_VALUE) { + throw new DuckDBFunctionException("Value out of range for DATE: " + value); + } + data.putInt(checkedByteOffset(row, Integer.BYTES), (int) days); + markValid(row); + } + + @Override + public void addDate(java.sql.Date value) { + setDate(nextAppendRow(), value); + } + + @Override + public void setDate(long row, java.sql.Date value) { + setDate(row, value == null ? null : value.toLocalDate()); + } + + @Override + public void addDate(java.util.Date value) { + setDate(nextAppendRow(), value); + } + + @Override + public void setDate(long row, java.util.Date value) { + if (value == null) { + setNull(row); + return; + } + if (value instanceof java.sql.Date) { + setDate(row, (java.sql.Date) value); + return; + } + LocalDate localDate = Instant.ofEpochMilli(value.getTime()).atZone(ZoneOffset.UTC).toLocalDate(); + setDate(row, localDate); + } + + @Override + public void addTimestamp(LocalDateTime value) { + setTimestamp(nextAppendRow(), value); + } + + @Override + public void setTimestamp(long row, LocalDateTime value) { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = timestampTypeMismatchMessage(false); + if (typeError != null) { + throw new DuckDBFunctionException(typeError); + } + if (value == null) { + setNull(row); + return; + } + data.putLong(checkedByteOffset(row, Long.BYTES), encodeLocalDateTime(value)); + markValid(row); + } + + @Override + public void addTimestamp(Timestamp value) { + setTimestamp(nextAppendRow(), value); + } + + @Override + public void setTimestamp(long row, Timestamp value) { + if (value == null) { + setNull(row); + return; + } + if (typeInfo.columnType == DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE) { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + data.putLong(checkedByteOffset(row, Long.BYTES), encodeInstant(value.toInstant())); + markValid(row); + return; + } + setTimestamp(row, value.toLocalDateTime()); + } + + @Override + public void addTimestamp(java.util.Date value) { + setTimestamp(nextAppendRow(), value); + } + + @Override + public void setTimestamp(long row, java.util.Date value) { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = timestampTypeMismatchMessage(false); + if (typeError != null) { + throw new DuckDBFunctionException(typeError); + } + if (value == null) { + setNull(row); + return; + } + if (value instanceof Timestamp) { + setTimestamp(row, (Timestamp) value); + return; + } + data.putLong(checkedByteOffset(row, Long.BYTES), encodeJavaUtilDate(value)); + markValid(row); + } + + @Override + public void addTimestamp(LocalDate value) { + setTimestamp(nextAppendRow(), value); + } + + @Override + public void setTimestamp(long row, LocalDate value) { + if (value == null) { + setNull(row); + return; + } + if (typeInfo.columnType == DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE) { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + Instant instant = value.atStartOfDay(ZoneId.systemDefault()).toInstant(); + data.putLong(checkedByteOffset(row, Long.BYTES), encodeInstant(instant)); + markValid(row); + return; + } + setTimestamp(row, value.atStartOfDay()); + } + + @Override + public void addOffsetDateTime(OffsetDateTime value) { + setOffsetDateTime(nextAppendRow(), value); + } + + @Override + public void setOffsetDateTime(long row, OffsetDateTime value) { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = timestampTypeMismatchMessage(true); + if (typeError != null) { + throw new DuckDBFunctionException(typeError); + } + if (value == null) { + setNull(row); + return; + } + data.putLong( + checkedByteOffset(row, Long.BYTES), + DuckDBTimestamp.localDateTime2Micros(value.withOffsetSameInstant(ZoneOffset.UTC).toLocalDateTime())); + markValid(row); + } + + @Override + public void addBigDecimal(BigDecimal value) { + setBigDecimal(nextAppendRow(), value); + } + + @Override + public void setBigDecimal(long row, BigDecimal value) { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = typeMismatchMessage(DuckDBColumnType.DECIMAL); + if (typeError != null) { + throw new DuckDBFunctionException(typeError); + } + if (value == null) { + setNull(row); + return; + } + BigDecimal scaled; + try { + scaled = value.setScale(typeInfo.decimalMeta.scale); + } catch (ArithmeticException e) { + throw decimalOutOfRange(value, e); + } + if (scaled.precision() > typeInfo.decimalMeta.width) { + throw decimalOutOfRange(value); + } + switch (typeInfo.storageType) { + case DUCKDB_TYPE_SMALLINT: + try { + data.putShort(checkedByteOffset(row, Short.BYTES), scaled.unscaledValue().shortValueExact()); + } catch (ArithmeticException e) { + throw decimalOutOfRange(value, e); + } + break; + case DUCKDB_TYPE_INTEGER: + try { + data.putInt(checkedByteOffset(row, Integer.BYTES), scaled.unscaledValue().intValueExact()); + } catch (ArithmeticException e) { + throw decimalOutOfRange(value, e); + } + break; + case DUCKDB_TYPE_BIGINT: + try { + data.putLong(checkedByteOffset(row, Long.BYTES), scaled.unscaledValue().longValueExact()); + } catch (ArithmeticException e) { + throw decimalOutOfRange(value, e); + } + break; + case DUCKDB_TYPE_HUGEINT: { + BigInteger unscaled = scaled.unscaledValue(); + int offset = checkedByteOffset(row, typeInfo.widthBytes); + data.putLong(offset, unscaled.longValue()); + data.putLong(offset + Long.BYTES, unscaled.shiftRight(Long.SIZE).longValue()); + break; + } + default: + throw new DuckDBFunctionException("Unsupported DECIMAL storage type: " + typeInfo.storageType); + } + markValid(row); + } + + @Override + public void addString(String value) { + setString(nextAppendRow(), value); + } + + @Override + public void setString(long row, String value) { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = typeMismatchMessage(DuckDBColumnType.VARCHAR); + if (typeError != null) { + throw new DuckDBFunctionException(typeError); + } + if (value == null) { + setNull(row); + return; + } + duckdb_vector_assign_string_element_len(vectorRef, row, value.getBytes(UTF_8)); + markValid(row); + } + + ByteBuffer vectorRef() { + return vectorRef; + } + + private void ensureValidity() { + if (validity != null) { + return; + } + duckdb_vector_ensure_validity_writable(vectorRef); + validity = duckdb_vector_get_validity(vectorRef, rowCount); + if (validity == null) { + throw new DuckDBFunctionException("Cannot initialize vector validity"); + } + validity = validity.order(NATIVE_ORDER); + } + + private void markValid(long row) { + if (validity == null) { + advanceAppendIndex(row); + return; + } + setRowValidity(row, true); + advanceAppendIndex(row); + } + + private void setRowValidity(long row, boolean valid) { + int entryOffset = Math.toIntExact(Math.multiplyExact(row / Long.SIZE, (long) Long.BYTES)); + long bitIndex = row % Long.SIZE; + long mask = 1L << bitIndex; + long entry = validity.getLong(entryOffset); + if (valid) { + entry |= mask; + } else { + entry &= ~mask; + } + validity.putLong(entryOffset, entry); + } + + private String typeMismatchMessage(DuckDBColumnType expected) { + if (typeInfo.columnType != expected) { + return "Expected vector type " + expected + ", found " + typeInfo.columnType; + } + return null; + } + + private String rowIndexErrorMessage(long row) { + if (row < 0 || row >= rowCount) { + return "Row index out of bounds: " + row; + } + return null; + } + + private String timestampTypeMismatchMessage(boolean requireTimezone) { + if (requireTimezone) { + if (typeInfo.columnType != DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE) { + return "Expected vector type TIMESTAMP WITH TIME ZONE, found " + typeInfo.columnType; + } + return null; + } + switch (typeInfo.columnType) { + case TIMESTAMP: + case TIMESTAMP_S: + case TIMESTAMP_MS: + case TIMESTAMP_NS: + case TIMESTAMP_WITH_TIME_ZONE: + return null; + default: + return "Expected vector type TIMESTAMP*, found " + typeInfo.columnType; + } + } + + private long encodeLocalDateTime(LocalDateTime value) { + Instant instant; + if (typeInfo.columnType == DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE) { + instant = value.atZone(ZoneId.systemDefault()).toInstant(); + } else { + instant = value.toInstant(ZoneOffset.UTC); + } + return encodeInstant(instant); + } + + private long encodeJavaUtilDate(java.util.Date value) { + return encodeInstant(Instant.ofEpochMilli(value.getTime())); + } + + private long encodeInstant(Instant instant) { + long epochSeconds = instant.getEpochSecond(); + int nano = instant.getNano(); + switch (typeInfo.capiType) { + case DUCKDB_TYPE_TIMESTAMP_S: + return epochSeconds; + case DUCKDB_TYPE_TIMESTAMP_MS: + return Math.addExact(Math.multiplyExact(epochSeconds, 1_000L), nano / 1_000_000L); + case DUCKDB_TYPE_TIMESTAMP: + case DUCKDB_TYPE_TIMESTAMP_TZ: + return Math.addExact(Math.multiplyExact(epochSeconds, 1_000_000L), nano / 1_000L); + case DUCKDB_TYPE_TIMESTAMP_NS: + return Math.addExact(Math.multiplyExact(epochSeconds, 1_000_000_000L), nano); + default: + throw new DuckDBFunctionException("Expected vector type TIMESTAMP*, found " + typeInfo.columnType); + } + } + + private static String unsignedRangeErrorMessage(String typeName, long value, long maxValue) { + if (value < 0 || value > maxValue) { + return "Value out of range for " + typeName + ": " + value; + } + return null; + } + + private DuckDBFunctionException decimalOutOfRange(BigDecimal value) { + return new DuckDBFunctionException("Value out of range for " + decimalTypeName() + ": " + value); + } + + private DuckDBFunctionException decimalOutOfRange(BigDecimal value, ArithmeticException cause) { + DuckDBFunctionException exception = decimalOutOfRange(value); + exception.initCause(cause); + return exception; + } + + private String decimalTypeName() { + return "DECIMAL(" + typeInfo.decimalMeta.width + "," + typeInfo.decimalMeta.scale + ")"; + } + + private int checkedRowIndex(long row) { + return Math.toIntExact(row); + } + + private int checkedByteOffset(long row, int elementWidth) { + return Math.toIntExact(Math.multiplyExact(row, (long) elementWidth)); + } + + private void advanceAppendIndex(long row) { + appendIndex = Math.max(appendIndex, Math.addExact(row, 1)); + } + + private long nextAppendRow() { + if (appendIndex >= rowCount) { + throw new IndexOutOfBoundsException("Append index out of bounds: " + appendIndex); + } + return appendIndex; + } +} diff --git a/src/test/java/org/duckdb/TestBindings.java b/src/test/java/org/duckdb/TestBindings.java index 2c97dfee7..8bdf0ad3f 100644 --- a/src/test/java/org/duckdb/TestBindings.java +++ b/src/test/java/org/duckdb/TestBindings.java @@ -6,7 +6,9 @@ import static org.duckdb.TestDuckDBJDBC.JDBC_URL; import static org.duckdb.test.Assertions.*; +import java.math.BigInteger; import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.nio.charset.StandardCharsets; import java.sql.*; import java.util.Arrays; @@ -20,6 +22,65 @@ public static void test_bindings_vector_size() throws Exception { assertTrue(size > 0); } + public static void test_bindings_vector_row_index_stream() throws Exception { + ByteBuffer lt = duckdb_create_logical_type(DUCKDB_TYPE_INTEGER.typeId); + ByteBuffer inputVec = duckdb_create_vector(lt); + ByteBuffer outputVec = duckdb_create_vector(lt); + + DuckDBWritableVector input = new DuckDBWritableVectorImpl(inputVec, 3); + input.setInt(0, 1); + input.setInt(1, 41); + input.setInt(2, -5); + + DuckDBReadableVector readable = new DuckDBReadableVectorImpl(inputVec, 3); + DuckDBWritableVector output = new DuckDBWritableVectorImpl(outputVec, 3); + readable.rowIndexStream().forEachOrdered(row -> { output.setInt(row, readable.getInt(row) + 1); }); + + DuckDBReadableVector result = new DuckDBReadableVectorImpl(outputVec, 3); + assertEquals(result.getInt(0), 2); + assertEquals(result.getInt(1), 42); + assertEquals(result.getInt(2), -4); + + duckdb_destroy_vector(outputVec); + duckdb_destroy_vector(inputVec); + duckdb_destroy_logical_type(lt); + } + + public static void test_bindings_writable_vector_append_after_indexed_write() throws Exception { + ByteBuffer lt = duckdb_create_logical_type(DUCKDB_TYPE_INTEGER.typeId); + ByteBuffer vec = duckdb_create_vector(lt); + + DuckDBWritableVector writable = new DuckDBWritableVectorImpl(vec, 3); + writable.setNull(0); + writable.setInt(1, 41); + writable.addInt(42); + + DuckDBReadableVector readable = new DuckDBReadableVectorImpl(vec, 3); + assertTrue(readable.isNull(0)); + assertEquals(readable.getInt(1), 41); + assertEquals(readable.getInt(2), 42); + + duckdb_destroy_vector(vec); + duckdb_destroy_logical_type(lt); + } + + public static void test_bindings_writable_vector_failed_append_does_not_advance() throws Exception { + ByteBuffer lt = duckdb_create_logical_type(DUCKDB_TYPE_INTEGER.typeId); + ByteBuffer vec = duckdb_create_vector(lt); + + DuckDBWritableVector writable = new DuckDBWritableVectorImpl(vec, 2); + assertThrows(() -> { writable.addString("boom"); }, DuckDBFunctionException.class); + writable.addInt(7); + writable.addInt(8); + + DuckDBReadableVector readable = new DuckDBReadableVectorImpl(vec, 2); + assertEquals(readable.getInt(0), 7); + assertEquals(readable.getInt(1), 8); + + duckdb_destroy_vector(vec); + duckdb_destroy_logical_type(lt); + } + public static void test_bindings_logical_type() throws Exception { ByteBuffer lt = duckdb_create_logical_type(DUCKDB_TYPE_INTEGER.typeId); assertNotNull(lt); @@ -42,6 +103,26 @@ public static void test_bindings_logical_type() throws Exception { assertThrows(() -> { duckdb_destroy_logical_type(null); }, SQLException.class); } + public static void test_bindings_parse_logical_type() throws Exception { + try (DuckDBLogicalType integerType = DuckDBLogicalType.of(DuckDBColumnType.INTEGER)) { + assertNotNull(integerType); + assertEquals(DUCKDB_TYPE_INTEGER.typeId, duckdb_get_type_id(integerType.logicalTypeRef())); + } + + try (DuckDBLogicalType decimalType = DuckDBLogicalType.decimal(18, 3)) { + assertNotNull(decimalType); + ByteBuffer decimalRef = decimalType.logicalTypeRef(); + assertEquals(DUCKDB_TYPE_DECIMAL.typeId, duckdb_get_type_id(decimalRef)); + assertEquals(18, duckdb_decimal_width(decimalRef)); + assertEquals(3, duckdb_decimal_scale(decimalRef)); + assertEquals(DUCKDB_TYPE_BIGINT.typeId, duckdb_decimal_internal_type(decimalRef)); + } + + assertThrows(() -> { DuckDBLogicalType.of(null); }, SQLException.class); + assertThrows(() -> { DuckDBLogicalType.decimal(39, 0); }, SQLException.class); + assertThrows(() -> { DuckDBLogicalType.decimal(10, 11); }, SQLException.class); + } + public static void test_bindings_vector_create() throws Exception { ByteBuffer lt = duckdb_create_logical_type(DUCKDB_TYPE_INTEGER.typeId); ByteBuffer vec = duckdb_create_vector(lt); @@ -108,6 +189,132 @@ public static void test_bindings_vector_strings() throws Exception { duckdb_destroy_logical_type(lt); } + public static void test_bindings_vector_get_string() throws Exception { + ByteBuffer lt = duckdb_create_logical_type(DUCKDB_TYPE_VARCHAR.typeId); + ByteBuffer vec = duckdb_create_vector(lt); + + long rowCount = duckdb_vector_size(); + DuckDBWritableVector writable = new DuckDBWritableVectorImpl(vec, rowCount); + writable.setNull(0); + writable.setString(1, "duckdb"); + + DuckDBReadableVector readable = new DuckDBReadableVectorImpl(vec, rowCount); + assertNull(readable.getString(0)); + assertEquals(readable.getString(1), "duckdb"); + assertThrows(() -> { readable.getString(rowCount); }, IndexOutOfBoundsException.class); + + duckdb_destroy_vector(vec); + duckdb_destroy_logical_type(lt); + } + + public static void test_bindings_vector_native_endian_roundtrip() throws Exception { + ByteBuffer lt = duckdb_create_logical_type(DUCKDB_TYPE_INTEGER.typeId); + ByteBuffer vec = duckdb_create_vector(lt); + + int rowCount = (int) duckdb_vector_size(); + int expected = 0x01020304; + DuckDBWritableVector writable = new DuckDBWritableVectorImpl(vec, rowCount); + writable.setInt(0, expected); + + ByteBuffer rawData = duckdb_vector_get_data(vec, (long) rowCount * Integer.BYTES); + assertEquals(rawData.order(ByteOrder.nativeOrder()).getInt(0), expected); + + DuckDBReadableVector readable = new DuckDBReadableVectorImpl(vec, rowCount); + assertEquals(readable.getInt(0), expected); + + duckdb_destroy_vector(vec); + duckdb_destroy_logical_type(lt); + } + + public static void test_bindings_writable_vector_stack_trace_origin() throws Exception { + ByteBuffer lt = duckdb_create_logical_type(DUCKDB_TYPE_VARCHAR.typeId); + ByteBuffer vec = duckdb_create_vector(lt); + + int rowCount = (int) duckdb_vector_size(); + DuckDBWritableVector writable = new DuckDBWritableVectorImpl(vec, rowCount); + + try { + writable.setInt(0, 42); + fail("Expected setInt to reject VARCHAR vector"); + } catch (DuckDBFunctionException exception) { + assertTrue(exception.getMessage().contains("Expected vector type INTEGER, found VARCHAR")); + assertEquals(exception.getStackTrace()[0].getMethodName(), "setInt"); + } + + try { + writable.setString(rowCount, "boom"); + fail("Expected setString to reject out-of-bounds row"); + } catch (IndexOutOfBoundsException exception) { + assertTrue(exception.getMessage().contains("Row index out of bounds")); + assertEquals(exception.getStackTrace()[0].getMethodName(), "setString"); + } + + duckdb_destroy_vector(vec); + duckdb_destroy_logical_type(lt); + } + + public static void test_bindings_vector_ubigint_native_endian_roundtrip() throws Exception { + ByteBuffer lt = duckdb_create_logical_type(DUCKDB_TYPE_UBIGINT.typeId); + ByteBuffer vec = duckdb_create_vector(lt); + + int rowCount = (int) duckdb_vector_size(); + assertTrue(rowCount >= 4); + BigInteger[] values = + new BigInteger[] {BigInteger.ZERO, new BigInteger("42"), new BigInteger("9223372036854775808"), + new BigInteger("18446744073709551615")}; + + DuckDBWritableVector writable = new DuckDBWritableVectorImpl(vec, rowCount); + for (int i = 0; i < values.length; i++) { + writable.setUint64(i, values[i]); + } + + ByteBuffer rawData = duckdb_vector_get_data(vec, (long) rowCount * Long.BYTES); + ByteBuffer nativeData = rawData.order(ByteOrder.nativeOrder()); + for (int i = 0; i < values.length; i++) { + assertEquals(nativeData.getLong(i * Long.BYTES), values[i].longValue()); + } + + DuckDBReadableVector readable = new DuckDBReadableVectorImpl(vec, rowCount); + for (int i = 0; i < values.length; i++) { + assertEquals(readable.getUint64(i), values[i]); + } + + duckdb_destroy_vector(vec); + duckdb_destroy_logical_type(lt); + } + + public static void test_bindings_vector_uhugeint_native_endian_roundtrip() throws Exception { + ByteBuffer lt = duckdb_create_logical_type(DUCKDB_TYPE_UHUGEINT.typeId); + ByteBuffer vec = duckdb_create_vector(lt); + + int rowCount = (int) duckdb_vector_size(); + assertTrue(rowCount >= 4); + BigInteger[] values = + new BigInteger[] {BigInteger.ZERO, new BigInteger("42"), new BigInteger("9223372036854775808"), + new BigInteger("340282366920938463463374607431768211455")}; + + DuckDBWritableVector writable = new DuckDBWritableVectorImpl(vec, rowCount); + for (int i = 0; i < values.length; i++) { + writable.setUHugeInt(i, values[i]); + } + + ByteBuffer rawData = duckdb_vector_get_data(vec, (long) rowCount * Long.BYTES * 2); + ByteBuffer nativeData = rawData.order(ByteOrder.nativeOrder()); + for (int i = 0; i < values.length; i++) { + int offset = i * Long.BYTES * 2; + assertEquals(nativeData.getLong(offset), values[i].longValue()); + assertEquals(nativeData.getLong(offset + Long.BYTES), values[i].shiftRight(Long.SIZE).longValue()); + } + + DuckDBReadableVector readable = new DuckDBReadableVectorImpl(vec, rowCount); + for (int i = 0; i < values.length; i++) { + assertEquals(readable.getUHugeInt(i), values[i]); + } + + duckdb_destroy_vector(vec); + duckdb_destroy_logical_type(lt); + } + public static void test_bindings_vector_validity() throws Exception { ByteBuffer lt = duckdb_create_logical_type(DUCKDB_TYPE_VARCHAR.typeId); ByteBuffer vec = duckdb_create_vector(lt); @@ -211,6 +418,53 @@ public static void test_bindings_validity() throws Exception { duckdb_destroy_logical_type(lt); } + public static void test_bindings_writable_vector_validity_word_boundaries() throws Exception { + ByteBuffer lt = duckdb_create_logical_type(DUCKDB_TYPE_INTEGER.typeId); + ByteBuffer vec = duckdb_create_vector(lt); + + long rowCount = 70; + DuckDBWritableVector writable = new DuckDBWritableVectorImpl(vec, rowCount); + long[] boundaryRows = new long[] {63, 64, 65, rowCount - 1}; + long[] sentinelRows = new long[] {62, 66, 67, 68}; + + for (long row : boundaryRows) { + writable.setInt(row, (int) row); + } + for (long row : sentinelRows) { + writable.setInt(row, (int) (row + 10_000)); + } + + for (long row : boundaryRows) { + writable.setNull(row); + } + + DuckDBReadableVector readable = new DuckDBReadableVectorImpl(vec, rowCount); + for (long row : boundaryRows) { + assertTrue(readable.isNull(row)); + } + for (long row : sentinelRows) { + assertFalse(readable.isNull(row)); + assertEquals(readable.getInt(row), (int) (row + 10_000)); + } + + for (long row : boundaryRows) { + writable.setInt(row, (int) (row + 1000)); + } + + DuckDBReadableVector revalidated = new DuckDBReadableVectorImpl(vec, rowCount); + for (long row : boundaryRows) { + assertFalse(revalidated.isNull(row)); + assertEquals(revalidated.getInt(row), (int) (row + 1000)); + } + for (long row : sentinelRows) { + assertFalse(revalidated.isNull(row)); + assertEquals(revalidated.getInt(row), (int) (row + 10_000)); + } + + duckdb_destroy_vector(vec); + duckdb_destroy_logical_type(lt); + } + public static void test_bindings_data_chunk() throws Exception { ByteBuffer intType = duckdb_create_logical_type(DUCKDB_TYPE_INTEGER.typeId); ByteBuffer varcharType = duckdb_create_logical_type(DUCKDB_TYPE_VARCHAR.typeId); @@ -343,6 +597,13 @@ public static void test_bindings_decimal_type() throws Exception { } } + public static void test_bindings_decimal_type_validation() throws Exception { + assertThrows(() -> { duckdb_create_decimal_type(0, 0); }, SQLException.class); + assertThrows(() -> { duckdb_create_decimal_type(39, 0); }, SQLException.class); + assertThrows(() -> { duckdb_create_decimal_type(10, -1); }, SQLException.class); + assertThrows(() -> { duckdb_create_decimal_type(10, 11); }, SQLException.class); + } + public static void test_bindings_enum_type() throws Exception { try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); Statement stmt = conn.createStatement()) { diff --git a/src/test/java/org/duckdb/TestDuckDBJDBC.java b/src/test/java/org/duckdb/TestDuckDBJDBC.java index 3769121a7..88051efdb 100644 --- a/src/test/java/org/duckdb/TestDuckDBJDBC.java +++ b/src/test/java/org/duckdb/TestDuckDBJDBC.java @@ -2243,12 +2243,13 @@ public static void main(String[] args) throws Exception { Class clazz = Class.forName("org.duckdb." + arg1); statusCode = runTests(new String[0], clazz); } else { - statusCode = runTests(args, TestDuckDBJDBC.class, TestAppender.class, TestAppenderCollection.class, - TestAppenderCollection2D.class, TestAppenderComposite.class, - TestSingleValueAppender.class, TestBatch.class, TestBindings.class, TestClosure.class, - TestExtensionTypes.class, TestMetadata.class, TestNoLib.class, /* TestSpatial.class,*/ - TestParameterMetadata.class, TestPrepare.class, TestResults.class, - TestSessionInit.class, TestTimestamp.class, TestVariant.class); + statusCode = + runTests(args, TestDuckDBJDBC.class, TestAppender.class, TestAppenderCollection.class, + TestAppenderCollection2D.class, TestAppenderComposite.class, TestSingleValueAppender.class, + TestBatch.class, TestBindings.class, TestClosure.class, TestExtensionTypes.class, + TestMetadata.class, TestNoLib.class, TestScalarFunctions.class, + /* TestSpatial.class,*/ TestParameterMetadata.class, TestPrepare.class, TestResults.class, + TestSessionInit.class, TestTimestamp.class, TestVariant.class); } System.exit(statusCode); } diff --git a/src/test/java/org/duckdb/TestScalarFunctions.java b/src/test/java/org/duckdb/TestScalarFunctions.java new file mode 100644 index 000000000..8c4fc856e --- /dev/null +++ b/src/test/java/org/duckdb/TestScalarFunctions.java @@ -0,0 +1,2186 @@ +package org.duckdb; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.duckdb.DuckDBBindings.*; +import static org.duckdb.DuckDBBindings.CAPIType.DUCKDB_TYPE_INTEGER; +import static org.duckdb.TestDuckDBJDBC.JDBC_URL; +import static org.duckdb.test.Assertions.*; + +import java.lang.reflect.Proxy; +import java.math.BigDecimal; +import java.math.BigInteger; +import java.nio.ByteBuffer; +import java.nio.file.Files; +import java.nio.file.Path; +import java.sql.Connection; +import java.sql.Date; +import java.sql.DriverManager; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.sql.Time; +import java.sql.Timestamp; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.OffsetDateTime; +import java.time.ZoneOffset; +import java.util.List; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.function.Supplier; + +public class TestScalarFunctions { + private interface ResultSetVerifier { + void verify(ResultSet rs) throws Exception; + } + + private static int sumNonNullIntColumns(DuckDBScalarContext ctx, DuckDBScalarRow row) { + int sum = 0; + for (int columnIndex = 0; columnIndex < ctx.columnCount(); columnIndex++) { + if (!row.isNull(columnIndex)) { + sum += row.getInt(columnIndex); + } + } + return sum; + } + + public static void test_bindings_scalar_function() throws Exception { + ByteBuffer intType = duckdb_create_logical_type(DUCKDB_TYPE_INTEGER.typeId); + ByteBuffer scalarFunction = duckdb_create_scalar_function(); + assertNotNull(scalarFunction); + + duckdb_scalar_function_set_name(scalarFunction, "binding_scalar_fn".getBytes(UTF_8)); + duckdb_scalar_function_add_parameter(scalarFunction, intType); + duckdb_scalar_function_set_return_type(scalarFunction, intType); + + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class)) { + assertEquals(duckdb_register_scalar_function(conn.connRef, scalarFunction), 1); + + assertThrows(() -> { duckdb_register_scalar_function(null, scalarFunction); }, SQLException.class); + assertThrows(() -> { duckdb_register_scalar_function(conn.connRef, null); }, SQLException.class); + } + + duckdb_destroy_scalar_function(scalarFunction); + duckdb_destroy_logical_type(intType); + + assertThrows(() -> { duckdb_destroy_scalar_function(null); }, SQLException.class); + assertThrows(() -> { duckdb_scalar_function_set_name(null, "x".getBytes(UTF_8)); }, SQLException.class); + } + + public static void test_register_scalar_function() throws Exception { + test_register_scalar_function_integer(); + } + + public static void test_register_scalar_function_builder() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement(); + DuckDBLogicalType intType = DuckDBLogicalType.of(DuckDBColumnType.INTEGER)) { + DuckDBRegisteredFunction function = + DuckDBFunctions.scalarFunction() + .withName("java_add_int_builder") + .withParameter(intType) + .withReturnType(intType) + .withVectorizedFunction(ctx -> { + ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setInt(row.getInt(0) + 1)); + }) + .register(conn); + assertEquals(function.name(), "java_add_int_builder"); + assertEquals(function.parameterTypes().size(), 1); + assertEquals(function.parameterTypes().get(0), intType); + assertEquals(function.returnType(), intType); + assertEquals(function.varArgType(), null); + assertEquals(function.isVolatile(), false); + assertEquals(function.hasSpecialHandling(), false); + assertEquals(function.propagateNulls(), false); + + try (ResultSet rs = + stmt.executeQuery("SELECT java_add_int_builder(v) FROM (VALUES (1), (NULL), (41)) t(v)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 2); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 42); + assertFalse(rs.next()); + } + } + } + + public static void test_register_scalar_function_builder_connection_without_unwrap() throws Exception { + try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement()) { + DuckDBFunctions.scalarFunction() + .withName("java_add_int_connection") + .withParameter(Integer.class) + .withReturnType(Integer.class) + .withFunction((Integer x) -> null != x ? x + 1 : null) + .register(conn); + + try (ResultSet rs = + stmt.executeQuery("SELECT java_add_int_connection(v) FROM (VALUES (1), (NULL), (41)) t(v)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 2); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 42); + assertFalse(rs.next()); + } + } + } + + public static void test_register_scalar_function_builder_returns_detached_metadata() throws Exception { + DuckDBRegisteredFunction function; + try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement(); + DuckDBScalarFunctionBuilder builder = DuckDBFunctions.scalarFunction()) { + function = builder.withName("java_add_int_detached") + .withParameter(Integer.class) + .withReturnType(Integer.class) + .withFunction((Integer x) -> null != x ? x + 1 : null) + .register(conn); + + String message = + assertThrows(() -> { builder.withName("java_add_int_detached_again"); }, SQLException.class); + assertTrue(message.contains("already finalized")); + + assertEquals(function.name(), "java_add_int_detached"); + assertEquals(function.parameterColumnTypes().size(), 1); + assertEquals(function.parameterColumnTypes().get(0), DuckDBColumnType.INTEGER); + assertEquals(function.returnColumnType(), DuckDBColumnType.INTEGER); + assertNotNull(function.function()); + assertEquals(function.functionKind(), DuckDBFunctions.DuckDBFunctionKind.SCALAR); + assertTrue(function.isScalar()); + assertEquals(function.propagateNulls(), false); + + try (ResultSet rs = + stmt.executeQuery("SELECT java_add_int_detached(v) FROM (VALUES (1), (NULL), (41)) t(v)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 2); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 42); + assertFalse(rs.next()); + } + } + } + + public static void test_register_scalar_function_registry_records_registered_functions() throws Exception { + DuckDBDriver.clearFunctionsRegistry(); + try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement()) { + DuckDBRegisteredFunction function = DuckDBFunctions.scalarFunction() + .withName("java_registry_recorded") + .withParameter(Integer.class) + .withReturnType(Integer.class) + .withFunction((Integer x) -> x + 1) + .register(conn); + + List registeredFunctions = DuckDBDriver.registeredFunctions(); + assertEquals(registeredFunctions.size(), 1); + assertEquals(registeredFunctions.get(0), function); + assertEquals(registeredFunctions.get(0).functionKind(), DuckDBFunctions.DuckDBFunctionKind.SCALAR); + assertTrue(registeredFunctions.get(0).isScalar()); + + try (ResultSet rs = stmt.executeQuery("SELECT java_registry_recorded(41)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 42); + assertFalse(rs.next()); + } + } finally { + DuckDBDriver.clearFunctionsRegistry(); + } + } + + public static void test_register_scalar_function_registry_is_read_only() throws Exception { + DuckDBDriver.clearFunctionsRegistry(); + try (Connection conn = DriverManager.getConnection(JDBC_URL)) { + DuckDBFunctions.scalarFunction() + .withName("java_registry_read_only") + .withParameter(Integer.class) + .withReturnType(Integer.class) + .withFunction((Integer x) -> x + 1) + .register(conn); + + List registeredFunctions = DuckDBDriver.registeredFunctions(); + assertThrows(() -> { registeredFunctions.add(null); }, UnsupportedOperationException.class); + } finally { + DuckDBDriver.clearFunctionsRegistry(); + } + } + + public static void test_register_scalar_function_registry_clear_only_clears_java_registry() throws Exception { + DuckDBDriver.clearFunctionsRegistry(); + try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement()) { + DuckDBFunctions.scalarFunction() + .withName("java_registry_clear_only") + .withParameter(Integer.class) + .withReturnType(Integer.class) + .withFunction((Integer x) -> x + 1) + .register(conn); + + assertEquals(DuckDBDriver.registeredFunctions().size(), 1); + DuckDBDriver.clearFunctionsRegistry(); + assertEquals(DuckDBDriver.registeredFunctions().size(), 0); + + try (ResultSet rs = stmt.executeQuery("SELECT java_registry_clear_only(41)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 42); + assertFalse(rs.next()); + } + } finally { + DuckDBDriver.clearFunctionsRegistry(); + } + } + + public static void test_register_scalar_function_registry_allows_duplicate_names() throws Exception { + DuckDBDriver.clearFunctionsRegistry(); + Path tempDir = Files.createTempDirectory("duckdb-registry"); + Path dbPathA = tempDir.resolve("registry-a.db"); + Path dbPathB = tempDir.resolve("registry-b.db"); + String urlA = "jdbc:duckdb:" + dbPathA.toAbsolutePath(); + String urlB = "jdbc:duckdb:" + dbPathB.toAbsolutePath(); + + try (Connection connA = DriverManager.getConnection(urlA); + Connection connB = DriverManager.getConnection(urlB)) { + DuckDBFunctions.scalarFunction() + .withName("java_registry_duplicate_name") + .withParameter(Integer.class) + .withReturnType(Integer.class) + .withFunction((Integer x) -> x + 1) + .register(connA); + DuckDBFunctions.scalarFunction() + .withName("java_registry_duplicate_name") + .withParameter(Integer.class) + .withReturnType(Integer.class) + .withFunction((Integer x) -> x + 2) + .register(connB); + + List registeredFunctions = DuckDBDriver.registeredFunctions(); + assertEquals(registeredFunctions.size(), 2); + assertEquals(registeredFunctions.get(0).name(), "java_registry_duplicate_name"); + assertEquals(registeredFunctions.get(1).name(), "java_registry_duplicate_name"); + } finally { + DuckDBDriver.clearFunctionsRegistry(); + Files.deleteIfExists(dbPathA); + Files.deleteIfExists(dbPathB); + Files.deleteIfExists(tempDir); + } + } + + public static void test_register_scalar_function_builder_rejects_non_duckdb_connection() throws Exception { + Connection connection = (Connection) Proxy.newProxyInstance( + Connection.class.getClassLoader(), new Class[] {Connection.class}, (proxy, method, args) -> { + switch (method.getName()) { + case "unwrap": + throw new SQLException("not a DuckDB connection"); + case "isWrapperFor": + return false; + case "toString": + return "invalid-connection"; + case "hashCode": + return System.identityHashCode(proxy); + case "equals": + return proxy == args[0]; + default: + throw new UnsupportedOperationException(method.getName()); + } + }); + + try (DuckDBScalarFunctionBuilder builder = DuckDBFunctions.scalarFunction()) { + builder.withName("java_invalid_connection") + .withParameter(Integer.class) + .withReturnType(Integer.class) + .withFunction((Integer x) -> x + 1); + + String message = assertThrows(() -> { builder.register(connection); }, SQLException.class); + assertTrue(message.contains("requires a DuckDB JDBC connection")); + } + } + + public static void test_register_scalar_function_builder_varargs_and_flags() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement(); + DuckDBLogicalType intType = DuckDBLogicalType.of(DuckDBColumnType.INTEGER)) { + DuckDBRegisteredFunction function = + DuckDBFunctions.scalarFunction() + .withName("java_sum_varargs_builder") + .withParameter(intType) + .withVarArgs(intType) + .withReturnType(intType) + .withVolatile() + .withSpecialHandling() + .withVectorizedFunction( + ctx -> { ctx.stream().forEachOrdered(row -> { row.setInt(sumNonNullIntColumns(ctx, row)); }); }) + .register(conn); + assertEquals(function.varArgType(), intType); + assertEquals(function.isVolatile(), true); + assertEquals(function.hasSpecialHandling(), true); + assertEquals(function.propagateNulls(), false); + + try (ResultSet rs = + stmt.executeQuery("SELECT java_sum_varargs_builder(1, 2, 3), java_sum_varargs_builder(5)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 6); + assertEquals(rs.getObject(2, Integer.class), 5); + assertFalse(rs.next()); + } + } + } + + public static void test_register_scalar_function_builder_column_type_overloads() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBRegisteredFunction function = + DuckDBFunctions.scalarFunction() + .withName("java_add_int_builder_col_type") + .withParameter(DuckDBColumnType.INTEGER) + .withReturnType(DuckDBColumnType.INTEGER) + .withVectorizedFunction(ctx -> { + ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setInt(row.getInt(0) + 1)); + }) + .register(conn); + assertEquals(function.parameterColumnTypes().size(), 1); + assertEquals(function.parameterColumnTypes().get(0), DuckDBColumnType.INTEGER); + assertEquals(function.parameterTypes().get(0), null); + assertEquals(function.returnColumnType(), DuckDBColumnType.INTEGER); + assertEquals(function.returnType(), null); + + try (ResultSet rs = stmt.executeQuery( + "SELECT java_add_int_builder_col_type(v) FROM (VALUES (1), (NULL), (41)) t(v)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 2); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 42); + assertFalse(rs.next()); + } + } + } + + public static void test_register_scalar_function_builder_with_parameters_class_helper() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBFunctions.scalarFunction() + .withName("java_add_int_with_parameters") + .withParameters(Integer.class, Integer.class) + .withReturnType(Integer.class) + .withFunction((Integer left, Integer right) -> left != null && right != null ? left + right : null) + .register(conn); + + try (ResultSet rs = stmt.executeQuery( + "SELECT java_add_int_with_parameters(a, b) FROM (VALUES (1, 2), (NULL, 1), (20, 22)) t(a, b)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 3); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 42); + assertFalse(rs.next()); + } + } + } + + public static void test_register_scalar_function_builder_java_function() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBFunctions.scalarFunction() + .withName("java_add_int_function") + .withParameter(Integer.class) + .withReturnType(Integer.class) + .withFunction((Integer x) -> null != x ? x + 1 : null) + .register(conn); + + try (ResultSet rs = + stmt.executeQuery("SELECT java_add_int_function(v) FROM (VALUES (1), (NULL), (41)) t(v)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 2); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 42); + assertFalse(rs.next()); + } + } + } + + public static void test_register_scalar_function_builder_java_function_propagate_nulls_false() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBRegisteredFunction function = DuckDBFunctions.scalarFunction() + .withName("java_add_int_function_nullable") + .withParameter(Integer.class) + .withReturnType(Integer.class) + .withFunction((Integer x) -> x == null ? 99 : x + 1) + .register(conn); + assertEquals(function.propagateNulls(), false); + + try (ResultSet rs = stmt.executeQuery( + "SELECT java_add_int_function_nullable(v) FROM (VALUES (1), (NULL), (41)) t(v)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 2); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 99); + assertFalse(rs.wasNull()); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 42); + assertFalse(rs.next()); + } + } + } + + public static void test_register_scalar_function_builder_java_bifunction() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBFunctions.scalarFunction() + .withName("java_add_int_bifunction") + .withParameter(Integer.class) + .withParameter(Integer.class) + .withReturnType(Integer.class) + .withFunction((Integer x, Integer y) -> null != x && null != y ? x + y : null) + .register(conn); + + try ( + ResultSet rs = stmt.executeQuery( + "SELECT java_add_int_bifunction(a, b) FROM (VALUES (1, 2), (NULL, 2), (39, 3), (5, NULL)) t(a, b)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 3); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 42); + assertTrue(rs.next()); + assertNullRow(rs); + assertFalse(rs.next()); + } + } + } + + public static void test_register_scalar_function_builder_java_bifunction_propagate_nulls_false() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBRegisteredFunction function = + DuckDBFunctions.scalarFunction() + .withName("java_add_int_bifunction_nullable") + .withParameter(Integer.class) + .withParameter(Integer.class) + .withReturnType(Integer.class) + .withFunction( + (Integer left, Integer right) -> (left == null ? 0 : left) + (right == null ? 0 : right)) + .register(conn); + assertEquals(function.propagateNulls(), false); + + try ( + ResultSet rs = stmt.executeQuery( + "SELECT java_add_int_bifunction_nullable(a, b) FROM (VALUES (1, 2), (NULL, 2), (39, NULL), (NULL, NULL)) t(a, b)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 3); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 2); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 39); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 0); + assertFalse(rs.next()); + } + } + } + + public static void test_register_scalar_function_builder_with_int_function() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBRegisteredFunction function = DuckDBFunctions.scalarFunction() + .withName("java_add_int_with_int_function") + .withParameter(Integer.class) + .withReturnType(Integer.class) + .withIntFunction(x -> x + 1) + .register(conn); + assertEquals(function.propagateNulls(), true); + + try (ResultSet rs = stmt.executeQuery( + "SELECT java_add_int_with_int_function(v) FROM (VALUES (1), (NULL), (41)) t(v)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 2); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 42); + assertFalse(rs.next()); + } + } + } + + public static void test_register_scalar_function_builder_with_int_binary_function() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBRegisteredFunction function = DuckDBFunctions.scalarFunction() + .withName("java_add_int_with_int_binary_function") + .withParameter(Integer.class) + .withParameter(Integer.class) + .withReturnType(Integer.class) + .withIntFunction((left, right) -> left + right) + .register(conn); + assertEquals(function.propagateNulls(), true); + + try (ResultSet rs = stmt.executeQuery("SELECT java_add_int_with_int_binary_function(a, b) " + + "FROM (VALUES (1, 2), (NULL, 2), (39, 3), (5, NULL)) t(a, b)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 3); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 42); + assertTrue(rs.next()); + assertNullRow(rs); + assertFalse(rs.next()); + } + } + } + + public static void test_register_scalar_function_builder_with_double_function() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBRegisteredFunction function = DuckDBFunctions.scalarFunction() + .withName("java_add_double_with_double_function") + .withParameter(Double.class) + .withReturnType(Double.class) + .withDoubleFunction(x -> x + 0.5d) + .register(conn); + assertEquals(function.propagateNulls(), true); + + try (ResultSet rs = stmt.executeQuery( + "SELECT java_add_double_with_double_function(v) FROM (VALUES (41.5), (NULL), (-2.5)) t(v)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Double.class), 42.0d); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Double.class), -2.0d); + assertFalse(rs.next()); + } + } + } + + public static void test_register_scalar_function_builder_with_double_binary_function() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBRegisteredFunction function = DuckDBFunctions.scalarFunction() + .withName("java_add_double_with_double_binary_function") + .withParameter(Double.class) + .withParameter(Double.class) + .withReturnType(Double.class) + .withDoubleFunction((left, right) -> left + right) + .register(conn); + assertEquals(function.propagateNulls(), true); + + try (ResultSet rs = + stmt.executeQuery("SELECT java_add_double_with_double_binary_function(a, b) FROM " + + "(VALUES (1.0, 2.0), (NULL, 2.0), (39.5, 2.5), (5.0, NULL)) t(a, b)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Double.class), 3.0d); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Double.class), 42.0d); + assertTrue(rs.next()); + assertNullRow(rs); + assertFalse(rs.next()); + } + } + } + + public static void test_register_scalar_function_builder_with_long_function() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBRegisteredFunction function = DuckDBFunctions.scalarFunction() + .withName("java_add_long_with_long_function") + .withParameter(Long.class) + .withReturnType(Long.class) + .withLongFunction(x -> x + 3) + .register(conn); + assertEquals(function.propagateNulls(), true); + + try ( + ResultSet rs = stmt.executeQuery( + "SELECT java_add_long_with_long_function(v) FROM (VALUES (39::BIGINT), (NULL), (-5::BIGINT)) t(v)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Long.class), 42L); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Long.class), -2L); + assertFalse(rs.next()); + } + } + } + + public static void test_register_scalar_function_builder_with_long_binary_function() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBRegisteredFunction function = DuckDBFunctions.scalarFunction() + .withName("java_add_long_with_long_binary_function") + .withParameter(Long.class) + .withParameter(Long.class) + .withReturnType(Long.class) + .withLongFunction((left, right) -> left + right) + .register(conn); + assertEquals(function.propagateNulls(), true); + + try (ResultSet rs = stmt.executeQuery("SELECT java_add_long_with_long_binary_function(a, b) " + + "FROM (VALUES (1::BIGINT, 2::BIGINT), (NULL, 2::BIGINT), " + + "(39::BIGINT, 3::BIGINT), (5::BIGINT, NULL)) t(a, b)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Long.class), 3L); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Long.class), 42L); + assertTrue(rs.next()); + assertNullRow(rs); + assertFalse(rs.next()); + } + } + } + + public static void test_register_scalar_function_builder_java_function_class_cast_error() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBFunctions.scalarFunction() + .withName("java_invalid_cast_function") + .withParameter(Integer.class) + .withReturnType(Integer.class) + .withFunction((String value) -> value.length()) + .register(conn); + + String message = + assertThrows(() -> { stmt.executeQuery("SELECT java_invalid_cast_function(1)"); }, SQLException.class); + assertTrue(message.contains("Java scalar function threw exception")); + assertTrue(message.contains("ClassCastException")); + } + } + + public static void test_register_scalar_function_builder_java_supplier() throws Exception { + assertNullaryJavaFunction("java_constant_supplier", Integer.class, + () -> 42, "SELECT java_constant_supplier() FROM range(3)", rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 42); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 42); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 42); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_builder_java_supplier_null_value() throws Exception { + assertNullaryJavaFunction("java_null_supplier", String.class, + () -> null, "SELECT java_null_supplier() FROM range(2)", rs -> { + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertNullRow(rs); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_builder_java_supplier_class_cast_error() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBFunctions.scalarFunction() + .withName("java_invalid_supplier_cast") + .withReturnType(Integer.class) + .withFunction(() -> "not_an_integer") + .register(conn); + + String message = + assertThrows(() -> { stmt.executeQuery("SELECT java_invalid_supplier_cast()"); }, SQLException.class); + assertTrue(message.contains("Java scalar function threw exception")); + assertTrue(message.contains("ClassCastException")); + } + } + + public static void test_register_scalar_function_builder_java_supplier_rejects_parameters() throws Exception { + try (DuckDBScalarFunctionBuilder builder = DuckDBFunctions.scalarFunction()) { + builder.withName("java_invalid_supplier_arity").withParameter(Integer.class).withReturnType(Integer.class); + String message = assertThrows(() -> { builder.withFunction(() -> 1); }, SQLException.class); + assertTrue(message.contains("Supplier callback requires zero declared parameters")); + } + } + + public static void test_register_scalar_function_builder_java_supplier_rejects_varargs() throws Exception { + try (DuckDBScalarFunctionBuilder builder = DuckDBFunctions.scalarFunction(); + DuckDBLogicalType intType = DuckDBLogicalType.of(DuckDBColumnType.INTEGER)) { + builder.withName("java_invalid_supplier_varargs").withReturnType(Integer.class).withVarArgs(intType); + String message = assertThrows(() -> { builder.withFunction(() -> 1); }, SQLException.class); + assertTrue(message.contains("Supplier callback does not support varargs")); + } + } + + public static void test_register_scalar_function_builder_java_function_rejects_varargs() throws Exception { + try (DuckDBScalarFunctionBuilder builder = DuckDBFunctions.scalarFunction(); + DuckDBLogicalType intType = DuckDBLogicalType.of(DuckDBColumnType.INTEGER)) { + builder.withName("java_invalid_function_varargs") + .withParameter(Integer.class) + .withReturnType(Integer.class) + .withVarArgs(intType); + String message = assertThrows(() -> { builder.withFunction((Integer x) -> x + 1); }, SQLException.class); + assertTrue(message.contains("Function callback does not support varargs")); + assertTrue(message.contains("withVarArgsFunction")); + } + } + + public static void test_register_scalar_function_builder_with_int_function_rejects_varargs() throws Exception { + try (DuckDBScalarFunctionBuilder builder = DuckDBFunctions.scalarFunction(); + DuckDBLogicalType intType = DuckDBLogicalType.of(DuckDBColumnType.INTEGER)) { + builder.withName("java_invalid_int_function_varargs") + .withParameter(Integer.class) + .withReturnType(Integer.class) + .withVarArgs(intType); + String message = assertThrows(() -> { builder.withIntFunction(x -> x + 1); }, SQLException.class); + assertTrue(message.contains("withIntFunction does not support varargs")); + } + } + + public static void test_register_scalar_function_builder_with_int_function_rejects_wrong_types() throws Exception { + try (DuckDBScalarFunctionBuilder builder = DuckDBFunctions.scalarFunction()) { + builder.withName("java_invalid_int_function_type") + .withParameter(Double.class) + .withReturnType(Integer.class); + String message = assertThrows(() -> { builder.withIntFunction(x -> x + 1); }, SQLException.class); + assertTrue(message.contains("withIntFunction requires parameter 0 to be INTEGER")); + } + } + + public static void test_register_scalar_function_builder_java_bifunction_rejects_varargs() throws Exception { + try (DuckDBScalarFunctionBuilder builder = DuckDBFunctions.scalarFunction(); + DuckDBLogicalType intType = DuckDBLogicalType.of(DuckDBColumnType.INTEGER)) { + builder.withName("java_invalid_bifunction_varargs") + .withParameter(Integer.class) + .withParameter(Integer.class) + .withReturnType(Integer.class) + .withVarArgs(intType); + String message = + assertThrows(() -> { builder.withFunction((Integer x, Integer y) -> x + y); }, SQLException.class); + assertTrue(message.contains("BiFunction callback does not support varargs")); + assertTrue(message.contains("withVarArgsFunction")); + } + } + + public static void test_register_scalar_function_builder_with_double_function_rejects_wrong_types() + throws Exception { + try (DuckDBScalarFunctionBuilder builder = DuckDBFunctions.scalarFunction()) { + builder.withName("java_invalid_double_function_type") + .withParameter(Integer.class) + .withReturnType(Double.class); + String message = assertThrows(() -> { builder.withDoubleFunction(x -> x + 0.5d); }, SQLException.class); + assertTrue(message.contains("withDoubleFunction requires parameter 0 to be DOUBLE")); + } + } + + public static void test_register_scalar_function_builder_with_long_function_rejects_varargs() throws Exception { + try (DuckDBScalarFunctionBuilder builder = DuckDBFunctions.scalarFunction(); + DuckDBLogicalType bigintType = DuckDBLogicalType.of(DuckDBColumnType.BIGINT)) { + builder.withName("java_invalid_long_function_varargs") + .withParameter(Long.class) + .withReturnType(Long.class) + .withVarArgs(bigintType); + String message = assertThrows(() -> { builder.withLongFunction(x -> x + 1); }, SQLException.class); + assertTrue(message.contains("withLongFunction does not support varargs")); + } + } + + public static void test_register_scalar_function_builder_with_long_function_rejects_wrong_types() throws Exception { + try (DuckDBScalarFunctionBuilder builder = DuckDBFunctions.scalarFunction()) { + builder.withName("java_invalid_long_function_type").withParameter(Integer.class).withReturnType(Long.class); + String message = assertThrows(() -> { builder.withLongFunction(x -> x + 1); }, SQLException.class); + assertTrue(message.contains("withLongFunction requires parameter 0 to be BIGINT")); + } + } + + public static void test_register_scalar_function_builder_java_varargs_function() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement(); + DuckDBLogicalType intType = DuckDBLogicalType.of(DuckDBColumnType.INTEGER)) { + DuckDBFunctions.scalarFunction() + .withName("java_sum_varargs_function") + .withParameter(Integer.class) + .withVarArgs(intType) + .withReturnType(Integer.class) + .withVarArgsFunction(args -> { + int sum = 0; + for (Object arg : args) { + sum += (Integer) arg; + } + return sum; + }) + .register(conn); + + try (ResultSet rs = + stmt.executeQuery("SELECT java_sum_varargs_function(1, 2, 3), java_sum_varargs_function(5), " + + "java_sum_varargs_function(NULL, 2), java_sum_varargs_function(2, NULL)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 6); + assertEquals(rs.getObject(2, Integer.class), 5); + assertEquals(rs.getObject(3), null); + assertEquals(rs.getObject(4), null); + assertFalse(rs.next()); + } + } + } + + public static void test_register_scalar_function_builder_java_varargs_function_requires_varargs() throws Exception { + try (DuckDBScalarFunctionBuilder builder = DuckDBFunctions.scalarFunction()) { + builder.withName("java_invalid_varargs_function") + .withParameter(Integer.class) + .withReturnType(Integer.class); + String message = assertThrows(() -> { builder.withVarArgsFunction(args -> 0); }, SQLException.class); + assertTrue(message.contains("requires withVarArgs")); + } + } + + public static void test_register_scalar_function_builder_java_function_supported_class_types() throws Exception { + Function notBoolean = value -> null != value ? !value : null; + Function addTinyInt = value -> null != value ? (byte) (value + 1) : null; + Function addBigInt = value -> null != value ? value + 3 : null; + Function addDouble = value -> null != value ? value + 0.5 : null; + Function suffixString = value -> null != value ? value + "_ok" : null; + Function addDate = value -> null != value ? value.plusDays(2) : null; + Function addTimestamp = value -> null != value ? value.plusMinutes(30) : null; + Function addTimestampTz = value -> null != value ? value.plusMinutes(5) : null; + assertUnaryJavaFunction("java_not_bool_function", Boolean.class, Boolean.class, notBoolean, + "SELECT java_not_bool_function(v) FROM (VALUES (TRUE), (NULL), (FALSE)) t(v)", rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Boolean.class), false); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Boolean.class), true); + assertFalse(rs.next()); + }); + + assertUnaryJavaFunction( + "java_add_tinyint_function", Byte.class, Byte.class, addTinyInt, + "SELECT java_add_tinyint_function(v) FROM (VALUES (41::TINYINT), (NULL), (-2::TINYINT)) t(v)", rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Byte.class), (byte) 42); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Byte.class), (byte) -1); + assertFalse(rs.next()); + }); + + assertUnaryJavaFunction( + "java_add_bigint_function", Long.class, Long.class, addBigInt, + "SELECT java_add_bigint_function(v) FROM (VALUES (39::BIGINT), (NULL), (-5::BIGINT)) t(v)", rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Long.class), 42L); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Long.class), -2L); + assertFalse(rs.next()); + }); + + assertUnaryJavaFunction( + "java_add_double_function", Double.class, Double.class, addDouble, + "SELECT java_add_double_function(v) FROM (VALUES (41.5::DOUBLE), (NULL), (-2.5::DOUBLE)) t(v)", rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Double.class), 42.0); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Double.class), -2.0); + assertFalse(rs.next()); + }); + + assertUnaryJavaFunction("java_suffix_varchar_function", String.class, String.class, suffixString, + "SELECT java_suffix_varchar_function(v) FROM (VALUES ('duck'), (NULL), ('db')) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, String.class), "duck_ok"); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, String.class), "db_ok"); + assertFalse(rs.next()); + }); + + assertUnaryJavaFunction( + "java_add_date_function", LocalDate.class, LocalDate.class, addDate, + "SELECT java_add_date_function(v) FROM (VALUES (DATE '2024-07-21'), (NULL), (DATE '2024-07-30')) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, LocalDate.class), LocalDate.of(2024, 7, 23)); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, LocalDate.class), LocalDate.of(2024, 8, 1)); + assertFalse(rs.next()); + }); + + assertUnaryJavaFunction( + "java_add_timestamp_function", LocalDateTime.class, LocalDateTime.class, addTimestamp, + "SELECT java_add_timestamp_function(v) FROM (VALUES " + + "(TIMESTAMP '2024-07-21 12:34:56.123456'), " + + "(NULL), " + + "(TIMESTAMP '1969-12-31 23:45:00')) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, LocalDateTime.class), LocalDateTime.of(2024, 7, 21, 13, 4, 56, 123456000)); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, LocalDateTime.class), LocalDateTime.of(1970, 1, 1, 0, 15, 0)); + assertFalse(rs.next()); + }); + + assertUnaryJavaFunction( + "java_add_timestamptz_function", OffsetDateTime.class, OffsetDateTime.class, addTimestampTz, + "SELECT java_add_timestamptz_function(v) FROM (VALUES " + + "(TIMESTAMPTZ '2024-07-21 12:34:56.123456+02:00'), (NULL)) t(v)", + rs -> { + assertTrue(rs.next()); + assertTrue(rs.getObject(1, OffsetDateTime.class) + .isEqual(OffsetDateTime.of(2024, 7, 21, 10, 39, 56, 123456000, ZoneOffset.UTC))); + assertTrue(rs.next()); + assertNullRow(rs); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_builder_java_function_supported_unsigned_types() throws Exception { + Function addUTinyInt = value -> null != value ? (short) (value + 1) : null; + Function addUSmallInt = value -> null != value ? value + 2 : null; + Function addUInteger = value -> null != value ? value + 3 : null; + Function addUBigInt = value -> null != value ? value.add(BigInteger.ONE) : null; + assertUnaryJavaFunction( + "java_add_utinyint_function", DuckDBColumnType.UTINYINT, DuckDBColumnType.UTINYINT, addUTinyInt, + "SELECT java_add_utinyint_function(v) FROM (VALUES (41::UTINYINT), (NULL), (254::UTINYINT)) t(v)", rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Short.class), (short) 42); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Short.class), (short) 255); + assertFalse(rs.next()); + }); + + assertUnaryJavaFunction( + "java_add_usmallint_function", DuckDBColumnType.USMALLINT, DuckDBColumnType.USMALLINT, addUSmallInt, + "SELECT java_add_usmallint_function(v) FROM (VALUES (40::USMALLINT), (NULL), (65533::USMALLINT)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 42); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 65535); + assertFalse(rs.next()); + }); + + assertUnaryJavaFunction( + "java_add_uinteger_function", DuckDBColumnType.UINTEGER, DuckDBColumnType.UINTEGER, addUInteger, + "SELECT java_add_uinteger_function(v) FROM (VALUES (39::UINTEGER), (NULL), (4294967292::UINTEGER)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Long.class), 42L); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Long.class), 4294967295L); + assertFalse(rs.next()); + }); + + assertUnaryJavaFunction( + "java_add_ubigint_function", DuckDBColumnType.UBIGINT, DuckDBColumnType.UBIGINT, addUBigInt, + "SELECT java_add_ubigint_function(v) FROM (VALUES (41::UBIGINT), (NULL), " + + "(18446744073709551614::UBIGINT)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, BigInteger.class), new BigInteger("42")); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, BigInteger.class), new BigInteger("18446744073709551615")); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_builder_java_function_hugeint_class_mapping() throws Exception { + Function addHugeInt = value -> null != value ? value.add(BigInteger.ONE) : null; + assertUnaryJavaFunction("java_add_hugeint_function", BigInteger.class, BigInteger.class, addHugeInt, + "SELECT java_add_hugeint_function(v) FROM (VALUES (CAST('41' AS HUGEINT)), (NULL), " + + "(CAST('170141183460469231731687303715884105726' AS HUGEINT))) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, BigInteger.class), new BigInteger("42")); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, BigInteger.class), + new BigInteger("170141183460469231731687303715884105727")); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_builder_java_function_decimal() throws Exception { + try (DuckDBLogicalType decimalType = DuckDBLogicalType.decimal(10, 2)) { + Function addDecimal = + value -> null != value ? value.add(new BigDecimal("1.25")) : null; + assertUnaryJavaFunction("java_add_decimal_function", decimalType, decimalType, addDecimal, + "SELECT java_add_decimal_function(v) FROM (VALUES (CAST(40.75 AS DECIMAL(10,2))), " + + "(NULL), (CAST(-1.25 AS DECIMAL(10,2)))) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, BigDecimal.class), new BigDecimal("42.00")); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, BigDecimal.class), new BigDecimal("0.00")); + assertFalse(rs.next()); + }); + } + } + + public static void test_register_scalar_function_builder_java_function_sql_date_class_mapping() throws Exception { + Function addSqlDate = + value -> null != value ? java.sql.Date.valueOf(value.toLocalDate().plusDays(1)) : null; + assertUnaryJavaFunction( + "java_add_sql_date_function", java.sql.Date.class, java.sql.Date.class, addSqlDate, + "SELECT java_add_sql_date_function(v) FROM (VALUES (DATE '2024-07-21'), (NULL), (DATE '2024-07-30')) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, java.sql.Date.class), java.sql.Date.valueOf("2024-07-22")); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, java.sql.Date.class), java.sql.Date.valueOf("2024-07-31")); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_builder_java_function_sql_timestamp_class_mapping() + throws Exception { + Function addSqlTimestamp = + value -> null != value ? java.sql.Timestamp.valueOf(value.toLocalDateTime().plusSeconds(1)) : null; + assertUnaryJavaFunction("java_add_sql_timestamp_function", java.sql.Timestamp.class, java.sql.Timestamp.class, + addSqlTimestamp, + "SELECT java_add_sql_timestamp_function(v) FROM (VALUES " + + "(TIMESTAMP '2024-07-21 12:34:56.123456'), (NULL)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, java.sql.Timestamp.class), + java.sql.Timestamp.valueOf("2024-07-21 12:34:57.123456")); + assertTrue(rs.next()); + assertNullRow(rs); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_builder_java_function_java_util_date_class_mapping() + throws Exception { + Function addUtilDate = + value -> null != value ? new java.util.Date(value.getTime() + 1000L) : null; + assertUnaryJavaFunction("java_add_java_util_date_function", java.util.Date.class, java.util.Date.class, + addUtilDate, + "SELECT java_add_java_util_date_function(v) = date_trunc('millisecond', v) + " + + "INTERVAL 1 SECOND FROM (VALUES " + + "(TIMESTAMP '2024-07-21 12:34:56.123456'), (NULL)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Boolean.class), true); + assertTrue(rs.next()); + assertNullRow(rs); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_builder_java_bifunction_supported_types() throws Exception { + BiFunction concatUnderscore = + (left, right) -> left != null && right != null ? left + "_" + right : null; + BiFunction sumDouble = + (left, right) -> left != null && right != null ? Double.sum(left, right) : null; + assertBinaryJavaFunction( + "java_concat_varchar_bifunction", String.class, String.class, String.class, concatUnderscore, + "SELECT java_concat_varchar_bifunction(a, b) FROM (VALUES ('duck', 'db'), (NULL, 'x'), ('a', NULL)) t(a, b)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, String.class), "duck_db"); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertNullRow(rs); + assertFalse(rs.next()); + }); + + assertBinaryJavaFunction( + "java_add_double_bifunction", Double.class, Double.class, Double.class, sumDouble, + "SELECT java_add_double_bifunction(a, b) FROM (VALUES (10.5::DOUBLE, 31.5::DOUBLE), (NULL, 2::DOUBLE), (2::DOUBLE, NULL)) t(a, b)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Double.class), 42.0); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertNullRow(rs); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_typed_logical_type() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement(); + DuckDBLogicalType intType = DuckDBLogicalType.of(DuckDBColumnType.INTEGER)) { + DuckDBFunctions.scalarFunction() + .withName("java_add_int_typed") + .withParameter(intType) + .withReturnType(intType) + .withVectorizedFunction( + ctx -> { ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setInt(row.getInt(0) + 1)); }) + .register(conn); + try (ResultSet rs = + stmt.executeQuery("SELECT java_add_int_typed(v) FROM (VALUES (1), (NULL), (41)) t(v)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 2); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 42); + assertFalse(rs.next()); + } + } + } + + public static void test_register_scalar_function_parallel() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement(); + DuckDBLogicalType bigintType = DuckDBLogicalType.of(DuckDBColumnType.BIGINT)) { + stmt.execute("PRAGMA threads=4"); + DuckDBFunctions.scalarFunction() + .withName("java_add_one_bigint") + .withParameter(bigintType) + .withReturnType(bigintType) + .withVectorizedFunction(ctx -> { + DuckDBWritableVector out = ctx.output(); + DuckDBReadableVector in = ctx.input(0); + long rowCount = ctx.rowCount(); + for (long row = 0; row < rowCount; row++) { + out.setLong(row, in.getLong(row) + 1); + } + }) + .register(conn); + + try (ResultSet rs = stmt.executeQuery("SELECT sum(java_add_one_bigint(i)) FROM range(1000000) t(i)")) { + assertTrue(rs.next()); + assertEquals(rs.getLong(1), 500000500000L); + assertFalse(rs.wasNull()); + assertFalse(rs.next()); + } + } + } + + public static void test_register_scalar_function_context_row_stream_int() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement(); + DuckDBLogicalType intType = DuckDBLogicalType.of(DuckDBColumnType.INTEGER)) { + DuckDBFunctions.scalarFunction() + .withName("java_add_int_row_stream") + .withParameter(intType) + .withReturnType(intType) + .withVectorizedFunction( + ctx -> { ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setInt(row.getInt(0) + 1)); }) + .register(conn); + + try (ResultSet rs = + stmt.executeQuery("SELECT java_add_int_row_stream(v) FROM (VALUES (1), (NULL), (41)) t(v)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 2); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 42); + assertFalse(rs.next()); + } + } + } + + public static void test_register_scalar_function_context_row_stream_double() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement(); + DuckDBLogicalType doubleType = DuckDBLogicalType.of(DuckDBColumnType.DOUBLE)) { + DuckDBFunctions.scalarFunction() + .withName("java_add_double_row_stream") + .withParameter(doubleType) + .withReturnType(doubleType) + .withVectorizedFunction(ctx -> { + ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setDouble(row.getDouble(0) + 1.5d)); + }) + .register(conn); + + try ( + ResultSet rs = stmt.executeQuery( + "SELECT java_add_double_row_stream(v) FROM (VALUES (40.5::DOUBLE), (NULL), (-3.0::DOUBLE)) t(v)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Double.class), 42.0d); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Double.class), -1.5d); + assertFalse(rs.next()); + } + } + } + + public static void test_register_scalar_function_primitive_nulls_handling() throws Exception { + try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement()) { + DuckDBFunctions.scalarFunction() + .withName("java_primitive_nulls_handling") + .withParameter(DuckDBColumnType.BOOLEAN) + .withParameter(DuckDBColumnType.TINYINT) + .withParameter(DuckDBColumnType.UTINYINT) + .withParameter(DuckDBColumnType.SMALLINT) + .withParameter(DuckDBColumnType.USMALLINT) + .withParameter(DuckDBColumnType.INTEGER) + .withParameter(DuckDBColumnType.UINTEGER) + .withParameter(DuckDBColumnType.BIGINT) + .withParameter(DuckDBColumnType.FLOAT) + .withParameter(DuckDBColumnType.DOUBLE) + .withReturnType(DuckDBColumnType.VARCHAR) + .withSpecialHandling() + .withVectorizedFunction(ctx -> { + assertFalse(ctx.nullsPropagated()); + ctx.stream().forEachOrdered(row -> { + try { + DuckDBReadableVector booleanVector = ctx.input(0); + DuckDBReadableVector intVector = ctx.input(5); + assertThrows( + () -> { booleanVector.getBoolean(row.index()); }, DuckDBFunctionException.class); + assertTrue(booleanVector.getBoolean(row.index(), true)); + assertThrows(() -> { intVector.getInt(row.index()); }, DuckDBFunctionException.class); + assertEquals(intVector.getInt(row.index(), 42), 42); + + assertThrows(() -> { row.getBoolean(0); }, DuckDBFunctionException.class); + try { + row.getBoolean(0); + fail("Expected row.getBoolean(0) to fail on NULL"); + } catch (DuckDBFunctionException exception) { + assertTrue(exception.getMessage().contains("Failed to read BOOLEAN")); + assertNotNull(exception.getCause()); + assertTrue(exception.getCause() instanceof DuckDBFunctionException); + assertTrue(exception.getCause().getMessage().contains("Primitive value for BOOLEAN")); + } + assertTrue(row.getBoolean(0, true)); + assertThrows(() -> { row.getByte(1); }, DuckDBFunctionException.class); + assertEquals(row.getByte(1, (byte) 42), (byte) 42); + assertThrows(() -> { row.getUint8(2); }, DuckDBFunctionException.class); + assertEquals(row.getUint8(2, (short) 42), (short) 42); + assertThrows(() -> { row.getShort(3); }, DuckDBFunctionException.class); + assertEquals(row.getShort(3, (short) 42), (short) 42); + assertThrows(() -> { row.getUint16(4); }, DuckDBFunctionException.class); + assertEquals(row.getUint16(4, 42), 42); + assertThrows(() -> { row.getInt(5); }, DuckDBFunctionException.class); + assertEquals(row.getInt(5, 42), 42); + assertThrows(() -> { row.getUint32(6); }, DuckDBFunctionException.class); + assertEquals(row.getUint32(6, (long) 42), (long) 42); + assertThrows(() -> { row.getLong(7); }, DuckDBFunctionException.class); + assertEquals(row.getLong(7, (long) 42), (long) 42); + assertThrows(() -> { row.getFloat(8); }, DuckDBFunctionException.class); + assertEquals(row.getFloat(8, (float) 42.1), (float) 42.1); + assertThrows(() -> { row.getDouble(9); }, DuckDBFunctionException.class); + assertEquals(row.getDouble(9, 42.1), 42.1); + } catch (Exception e) { + throw new RuntimeException(e); + } + row.setString("ok"); + }); + }) + .register(conn); + + try ( + ResultSet rs = stmt.executeQuery( + "SELECT java_primitive_nulls_handling(NULL::BOOLEAN, NULL::TINYINT, NULL::UTINYINT, NULL::SMALLINT," + + " NULL::USMALLINT, NULL::INTEGER, NULL::UINTEGER, NULL::BIGINT, NULL::FLOAT, NULL::DOUBLE)")) { + assertTrue(rs.next()); + assertEquals(rs.getString(1), "ok"); + assertFalse(rs.next()); + } + } + } + + public static void test_register_scalar_function_context_row_stream_propagate_nulls_false() throws Exception { + assertUnaryScalarFunction( + "java_suffix_varchar_row_stream_nullable", DuckDBColumnType.VARCHAR, DuckDBColumnType.VARCHAR, + ctx + -> { + ctx.stream().forEachOrdered(row -> { + String value = row.getString(0); + row.setString(value == null ? "NULL_SEEN" : value + "_ok"); + }); + }, + "SELECT java_suffix_varchar_row_stream_nullable(v) FROM (VALUES ('duck'), (NULL), ('db')) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, String.class), "duck_ok"); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, String.class), "NULL_SEEN"); + assertFalse(rs.wasNull()); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, String.class), "db_ok"); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_integer_append() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBFunctions.scalarFunction() + .withName("java_add_int_append") + .withParameter(DuckDBColumnType.INTEGER) + .withReturnType(DuckDBColumnType.INTEGER) + .withIntFunction(x -> x + 1) + .register(conn); + + try (ResultSet rs = + stmt.executeQuery("SELECT java_add_int_append(v) FROM (VALUES (41), (NULL), (-2)) t(v)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 42); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), -1); + assertFalse(rs.next()); + } + } + } + + public static void test_register_scalar_function_bigint_append() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBFunctions.scalarFunction() + .withName("java_add_bigint_append") + .withParameter(DuckDBColumnType.BIGINT) + .withReturnType(DuckDBColumnType.BIGINT) + .withLongFunction(x -> x + 1) + .register(conn); + + try (ResultSet rs = stmt.executeQuery( + "SELECT java_add_bigint_append(v) FROM (VALUES (41::BIGINT), (NULL), (-2::BIGINT)) t(v)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Long.class), 42L); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Long.class), -1L); + assertFalse(rs.next()); + } + } + } + + public static void test_register_scalar_function_exception_propagation() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement(); + DuckDBLogicalType intType = DuckDBLogicalType.of(DuckDBColumnType.INTEGER)) { + DuckDBFunctions.scalarFunction() + .withName("java_throws_exception") + .withParameter(intType) + .withReturnType(intType) + .withVectorizedFunction(ctx -> { throw new IllegalStateException("boom"); }) + .register(conn); + String message = + assertThrows(() -> { stmt.executeQuery("SELECT java_throws_exception(1)"); }, SQLException.class); + assertTrue(message.contains("Java scalar function threw exception")); + assertTrue(message.contains("IllegalStateException")); + assertTrue(message.contains("boom")); + } + } + + public static void test_register_scalar_function_boolean() throws Exception { + assertUnaryScalarFunction( + "java_not_bool", DuckDBColumnType.BOOLEAN, DuckDBColumnType.BOOLEAN, + ctx + -> { ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setBoolean(!row.getBoolean(0))); }, + "SELECT java_not_bool(v) FROM (VALUES (TRUE), (NULL), (FALSE)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Boolean.class), false); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Boolean.class), true); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_tinyint() throws Exception { + assertUnaryScalarFunction( + "java_add_tinyint", DuckDBColumnType.TINYINT, DuckDBColumnType.TINYINT, + ctx + -> { ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setByte((byte) (row.getByte(0) + 1))); }, + "SELECT java_add_tinyint(v) FROM (VALUES (41::TINYINT), (NULL), (-2::TINYINT)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Byte.class), (byte) 42); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Byte.class), (byte) -1); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_smallint() throws Exception { + assertUnaryScalarFunction( + "java_add_smallint", DuckDBColumnType.SMALLINT, DuckDBColumnType.SMALLINT, + ctx + -> { + ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setShort((short) (row.getShort(0) + 2))); + }, + "SELECT java_add_smallint(v) FROM (VALUES (40::SMALLINT), (NULL), (-4::SMALLINT)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Short.class), (short) 42); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Short.class), (short) -2); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_integer() throws Exception { + assertUnaryScalarFunction( + "java_add_int", DuckDBColumnType.INTEGER, DuckDBColumnType.INTEGER, + ctx + -> { ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setInt(row.getInt(0) + 1)); }, + "SELECT java_add_int(v) FROM (VALUES (1), (NULL), (41)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 2); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 42); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_integer_revalidates_after_null() throws Exception { + assertUnaryScalarFunction("java_revalidate_int", DuckDBColumnType.INTEGER, DuckDBColumnType.INTEGER, + ctx + -> { + ctx.propagateNulls(true).stream().forEachOrdered(row -> { + row.setNull(); + row.setInt(row.getInt(0) + 1); + }); + }, + "SELECT java_revalidate_int(v) FROM (VALUES (41), (NULL)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 42); + assertFalse(rs.wasNull()); + assertTrue(rs.next()); + assertNullRow(rs); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_bigint() throws Exception { + assertUnaryScalarFunction( + "java_add_bigint", DuckDBColumnType.BIGINT, DuckDBColumnType.BIGINT, + ctx + -> { ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setLong(row.getLong(0) + 3)); }, + "SELECT java_add_bigint(v) FROM (VALUES (39::BIGINT), (NULL), (-5::BIGINT)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Long.class), 42L); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Long.class), -2L); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_utinyint() throws Exception { + assertUnaryScalarFunction( + "java_add_utinyint", DuckDBColumnType.UTINYINT, DuckDBColumnType.UTINYINT, + ctx + -> { ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setUint8(row.getUint8(0) + 1)); }, + "SELECT java_add_utinyint(v) FROM (VALUES (41::UTINYINT), (NULL), (254::UTINYINT)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Short.class), (short) 42); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Short.class), (short) 255); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_usmallint() throws Exception { + assertUnaryScalarFunction( + "java_add_usmallint", DuckDBColumnType.USMALLINT, DuckDBColumnType.USMALLINT, + ctx + -> { ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setUint16(row.getUint16(0) + 2)); }, + "SELECT java_add_usmallint(v) FROM (VALUES (40::USMALLINT), (NULL), (65533::USMALLINT)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 42); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 65535); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_uinteger() throws Exception { + assertUnaryScalarFunction( + "java_add_uinteger", DuckDBColumnType.UINTEGER, DuckDBColumnType.UINTEGER, + ctx + -> { ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setUint32(row.getUint32(0) + 3)); }, + "SELECT java_add_uinteger(v) FROM (VALUES (39::UINTEGER), (NULL), (4294967292::UINTEGER)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Long.class), 42L); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Long.class), 4294967295L); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_ubigint() throws Exception { + assertUnaryScalarFunction( + "java_add_ubigint", DuckDBColumnType.UBIGINT, DuckDBColumnType.UBIGINT, + ctx + -> { + BigInteger increment = BigInteger.ONE; + ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setUint64(row.getUint64(0).add(increment))); + }, + "SELECT java_add_ubigint(v) FROM (VALUES (41::UBIGINT), (NULL), " + + "(18446744073709551614::UBIGINT)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, BigInteger.class), new BigInteger("42")); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, BigInteger.class), new BigInteger("18446744073709551615")); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_uhugeint() throws Exception { + assertUnaryScalarFunction("java_add_uhugeint", DuckDBColumnType.UHUGEINT, DuckDBColumnType.UHUGEINT, + ctx + -> { + BigInteger increment = BigInteger.ONE; + ctx.propagateNulls(true).stream().forEachOrdered( + row -> row.setUHugeInt(row.getUHugeInt(0).add(increment))); + }, + "SELECT java_add_uhugeint(v) FROM (VALUES (CAST('41' AS UHUGEINT)), (NULL), " + + "(CAST('340282366920938463463374607431768211454' AS UHUGEINT))) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, BigInteger.class), new BigInteger("42")); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, BigInteger.class), + new BigInteger("340282366920938463463374607431768211455")); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_builder_java_function_uhugeint() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBFunctions.scalarFunction() + .withName("java_add_uhugeint_function") + .withParameter(DuckDBColumnType.UHUGEINT) + .withReturnType(DuckDBColumnType.UHUGEINT) + .withFunction((BigInteger value) -> null != value ? value.add(BigInteger.ONE) : null) + .register(conn); + + try (ResultSet rs = stmt.executeQuery( + "SELECT java_add_uhugeint_function(v) FROM (VALUES (CAST('41' AS UHUGEINT)), (NULL), " + + "(CAST('340282366920938463463374607431768211454' AS UHUGEINT))) t(v)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, BigInteger.class), new BigInteger("42")); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, BigInteger.class), + new BigInteger("340282366920938463463374607431768211455")); + assertFalse(rs.next()); + } + } + } + + public static void test_register_scalar_function_float() throws Exception { + assertUnaryScalarFunction( + "java_add_float", DuckDBColumnType.FLOAT, DuckDBColumnType.FLOAT, + ctx + -> { ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setFloat(row.getFloat(0) + 1.25f)); }, + "SELECT java_add_float(v) FROM (VALUES (40.75::FLOAT), (NULL), (-2.5::FLOAT)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Float.class), 42.0f); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Float.class), -1.25f); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_double() throws Exception { + assertUnaryScalarFunction( + "java_add_double", DuckDBColumnType.DOUBLE, DuckDBColumnType.DOUBLE, + ctx + -> { ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setDouble(row.getDouble(0) + 1.5d)); }, + "SELECT java_add_double(v) FROM (VALUES (40.5::DOUBLE), (NULL), (-3.0::DOUBLE)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Double.class), 42.0d); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Double.class), -1.5d); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_decimal() throws Exception { + try (DuckDBLogicalType decimalType = DuckDBLogicalType.decimal(38, 10)) { + assertUnaryScalarFunction("java_add_decimal", decimalType, decimalType, + ctx + -> { + BigDecimal increment = new BigDecimal("0.0000000001"); + ctx.propagateNulls(true).stream().forEachOrdered( + row -> row.setBigDecimal(row.getBigDecimal(0).add(increment))); + }, + "SELECT java_add_decimal(v) FROM (VALUES " + + "(CAST('12345678901234567890.1234567890' AS DECIMAL(38,10))), " + + "(NULL), " + + "(CAST('-0.0000000001' AS DECIMAL(38,10)))) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, BigDecimal.class), + new BigDecimal("12345678901234567890.1234567891")); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, BigDecimal.class), BigDecimal.ZERO.setScale(10)); + assertFalse(rs.next()); + }); + } + } + + public static void test_register_scalar_function_decimal_precision_overflow() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement(); + DuckDBLogicalType decimalType = DuckDBLogicalType.decimal(10, 2)) { + DuckDBFunctions.scalarFunction() + .withName("java_decimal_precision_overflow") + .withParameter(decimalType) + .withReturnType(decimalType) + .withVectorizedFunction(ctx -> { + DuckDBWritableVector out = ctx.output(); + long rowCount = ctx.rowCount(); + for (int i = 0; i < rowCount; i++) { + out.setBigDecimal(i, new BigDecimal("12345678901.23")); + } + }) + .register(conn); + + String err = assertThrows(() -> { + stmt.execute("SELECT java_decimal_precision_overflow(CAST(1 AS DECIMAL(10,2)))"); + }, SQLException.class); + assertTrue(err.contains("DECIMAL(10,2)")); + } + } + + public static void test_register_scalar_function_decimal_scale_overflow() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement(); + DuckDBLogicalType decimalType = DuckDBLogicalType.decimal(10, 2)) { + DuckDBFunctions.scalarFunction() + .withName("java_decimal_scale_overflow") + .withParameter(decimalType) + .withReturnType(decimalType) + .withVectorizedFunction(ctx -> { + DuckDBWritableVector out = ctx.output(); + long rowCount = ctx.rowCount(); + for (int i = 0; i < rowCount; i++) { + out.setBigDecimal(i, new BigDecimal("1.234")); + } + }) + .register(conn); + + String err = assertThrows(() -> { + stmt.execute("SELECT java_decimal_scale_overflow(CAST(1 AS DECIMAL(10,2)))"); + }, SQLException.class); + assertTrue(err.contains("DECIMAL(10,2)")); + } + } + + public static void test_register_scalar_function_date() throws Exception { + assertUnaryScalarFunction( + "java_add_date", DuckDBColumnType.DATE, DuckDBColumnType.DATE, + ctx + -> { + ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setDate(row.getLocalDate(0).plusDays(2))); + }, + "SELECT java_add_date(v) FROM (VALUES (DATE '2024-07-20'), (NULL), (DATE '1969-12-31')) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, LocalDate.class), LocalDate.of(2024, 7, 22)); + assertEquals(rs.getDate(1), Date.valueOf(LocalDate.of(2024, 7, 22))); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, LocalDate.class), LocalDate.of(1970, 1, 2)); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_date_from_java_util_date() throws Exception { + assertUnaryScalarFunction("java_date_from_util_date", DuckDBColumnType.DATE, DuckDBColumnType.DATE, + ctx + -> { + ctx.propagateNulls(true).stream().forEachOrdered(row -> { + LocalDate value = row.getLocalDate(0).plusDays(1); + row.setDate(java.util.Date.from(value.atStartOfDay(UTC).toInstant())); + }); + }, + "SELECT java_date_from_util_date(v) FROM (VALUES (DATE '2024-07-21'), (NULL)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, LocalDate.class), LocalDate.of(2024, 7, 22)); + assertTrue(rs.next()); + assertNullRow(rs); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_timestamp() throws Exception { + assertUnaryScalarFunction("java_add_timestamp", DuckDBColumnType.TIMESTAMP, DuckDBColumnType.TIMESTAMP, + ctx + -> { + ctx.propagateNulls(true).stream().forEachOrdered( + row -> row.setTimestamp(row.getLocalDateTime(0).plusMinutes(30))); + }, + "SELECT java_add_timestamp(v) FROM (VALUES " + + "(TIMESTAMP '2024-07-21 12:34:56.123456'), " + + "(NULL), " + + "(TIMESTAMP '1969-12-31 23:45:00')) t(v)", + rs -> { + assertTrue(rs.next()); + Timestamp ts1 = rs.getTimestamp(1); + assertEquals(ts1, Timestamp.valueOf("2024-07-21 13:04:56.123456")); + assertEquals(rs.getObject(1, LocalDateTime.class), + LocalDateTime.of(2024, 7, 21, 13, 4, 56, 123456000)); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getTimestamp(1), Timestamp.valueOf("1970-01-01 00:15:00")); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_timestamp_s() throws Exception { + assertUnaryScalarFunction( + "java_add_timestamp_s", DuckDBColumnType.TIMESTAMP_S, DuckDBColumnType.TIMESTAMP_S, + ctx + -> { + ctx.propagateNulls(true).stream().forEachOrdered( + row -> row.setTimestamp(row.getLocalDateTime(0).plusSeconds(2))); + }, + "SELECT java_add_timestamp_s(v) FROM (VALUES (TIMESTAMP_S '2024-07-21 12:34:56'), (NULL)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getTimestamp(1), Timestamp.valueOf("2024-07-21 12:34:58")); + assertTrue(rs.next()); + assertNullRow(rs); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_timestamp_s_pre_epoch() throws Exception { + assertUnaryScalarFunction("java_copy_timestamp_s_pre_epoch", DuckDBColumnType.TIMESTAMP, + DuckDBColumnType.TIMESTAMP_S, + ctx + -> { ctx.stream().forEachOrdered(row -> row.setTimestamp(row.getLocalDateTime(0))); }, + "SELECT java_copy_timestamp_s_pre_epoch(v) FROM (VALUES " + + "(TIMESTAMP '1969-12-31 23:59:59.999')) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getTimestamp(1), Timestamp.valueOf("1969-12-31 23:59:59")); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_timestamp_ms() throws Exception { + assertUnaryScalarFunction("java_add_timestamp_ms", DuckDBColumnType.TIMESTAMP_MS, DuckDBColumnType.TIMESTAMP_MS, + ctx + -> { + ctx.propagateNulls(true).stream().forEachOrdered( + row -> row.setTimestamp(row.getLocalDateTime(0).plusNanos(7_000_000))); + }, + "SELECT java_add_timestamp_ms(v) FROM (VALUES " + + "(TIMESTAMP_MS '2024-07-21 12:34:56.123'), (NULL)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, LocalDateTime.class), + LocalDateTime.of(2024, 7, 21, 12, 34, 56, 130_000_000)); + assertTrue(rs.next()); + assertNullRow(rs); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_timestamp_ms_pre_epoch() throws Exception { + assertUnaryScalarFunction("java_copy_timestamp_ms_pre_epoch", DuckDBColumnType.TIMESTAMP, + DuckDBColumnType.TIMESTAMP_MS, + ctx + -> { ctx.stream().forEachOrdered(row -> row.setTimestamp(row.getLocalDateTime(0))); }, + "SELECT java_copy_timestamp_ms_pre_epoch(v) FROM (VALUES " + + "(TIMESTAMP '1969-12-31 23:59:59.9995')) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, LocalDateTime.class), + LocalDateTime.of(1969, 12, 31, 23, 59, 59, 999_000_000)); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_timestamp_ns() throws Exception { + assertUnaryScalarFunction("java_add_timestamp_ns", DuckDBColumnType.TIMESTAMP_NS, DuckDBColumnType.TIMESTAMP_NS, + ctx + -> { + ctx.propagateNulls(true).stream().forEachOrdered( + row -> row.setTimestamp(row.getLocalDateTime(0).plusNanos(789))); + }, + "SELECT java_add_timestamp_ns(v) FROM (VALUES " + + "(TIMESTAMP_NS '2024-07-21 12:34:56.123456789'), (NULL)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, LocalDateTime.class), + LocalDateTime.of(2024, 7, 21, 12, 34, 56, 123457578)); + assertTrue(rs.next()); + assertNullRow(rs); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_timestamptz() throws Exception { + assertUnaryScalarFunction( + "java_add_timestamptz", DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE, + DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE, + ctx + -> { + ctx.propagateNulls(true).stream().forEachOrdered( + row -> row.setOffsetDateTime(row.getOffsetDateTime(0).plusMinutes(5))); + }, + "SELECT java_add_timestamptz(v) FROM (VALUES " + + "(TIMESTAMPTZ '2024-07-21 12:34:56.123456+02:00'), (NULL)) t(v)", + rs -> { + assertTrue(rs.next()); + assertTrue(rs.getObject(1, OffsetDateTime.class) + .isEqual(OffsetDateTime.of(2024, 7, 21, 10, 39, 56, 123456000, ZoneOffset.UTC))); + assertTrue(rs.next()); + assertNullRow(rs); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_timestamptz_set_timestamp() throws Exception { + assertUnaryScalarFunction( + "java_copy_timestamptz_with_timestamp", DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE, + DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE, + ctx + -> { ctx.stream().forEachOrdered(row -> row.setTimestamp(row.getTimestamp(0))); }, + "SELECT java_copy_timestamptz_with_timestamp(v) FROM (VALUES " + + "(TIMESTAMPTZ '2024-07-21 12:34:56.123456+02:00')) t(v)", + rs -> { + assertTrue(rs.next()); + assertTrue(rs.getObject(1, OffsetDateTime.class) + .isEqual(OffsetDateTime.of(2024, 7, 21, 10, 34, 56, 123456000, ZoneOffset.UTC))); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_timestamp_from_java_util_date() throws Exception { + assertUnaryScalarFunction( + "java_timestamp_from_util_date", DuckDBColumnType.TIMESTAMP, DuckDBColumnType.TIMESTAMP, + ctx + -> { + long oneSecondMillis = 1000L; + ctx.propagateNulls(true).stream().forEachOrdered( + row -> { row.setTimestamp(new java.util.Date(row.getTimestamp(0).getTime() + oneSecondMillis)); }); + }, + "SELECT epoch_ms(java_timestamp_from_util_date(v)) FROM (VALUES " + + "(TIMESTAMP '2024-07-21 12:34:56.123456'), " + + "(NULL)) t(v)", + rs -> { + assertTrue(rs.next()); + Timestamp input = Timestamp.valueOf("2024-07-21 12:34:56.123456"); + assertEquals(rs.getLong(1), input.getTime() + 1000L); + assertFalse(rs.wasNull()); + assertTrue(rs.next()); + assertNullRow(rs); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_timestamp_from_java_util_date_typed_timestamp() throws Exception { + assertUnaryScalarFunction( + "java_timestamp_from_util_ts", DuckDBColumnType.TIMESTAMP, DuckDBColumnType.TIMESTAMP, + ctx + -> { + ctx.propagateNulls(true).stream().forEachOrdered(row -> { + java.util.Date value = Timestamp.valueOf(row.getLocalDateTime(0).plusNanos(789000)); + row.setTimestamp(value); + }); + }, + "SELECT java_timestamp_from_util_ts(v) FROM (VALUES (TIMESTAMP '2024-07-21 12:34:56.123456')) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getTimestamp(1), Timestamp.valueOf("2024-07-21 12:34:56.124245")); + assertFalse(rs.wasNull()); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_timestamp_from_java_util_date_typed_sql_date() throws Exception { + assertUnaryScalarFunction( + "java_timestamp_from_util_sql_date", DuckDBColumnType.DATE, DuckDBColumnType.TIMESTAMP, + ctx + -> { + ctx.stream().forEachOrdered(row -> { + java.util.Date value = Date.valueOf(row.getLocalDate(0)); + row.setTimestamp(value); + }); + }, + "SELECT epoch_ms(java_timestamp_from_util_sql_date(v)) FROM (VALUES (DATE '2024-07-21')) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getLong(1), Date.valueOf("2024-07-21").getTime()); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_timestamp_from_java_util_date_typed_sql_time() throws Exception { + assertUnaryScalarFunction( + "java_timestamp_from_util_sql_time", DuckDBColumnType.TIMESTAMP, DuckDBColumnType.TIMESTAMP, + ctx + -> { + ctx.stream().forEachOrdered(row -> { + java.util.Date value = Time.valueOf("12:34:56"); + row.setTimestamp(value); + }); + }, + "SELECT epoch_ms(java_timestamp_from_util_sql_time(v)) FROM (VALUES (TIMESTAMP '2024-07-21 00:00:00')) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getLong(1), Time.valueOf("12:34:56").getTime()); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_timestamp_from_local_date() throws Exception { + assertUnaryScalarFunction( + "java_timestamp_from_local_date", DuckDBColumnType.DATE, DuckDBColumnType.TIMESTAMP, + ctx + -> { + ctx.propagateNulls(true).stream().forEachOrdered( + row -> row.setTimestamp(row.getLocalDate(0).plusDays(1))); + }, + "SELECT java_timestamp_from_local_date(v) FROM (VALUES (DATE '2024-07-21'), (NULL)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getTimestamp(1), Timestamp.valueOf("2024-07-22 00:00:00")); + assertEquals(rs.getObject(1, LocalDateTime.class), LocalDateTime.of(2024, 7, 22, 0, 0)); + assertTrue(rs.next()); + assertNullRow(rs); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_varchar() throws Exception { + assertUnaryScalarFunction( + "java_suffix_varchar", DuckDBColumnType.VARCHAR, DuckDBColumnType.VARCHAR, + ctx + -> { ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setString(row.getString(0) + "_java")); }, + "SELECT java_suffix_varchar(v) FROM (VALUES ('duck'), (NULL), " + + "('abcdefghijklmnop')) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, String.class), "duck_java"); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, String.class), "abcdefghijklmnop_java"); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_varchar_get_string_handles_null() throws Exception { + assertUnaryScalarFunction( + "java_echo_varchar_nullable", DuckDBColumnType.VARCHAR, DuckDBColumnType.VARCHAR, + ctx + -> { ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setString(row.getString(0))); }, + "SELECT java_echo_varchar_nullable(v) FROM (VALUES ('duck'), (NULL), " + + "('abcdefghijklmnop')) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, String.class), "duck"); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, String.class), "abcdefghijklmnop"); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_varchar_revalidates_after_null() throws Exception { + assertUnaryScalarFunction("java_revalidate_varchar", DuckDBColumnType.VARCHAR, DuckDBColumnType.VARCHAR, + ctx + -> { + ctx.propagateNulls(true).stream().forEachOrdered(row -> { + row.setNull(); + row.setString(row.getString(0) + "_ok"); + }); + }, + "SELECT java_revalidate_varchar(v) FROM (VALUES ('duck'), (NULL)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, String.class), "duck_ok"); + assertFalse(rs.wasNull()); + assertTrue(rs.next()); + assertNullRow(rs); + assertFalse(rs.next()); + }); + } + + private static void assertUnaryScalarFunction(String functionName, DuckDBColumnType parameterType, + DuckDBColumnType returnType, DuckDBScalarFunction function, + String query, ResultSetVerifier verifier) throws Exception { + try (DuckDBLogicalType parameterLogicalType = DuckDBLogicalType.of(parameterType); + DuckDBLogicalType returnLogicalType = DuckDBLogicalType.of(returnType)) { + assertUnaryScalarFunction(functionName, parameterLogicalType, returnLogicalType, function, query, verifier); + } + } + + private static void assertUnaryScalarFunction(String functionName, DuckDBLogicalType parameterType, + DuckDBLogicalType returnType, DuckDBScalarFunction function, + String query, ResultSetVerifier verifier) throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBFunctions.scalarFunction() + .withName(functionName) + .withParameter(parameterType) + .withReturnType(returnType) + .withVectorizedFunction(function) + .register(conn); + try (ResultSet rs = stmt.executeQuery(query)) { + verifier.verify(rs); + } + } + } + + private static void assertUnaryJavaFunction(String functionName, Class parameterType, Class returnType, + Function function, String query, ResultSetVerifier verifier) + throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBFunctions.scalarFunction() + .withName(functionName) + .withParameter(parameterType) + .withReturnType(returnType) + .withFunction(function) + .register(conn); + try (ResultSet rs = stmt.executeQuery(query)) { + verifier.verify(rs); + } + } + } + + private static void assertUnaryJavaFunction(String functionName, DuckDBColumnType parameterType, + DuckDBColumnType returnType, Function function, String query, + ResultSetVerifier verifier) throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBFunctions.scalarFunction() + .withName(functionName) + .withParameter(parameterType) + .withReturnType(returnType) + .withFunction(function) + .register(conn); + try (ResultSet rs = stmt.executeQuery(query)) { + verifier.verify(rs); + } + } + } + + private static void assertUnaryJavaFunction(String functionName, DuckDBLogicalType parameterType, + DuckDBLogicalType returnType, Function function, String query, + ResultSetVerifier verifier) throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBFunctions.scalarFunction() + .withName(functionName) + .withParameter(parameterType) + .withReturnType(returnType) + .withFunction(function) + .register(conn); + try (ResultSet rs = stmt.executeQuery(query)) { + verifier.verify(rs); + } + } + } + + private static void assertBinaryJavaFunction(String functionName, Class leftType, Class rightType, + Class returnType, BiFunction function, String query, + ResultSetVerifier verifier) throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBFunctions.scalarFunction() + .withName(functionName) + .withParameter(leftType) + .withParameter(rightType) + .withReturnType(returnType) + .withFunction(function) + .register(conn); + try (ResultSet rs = stmt.executeQuery(query)) { + verifier.verify(rs); + } + } + } + + private static void assertNullaryJavaFunction(String functionName, Class returnType, Supplier function, + String query, ResultSetVerifier verifier) throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBFunctions.scalarFunction() + .withName(functionName) + .withReturnType(returnType) + .withFunction(function) + .register(conn); + try (ResultSet rs = stmt.executeQuery(query)) { + verifier.verify(rs); + } + } + } + + private static void assertNullRow(ResultSet rs) throws Exception { + assertEquals(rs.getObject(1), null); + assertTrue(rs.wasNull()); + } + + private static final java.time.ZoneOffset UTC = java.time.ZoneOffset.UTC; +}