From 4707012e2b49060a5065b71e528007dee65c6757 Mon Sep 17 00:00:00 2001 From: Luis Fernando Kauer Date: Tue, 31 Mar 2026 15:23:50 -0300 Subject: [PATCH 1/9] add java vectorized scalar function support --- CMakeLists.txt | 1 + README.md | 2 + UDF.MD | 63 ++ duckdb_java.def | 9 + duckdb_java.exp | 9 + duckdb_java.map | 9 + src/jni/bindings_logical_type.cpp | 37 + src/jni/bindings_scalar_function.cpp | 90 ++ src/jni/bindings_vector.cpp | 45 + src/jni/duckdb_java.cpp | 291 ++++++ src/jni/refs.cpp | 12 + src/jni/refs.hpp | 6 + src/main/java/org/duckdb/DuckDBBindings.java | 22 + .../java/org/duckdb/DuckDBConnection.java | 63 ++ .../org/duckdb/DuckDBDataChunkReader.java | 44 + .../java/org/duckdb/DuckDBReadableVector.java | 262 +++++ .../java/org/duckdb/DuckDBVectorTypeInfo.java | 87 ++ .../DuckDBVectorizedScalarFunction.java | 16 + .../java/org/duckdb/DuckDBWritableVector.java | 406 ++++++++ src/test/java/org/duckdb/TestBindings.java | 102 ++ src/test/java/org/duckdb/TestDuckDBJDBC.java | 13 +- .../java/org/duckdb/TestScalarFunctions.java | 909 ++++++++++++++++++ 22 files changed, 2492 insertions(+), 6 deletions(-) create mode 100644 UDF.MD create mode 100644 src/jni/bindings_scalar_function.cpp create mode 100644 src/main/java/org/duckdb/DuckDBDataChunkReader.java create mode 100644 src/main/java/org/duckdb/DuckDBReadableVector.java create mode 100644 src/main/java/org/duckdb/DuckDBVectorTypeInfo.java create mode 100644 src/main/java/org/duckdb/DuckDBVectorizedScalarFunction.java create mode 100644 src/main/java/org/duckdb/DuckDBWritableVector.java create mode 100644 src/test/java/org/duckdb/TestScalarFunctions.java diff --git a/CMakeLists.txt b/CMakeLists.txt index 2d286cc75..e5bbe070a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -591,6 +591,7 @@ add_library(duckdb_java SHARED src/jni/bindings_common.cpp src/jni/bindings_data_chunk.cpp src/jni/bindings_logical_type.cpp + src/jni/bindings_scalar_function.cpp src/jni/bindings_validity.cpp src/jni/bindings_vector.cpp src/jni/config.cpp diff --git a/README.md b/README.md index 4042a68db..3f1b60c52 100644 --- a/README.md +++ b/README.md @@ -20,3 +20,5 @@ This optionally takes an argument to only run a single test, for example: ``` java -cp "build/release/duckdb_jdbc_tests.jar:build/release/duckdb_jdbc.jar" org/duckdb/TestDuckDBJDBC test_valid_but_local_config_throws_exception ``` + +Scalar function usage examples: [UDF.MD](UDF.MD) diff --git a/UDF.MD b/UDF.MD new file mode 100644 index 000000000..7b81cab9e --- /dev/null +++ b/UDF.MD @@ -0,0 +1,63 @@ +# Java Scalar Functions (UDF) + +Use `DuckDBConnection.registerScalarFunction` to register a vectorized scalar function in Java. + +```java +void registerScalarFunction( + String name, + String[] parameterTypes, + String returnType, + DuckDBVectorizedScalarFunction function +) throws SQLException +``` + +Notes: +- `parameterTypes` and `returnType` are SQL type strings (for example: `INTEGER`, `VARCHAR`, `TIMESTAMP`). +- The callback is vectorized: process `rowCount` rows from the input chunk and write one value per row into `out`. +- `DuckDBDataChunkReader` / `DuckDBReadableVector` / `DuckDBWritableVector` are valid only during the callback. + +## Simple example + +```java +try (DuckDBConnection conn = DriverManager.getConnection("jdbc:duckdb:").unwrap(DuckDBConnection.class)) { + conn.registerScalarFunction("java_add_one", new String[] {"INTEGER"}, "INTEGER", (input, rowCount, out) -> { + DuckDBReadableVector in = input.vector(0); + for (int i = 0; i < rowCount; i++) { + if (in.isNull(i)) { + out.setNull(i); + } else { + out.setInt(i, in.getInt(i) + 1); + } + } + }); +} +``` + +## More complete example + +Build a label from `TIMESTAMP` + `VARCHAR` + `DOUBLE`, preserving `NULL` behavior: + +```java +try (DuckDBConnection conn = DriverManager.getConnection("jdbc:duckdb:").unwrap(DuckDBConnection.class)) { + conn.registerScalarFunction( + "java_event_label", + new String[] {"TIMESTAMP", "VARCHAR", "DOUBLE"}, + "VARCHAR", + (input, rowCount, out) -> { + DuckDBReadableVector ts = input.vector(0); + DuckDBReadableVector tag = input.vector(1); + DuckDBReadableVector score = input.vector(2); + + for (int i = 0; i < rowCount; i++) { + if (ts.isNull(i) || tag.isNull(i) || score.isNull(i)) { + out.setNull(i); + continue; + } + String value = + ts.getLocalDateTime(i) + " | " + tag.getString(i).trim().toUpperCase() + " | " + score.getDouble(i); + out.setString(i, value); + } + } + ); +} +``` diff --git a/duckdb_java.def b/duckdb_java.def index 68ff3031b..8edfc8220 100644 --- a/duckdb_java.def +++ b/duckdb_java.def @@ -52,7 +52,15 @@ Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1set_1schema Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1startup Java_org_duckdb_DuckDBBindings_duckdb_1vector_1size +Java_org_duckdb_DuckDBBindings_duckdb_1create_1scalar_1function +Java_org_duckdb_DuckDBBindings_duckdb_1destroy_1scalar_1function +Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1name +Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1add_1parameter +Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1return_1type +Java_org_duckdb_DuckDBBindings_duckdb_1register_1scalar_1function +Java_org_duckdb_DuckDBBindings_duckdb_1jdbc_1varchar_1string_1bytes Java_org_duckdb_DuckDBBindings_duckdb_1create_1logical_1type +Java_org_duckdb_DuckDBBindings_duckdb_1jdbc_1parse_1logical_1type Java_org_duckdb_DuckDBBindings_duckdb_1get_1type_1id Java_org_duckdb_DuckDBBindings_duckdb_1decimal_1width Java_org_duckdb_DuckDBBindings_duckdb_1decimal_1scale @@ -98,6 +106,7 @@ Java_org_duckdb_DuckDBBindings_duckdb_1appender_1column_1count Java_org_duckdb_DuckDBBindings_duckdb_1appender_1column_1type Java_org_duckdb_DuckDBBindings_duckdb_1append_1data_1chunk Java_org_duckdb_DuckDBBindings_duckdb_1append_1default_1to_1chunk +Java_org_duckdb_DuckDBBindings_duckdb_1jdbc_1scalar_1function_1set_1callback duckdb_adbc_init duckdb_add_aggregate_function_to_set diff --git a/duckdb_java.exp b/duckdb_java.exp index 6b6cb687d..b6811a443 100644 --- a/duckdb_java.exp +++ b/duckdb_java.exp @@ -49,7 +49,16 @@ _Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1set_1schema _Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1startup _Java_org_duckdb_DuckDBBindings_duckdb_1vector_1size +_Java_org_duckdb_DuckDBBindings_duckdb_1create_1scalar_1function +_Java_org_duckdb_DuckDBBindings_duckdb_1destroy_1scalar_1function +_Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1name +_Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1add_1parameter +_Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1return_1type +_Java_org_duckdb_DuckDBBindings_duckdb_1register_1scalar_1function +_Java_org_duckdb_DuckDBBindings_duckdb_1jdbc_1varchar_1string_1bytes _Java_org_duckdb_DuckDBBindings_duckdb_1create_1logical_1type +_Java_org_duckdb_DuckDBBindings_duckdb_1jdbc_1parse_1logical_1type +_Java_org_duckdb_DuckDBBindings_duckdb_1jdbc_1scalar_1function_1set_1callback _Java_org_duckdb_DuckDBBindings_duckdb_1get_1type_1id _Java_org_duckdb_DuckDBBindings_duckdb_1decimal_1width _Java_org_duckdb_DuckDBBindings_duckdb_1decimal_1scale diff --git a/duckdb_java.map b/duckdb_java.map index 7ed2d7233..635070538 100644 --- a/duckdb_java.map +++ b/duckdb_java.map @@ -51,7 +51,15 @@ DUCKDB_JAVA { Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1startup; Java_org_duckdb_DuckDBBindings_duckdb_1vector_1size; + Java_org_duckdb_DuckDBBindings_duckdb_1create_1scalar_1function; + Java_org_duckdb_DuckDBBindings_duckdb_1destroy_1scalar_1function; + Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1name; + Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1add_1parameter; + Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1return_1type; + Java_org_duckdb_DuckDBBindings_duckdb_1register_1scalar_1function; + Java_org_duckdb_DuckDBBindings_duckdb_1jdbc_1varchar_1string_1bytes; Java_org_duckdb_DuckDBBindings_duckdb_1create_1logical_1type; + Java_org_duckdb_DuckDBBindings_duckdb_1jdbc_1parse_1logical_1type; Java_org_duckdb_DuckDBBindings_duckdb_1get_1type_1id; Java_org_duckdb_DuckDBBindings_duckdb_1decimal_1width; Java_org_duckdb_DuckDBBindings_duckdb_1decimal_1scale; @@ -97,6 +105,7 @@ DUCKDB_JAVA { Java_org_duckdb_DuckDBBindings_duckdb_1appender_1column_1type; Java_org_duckdb_DuckDBBindings_duckdb_1append_1data_1chunk; Java_org_duckdb_DuckDBBindings_duckdb_1append_1default_1to_1chunk; + Java_org_duckdb_DuckDBBindings_duckdb_1jdbc_1scalar_1function_1set_1callback; duckdb_adbc_init; duckdb_add_aggregate_function_to_set; diff --git a/src/jni/bindings_logical_type.cpp b/src/jni/bindings_logical_type.cpp index 63bc64cc0..ac4f0c568 100644 --- a/src/jni/bindings_logical_type.cpp +++ b/src/jni/bindings_logical_type.cpp @@ -1,4 +1,6 @@ #include "bindings.hpp" +#include "duckdb/common/types.hpp" +#include "holders.hpp" #include "refs.hpp" #include "util.hpp" @@ -36,6 +38,41 @@ JNIEXPORT jobject JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1create_1logical return make_ptr_buf(env, lt); } +/* + * Class: org_duckdb_DuckDBBindings + * Method: duckdb_jdbc_parse_logical_type + * Signature: (Ljava/nio/ByteBuffer;[B)Ljava/nio/ByteBuffer; + */ +JNIEXPORT jobject JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1jdbc_1parse_1logical_1type(JNIEnv *env, jclass, + jobject connection, + jbyteArray type_name) { + + if (type_name == nullptr) { + env->ThrowNew(J_SQLException, "Invalid logical type name"); + return nullptr; + } + + try { + auto sql_type_name = jbyteArray_to_string(env, type_name); + duckdb::LogicalType logical_type; + if (connection) { + auto conn = get_connection(env, connection); + if (env->ExceptionCheck()) { + return nullptr; + } + conn->context->RunFunctionInTransaction( + [&]() { logical_type = duckdb::TransformStringToLogicalType(sql_type_name, *conn->context); }); + } else { + logical_type = duckdb::TransformStringToLogicalType(sql_type_name); + } + return make_ptr_buf(env, + reinterpret_cast(new duckdb::LogicalType(std::move(logical_type)))); + } catch (const std::exception &e) { + env->ThrowNew(J_SQLException, e.what()); + return nullptr; + } +} + /* * Class: org_duckdb_DuckDBBindings * Method: duckdb_get_type_id diff --git a/src/jni/bindings_scalar_function.cpp b/src/jni/bindings_scalar_function.cpp new file mode 100644 index 000000000..c59e1387a --- /dev/null +++ b/src/jni/bindings_scalar_function.cpp @@ -0,0 +1,90 @@ +#include "bindings.hpp" +#include "holders.hpp" +#include "refs.hpp" +#include "util.hpp" + +static duckdb_scalar_function scalar_function_buf_to_scalar_function(JNIEnv *env, jobject scalar_function_buf) { + + if (scalar_function_buf == nullptr) { + env->ThrowNew(J_SQLException, "Invalid scalar function buffer"); + return nullptr; + } + + duckdb_scalar_function scalar_function = + reinterpret_cast(env->GetDirectBufferAddress(scalar_function_buf)); + if (scalar_function == nullptr) { + env->ThrowNew(J_SQLException, "Invalid scalar function"); + return nullptr; + } + + return scalar_function; +} + +JNIEXPORT jobject JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1create_1scalar_1function(JNIEnv *env, jclass) { + return make_ptr_buf(env, duckdb_create_scalar_function()); +} + +JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1destroy_1scalar_1function(JNIEnv *env, jclass, + jobject scalar_function) { + auto function = scalar_function_buf_to_scalar_function(env, scalar_function); + if (env->ExceptionCheck()) { + return; + } + duckdb_destroy_scalar_function(&function); +} + +JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1name(JNIEnv *env, jclass, + jobject scalar_function, + jbyteArray name) { + auto function = scalar_function_buf_to_scalar_function(env, scalar_function); + if (env->ExceptionCheck()) { + return; + } + if (name == nullptr) { + env->ThrowNew(J_SQLException, "Invalid scalar function name"); + return; + } + auto function_name = jbyteArray_to_string(env, name); + duckdb_scalar_function_set_name(function, function_name.c_str()); +} + +JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1add_1parameter(JNIEnv *env, jclass, + jobject scalar_function, + jobject logical_type) { + auto function = scalar_function_buf_to_scalar_function(env, scalar_function); + if (env->ExceptionCheck()) { + return; + } + auto type = logical_type_buf_to_logical_type(env, logical_type); + if (env->ExceptionCheck()) { + return; + } + duckdb_scalar_function_add_parameter(function, type); +} + +JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1return_1type( + JNIEnv *env, jclass, jobject scalar_function, jobject logical_type) { + auto function = scalar_function_buf_to_scalar_function(env, scalar_function); + if (env->ExceptionCheck()) { + return; + } + auto type = logical_type_buf_to_logical_type(env, logical_type); + if (env->ExceptionCheck()) { + return; + } + duckdb_scalar_function_set_return_type(function, type); +} + +JNIEXPORT jint JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1register_1scalar_1function(JNIEnv *env, jclass, + jobject connection, + jobject scalar_function) { + auto conn = conn_ref_buf_to_conn(env, connection); + if (env->ExceptionCheck()) { + return static_cast(DuckDBError); + } + auto function = scalar_function_buf_to_scalar_function(env, scalar_function); + if (env->ExceptionCheck()) { + return static_cast(DuckDBError); + } + return static_cast(duckdb_register_scalar_function(conn, function)); +} diff --git a/src/jni/bindings_vector.cpp b/src/jni/bindings_vector.cpp index 56876d68e..9bda6e778 100644 --- a/src/jni/bindings_vector.cpp +++ b/src/jni/bindings_vector.cpp @@ -21,6 +21,51 @@ static duckdb_vector vector_buf_to_vector(JNIEnv *env, jobject vector_buf) { return vector; } +/* + * Class: org_duckdb_DuckDBBindings + * Method: duckdb_jdbc_varchar_string_bytes + * Signature: (Ljava/nio/ByteBuffer;Ljava/nio/ByteBuffer;JJ)[B + */ +JNIEXPORT jbyteArray JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1jdbc_1varchar_1string_1bytes( + JNIEnv *env, jclass, jobject vector_data, jobject validity, jlong row_count, jlong row) { + + if (vector_data == nullptr) { + env->ThrowNew(J_SQLException, "Invalid vector data buffer"); + return nullptr; + } + auto data = reinterpret_cast(env->GetDirectBufferAddress(vector_data)); + if (data == nullptr) { + env->ThrowNew(J_SQLException, "Invalid vector data"); + return nullptr; + } + idx_t row_count_idx = jlong_to_idx(env, row_count); + if (env->ExceptionCheck()) { + return nullptr; + } + idx_t row_idx = jlong_to_idx(env, row); + if (env->ExceptionCheck()) { + return nullptr; + } + if (row_idx >= row_count_idx) { + env->ThrowNew(J_SQLException, "Row index out of bounds"); + return nullptr; + } + if (validity != nullptr) { + auto mask = reinterpret_cast(env->GetDirectBufferAddress(validity)); + if (mask == nullptr) { + env->ThrowNew(J_SQLException, "Invalid validity buffer"); + return nullptr; + } + if ((mask[row_idx / 64] & (1ULL << (row_idx % 64))) == 0) { + return nullptr; + } + } + auto &string_value = data[row_idx]; + auto string_len = duckdb_string_t_length(string_value); + auto string_ptr = duckdb_string_t_data(&string_value); + return make_jbyteArray(env, string_ptr, string_len); +} + /* * Class: org_duckdb_DuckDBBindings * Method: duckdb_create_vector diff --git a/src/jni/duckdb_java.cpp b/src/jni/duckdb_java.cpp index 436dda5c4..408e6c30c 100644 --- a/src/jni/duckdb_java.cpp +++ b/src/jni/duckdb_java.cpp @@ -24,11 +24,13 @@ extern "C" { #include #include +#include using namespace duckdb; using namespace std; static jint JNI_VERSION = JNI_VERSION_1_6; +static JavaVM *JVM_REF = nullptr; void ThrowJNI(JNIEnv *env, const char *message) { D_ASSERT(J_SQLException); @@ -40,6 +42,7 @@ JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved) { if (vm->GetEnv(reinterpret_cast(&env), JNI_VERSION) != JNI_OK) { return JNI_ERR; } + JVM_REF = vm; try { create_refs(env); @@ -62,6 +65,281 @@ JNIEXPORT void JNICALL JNI_OnUnload(JavaVM *vm, void *reserved) { return; } delete_global_refs(env); + JVM_REF = nullptr; +} + +struct JNIEnvGuard { + JavaVM *vm; + JNIEnv *env; + bool detach_when_done; + + explicit JNIEnvGuard(JavaVM *vm_p) : vm(vm_p), env(nullptr), detach_when_done(false) { + if (!vm) { + throw InvalidInputException("JVM is not available"); + } + auto get_env_status = vm->GetEnv(reinterpret_cast(&env), JNI_VERSION); + if (get_env_status == JNI_OK) { + return; + } + if (get_env_status != JNI_EDETACHED) { + throw InvalidInputException("Failed to get JNI environment"); + } + auto attach_status = vm->AttachCurrentThread(reinterpret_cast(&env), nullptr); + if (attach_status != JNI_OK || !env) { + throw InvalidInputException("Failed to attach current thread to JVM"); + } + detach_when_done = true; + } + + ~JNIEnvGuard() { + if (detach_when_done && vm) { + vm->DetachCurrentThread(); + } + } +}; + +struct JavaScalarFunctionState { + JavaVM *vm; + jobject callback; + jmethodID apply_method; + + JavaScalarFunctionState(JavaVM *vm_p, jobject callback_p, jmethodID apply_method_p) + : vm(vm_p), callback(callback_p), apply_method(apply_method_p) { + } + + ~JavaScalarFunctionState() { + if (!vm || !callback) { + return; + } + try { + JNIEnvGuard env_guard(vm); + env_guard.env->DeleteGlobalRef(callback); + } catch (...) { + // noop in destructor + } + } +}; + +struct JavaScalarFunctionLocalState { + JavaVM *vm; + JNIEnv *env; + bool detach_when_done; +}; + +static duckdb_scalar_function scalar_function_buf_to_scalar_function(JNIEnv *env, jobject scalar_function_buf) { + if (scalar_function_buf == nullptr) { + env->ThrowNew(J_SQLException, "Invalid scalar function buffer"); + return nullptr; + } + + auto scalar_function = reinterpret_cast(env->GetDirectBufferAddress(scalar_function_buf)); + if (scalar_function == nullptr) { + env->ThrowNew(J_SQLException, "Invalid scalar function"); + return nullptr; + } + return scalar_function; +} + +jobject ProcessVector(JNIEnv *env, Connection *conn_ref, Vector &vec, idx_t row_count); + +static string consume_java_exception_message(JNIEnv *env) { + auto throwable = env->ExceptionOccurred(); + if (!throwable) { + return "Java exception"; + } + env->ExceptionClear(); + + string message = "Java exception"; + auto msg = (jstring)env->CallObjectMethod(throwable, J_Throwable_getMessage); + if (!env->ExceptionCheck() && msg) { + message = jstring_to_string(env, msg); + } + if (env->ExceptionCheck()) { + env->ExceptionClear(); + } + + env->DeleteLocalRef(throwable); + if (msg) { + env->DeleteLocalRef(msg); + } + + return message; +} + +static void get_or_attach_jni_env(JavaVM *vm, JNIEnv *&env, bool &detach_when_done) { + if (!vm) { + throw InvalidInputException("JVM is not available"); + } + + detach_when_done = false; + auto get_env_status = vm->GetEnv(reinterpret_cast(&env), JNI_VERSION); + if (get_env_status == JNI_OK) { + return; + } + if (get_env_status != JNI_EDETACHED) { + throw InvalidInputException("Failed to get JNI environment"); + } + + auto attach_status = vm->AttachCurrentThread(reinterpret_cast(&env), nullptr); + if (attach_status != JNI_OK || !env) { + throw InvalidInputException("Failed to attach current thread to JVM"); + } + detach_when_done = true; +} + +static void execute_java_vectorized_scalar_function(JNIEnv *env, JavaScalarFunctionState &state, DataChunk &input, + Vector &output) { + auto row_count = input.size(); + jobject input_chunk_buf = make_ptr_buf(env, &input); + jobject output_vector_buf = make_ptr_buf(env, &output); + auto input_reader = env->NewObject(J_DuckDataChunkReader, J_DuckDataChunkReader_init, input_chunk_buf, + static_cast(row_count)); + if (env->ExceptionCheck()) { + if (input_chunk_buf) { + env->DeleteLocalRef(input_chunk_buf); + } + if (output_vector_buf) { + env->DeleteLocalRef(output_vector_buf); + } + throw InvalidInputException("Could not create DuckDBDataChunkReader: %s", consume_java_exception_message(env)); + } + + auto output_writer = env->NewObject(J_DuckWritableVector, J_DuckWritableVector_init, output_vector_buf, + static_cast(row_count)); + if (env->ExceptionCheck()) { + env->DeleteLocalRef(input_reader); + if (input_chunk_buf) { + env->DeleteLocalRef(input_chunk_buf); + } + if (output_vector_buf) { + env->DeleteLocalRef(output_vector_buf); + } + throw InvalidInputException("Could not create DuckDBWritableVector: %s", consume_java_exception_message(env)); + } + + env->CallVoidMethod(state.callback, state.apply_method, input_reader, static_cast(row_count), output_writer); + + env->DeleteLocalRef(output_writer); + env->DeleteLocalRef(input_reader); + if (input_chunk_buf) { + env->DeleteLocalRef(input_chunk_buf); + } + if (output_vector_buf) { + env->DeleteLocalRef(output_vector_buf); + } + + if (env->ExceptionCheck()) { + throw InvalidInputException("Java scalar function threw exception: %s", consume_java_exception_message(env)); + } +} + +static void destroy_java_scalar_function_state(void *extra_info); +static void init_java_scalar_function_capi(duckdb_init_info info); +static void execute_java_scalar_function_capi(duckdb_function_info info, duckdb_data_chunk input, duckdb_vector output); + +static jmethodID get_scalar_callback_method(JNIEnv *env, jobject function_j, const char *signature, + const char *error_message) { + auto callback_class = env->GetObjectClass(function_j); + auto apply_method = env->GetMethodID(callback_class, "apply", signature); + env->DeleteLocalRef(callback_class); + if (!apply_method || env->ExceptionCheck()) { + consume_java_exception_message(env); + throw InvalidInputException("%s", error_message); + } + return apply_method; +} + +static void install_java_scalar_function_callback(JNIEnv *env, jobject conn_ref_buf, jobject scalar_function_buf, + jobject function_j, const char *signature, + const char *error_message) { + auto connection = get_connection(env, conn_ref_buf); + if (!connection) { + throw InvalidInputException("Invalid connection"); + } + auto scalar_function = scalar_function_buf_to_scalar_function(env, scalar_function_buf); + if (env->ExceptionCheck()) { + return; + } + if (!function_j) { + throw InvalidInputException("Invalid scalar function callback"); + } + + auto callback_ref = env->NewGlobalRef(function_j); + if (!callback_ref) { + throw InvalidInputException("Could not create global reference for scalar function callback"); + } + + try { + auto apply_method = get_scalar_callback_method(env, function_j, signature, error_message); + auto state = new JavaScalarFunctionState(JVM_REF, callback_ref, apply_method); + duckdb_scalar_function_set_extra_info(scalar_function, state, destroy_java_scalar_function_state); + duckdb_scalar_function_set_function(scalar_function, execute_java_scalar_function_capi); + duckdb_scalar_function_set_init(scalar_function, init_java_scalar_function_capi); + } catch (...) { + env->DeleteGlobalRef(callback_ref); + throw; + } +} + +static void destroy_java_scalar_function_state(void *extra_info) { + if (!extra_info) { + return; + } + delete reinterpret_cast(extra_info); +} + +static void destroy_java_scalar_function_local_state(void *state_ptr) { + if (!state_ptr) { + return; + } + + auto state = reinterpret_cast(state_ptr); + if (state->detach_when_done && state->vm) { + state->vm->DetachCurrentThread(); + } + delete state; +} + +static void init_java_scalar_function_capi(duckdb_init_info info) { + JavaScalarFunctionLocalState *local_state = nullptr; + try { + auto state = reinterpret_cast(duckdb_scalar_function_init_get_extra_info(info)); + if (!state) { + duckdb_scalar_function_init_set_error(info, "Invalid Java scalar function callback state"); + return; + } + + local_state = new JavaScalarFunctionLocalState(); + local_state->vm = state->vm; + local_state->env = nullptr; + local_state->detach_when_done = false; + get_or_attach_jni_env(local_state->vm, local_state->env, local_state->detach_when_done); + duckdb_scalar_function_init_set_state(info, local_state, destroy_java_scalar_function_local_state); + local_state = nullptr; + } catch (const std::exception &e) { + if (local_state) { + destroy_java_scalar_function_local_state(local_state); + } + duckdb_scalar_function_init_set_error(info, e.what()); + } +} + +static void execute_java_scalar_function_capi(duckdb_function_info info, duckdb_data_chunk input, + duckdb_vector output) { + auto state = reinterpret_cast(duckdb_scalar_function_get_extra_info(info)); + auto local_state = reinterpret_cast(duckdb_scalar_function_get_state(info)); + if (!state || !local_state || !local_state->env || !input || !output) { + duckdb_scalar_function_set_error(info, "Invalid Java scalar function callback state"); + return; + } + + try { + auto &input_chunk = *reinterpret_cast(input); + auto &output_vector = *reinterpret_cast(output); + execute_java_vectorized_scalar_function(local_state->env, *state, input_chunk, output_vector); + } catch (const std::exception &e) { + duckdb_scalar_function_set_error(info, e.what()); + } } //! The database instance cache, used so that multiple connections to the same file point to the same database object @@ -899,6 +1177,19 @@ void _duckdb_jdbc_arrow_register(JNIEnv *env, jclass, jobject conn_ref_buf, jlon conn->TableFunction("arrow_scan_dumb", parameters)->CreateView(name, true, true); } +extern "C" JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1jdbc_1scalar_1function_1set_1callback( + JNIEnv *env, jclass, jobject conn_ref_buf, jobject scalar_function_buf, jobject function_j) { + try { + install_java_scalar_function_callback(env, conn_ref_buf, scalar_function_buf, function_j, + "(Lorg/duckdb/DuckDBDataChunkReader;ILorg/duckdb/DuckDBWritableVector;)V", + "Could not find apply(DuckDBDataChunkReader, int, DuckDBWritableVector) " + "on scalar function callback"); + } catch (const std::exception &e) { + duckdb::ErrorData error(e); + ThrowJNI(env, error.Message().c_str()); + } +} + void _duckdb_jdbc_create_extension_type(JNIEnv *env, jclass, jobject conn_buf) { auto connection = get_connection(env, conn_buf); diff --git a/src/jni/refs.cpp b/src/jni/refs.cpp index 1dc6cbafe..d76cfbd80 100644 --- a/src/jni/refs.cpp +++ b/src/jni/refs.cpp @@ -61,6 +61,12 @@ jmethodID J_DuckVector_retainConstlenData; jfieldID J_DuckVector_constlen; jfieldID J_DuckVector_varlen; +jclass J_DuckDataChunkReader; +jmethodID J_DuckDataChunkReader_init; + +jclass J_DuckWritableVector; +jmethodID J_DuckWritableVector_init; + jclass J_DuckArray; jmethodID J_DuckArray_init; @@ -287,6 +293,12 @@ void create_refs(JNIEnv *env) { J_DuckVector_constlen = get_field_id(env, J_DuckVector, "constlen_data", "Ljava/nio/ByteBuffer;"); J_DuckVector_varlen = get_field_id(env, J_DuckVector, "varlen_data", "[Ljava/lang/Object;"); + J_DuckDataChunkReader = make_class_ref(env, "org/duckdb/DuckDBDataChunkReader"); + J_DuckDataChunkReader_init = get_method_id(env, J_DuckDataChunkReader, "", "(Ljava/nio/ByteBuffer;I)V"); + + J_DuckWritableVector = make_class_ref(env, "org/duckdb/DuckDBWritableVector"); + J_DuckWritableVector_init = get_method_id(env, J_DuckWritableVector, "", "(Ljava/nio/ByteBuffer;I)V"); + J_ByteBuffer = make_class_ref(env, "java/nio/ByteBuffer"); J_ByteBuffer_order = get_method_id(env, J_ByteBuffer, "order", "(Ljava/nio/ByteOrder;)Ljava/nio/ByteBuffer;"); J_ByteOrder = make_class_ref(env, "java/nio/ByteOrder"); diff --git a/src/jni/refs.hpp b/src/jni/refs.hpp index cda859d33..94236dcf8 100644 --- a/src/jni/refs.hpp +++ b/src/jni/refs.hpp @@ -58,6 +58,12 @@ extern jmethodID J_DuckVector_retainConstlenData; extern jfieldID J_DuckVector_constlen; extern jfieldID J_DuckVector_varlen; +extern jclass J_DuckDataChunkReader; +extern jmethodID J_DuckDataChunkReader_init; + +extern jclass J_DuckWritableVector; +extern jmethodID J_DuckWritableVector_init; + extern jclass J_DuckArray; extern jmethodID J_DuckArray_init; diff --git a/src/main/java/org/duckdb/DuckDBBindings.java b/src/main/java/org/duckdb/DuckDBBindings.java index 4ee45c04d..3e0d00b0e 100644 --- a/src/main/java/org/duckdb/DuckDBBindings.java +++ b/src/main/java/org/duckdb/DuckDBBindings.java @@ -17,10 +17,32 @@ public class DuckDBBindings { static native long duckdb_vector_size(); + // scalar function + + static native ByteBuffer duckdb_create_scalar_function(); + + static native void duckdb_destroy_scalar_function(ByteBuffer scalarFunction); + + static native void duckdb_scalar_function_set_name(ByteBuffer scalarFunction, byte[] name); + + static native void duckdb_scalar_function_add_parameter(ByteBuffer scalarFunction, ByteBuffer logicalType); + + static native void duckdb_scalar_function_set_return_type(ByteBuffer scalarFunction, ByteBuffer logicalType); + + static native int duckdb_register_scalar_function(ByteBuffer connection, ByteBuffer scalarFunction); + + static native void duckdb_jdbc_scalar_function_set_callback(ByteBuffer connection, ByteBuffer scalarFunction, + DuckDBVectorizedScalarFunction function); + + static native byte[] duckdb_jdbc_varchar_string_bytes(ByteBuffer vectorData, ByteBuffer validity, long rowCount, + long row); + // logical type static native ByteBuffer duckdb_create_logical_type(int duckdb_type); + static native ByteBuffer duckdb_jdbc_parse_logical_type(ByteBuffer connection, byte[] type_name); + static native int duckdb_get_type_id(ByteBuffer logical_type); static native int duckdb_decimal_width(ByteBuffer logical_type); diff --git a/src/main/java/org/duckdb/DuckDBConnection.java b/src/main/java/org/duckdb/DuckDBConnection.java index d51c0c00e..5acbeaa4a 100644 --- a/src/main/java/org/duckdb/DuckDBConnection.java +++ b/src/main/java/org/duckdb/DuckDBConnection.java @@ -500,6 +500,69 @@ public void registerArrowStream(String name, Object arrow_array_stream) { } } + public void registerScalarFunction(String name, String[] parameterTypes, String returnType, + DuckDBVectorizedScalarFunction function) throws SQLException { + checkOpen(); + connRefLock.lock(); + ByteBuffer scalarFunction = null; + ByteBuffer returnLogicalType = null; + ByteBuffer[] parameterLogicalTypes = null; + try { + checkOpen(); + if (name == null || name.trim().isEmpty()) { + throw new SQLException("Function name cannot be null or empty"); + } + if (parameterTypes == null) { + throw new SQLException("Parameter types cannot be null"); + } + for (int i = 0; i < parameterTypes.length; i++) { + String parameterType = parameterTypes[i]; + if (parameterType == null || parameterType.trim().isEmpty()) { + throw new SQLException("Parameter type at index " + i + " cannot be null or empty"); + } + } + if (returnType == null || returnType.trim().isEmpty()) { + throw new SQLException("Return type cannot be null or empty"); + } + if (function == null) { + throw new SQLException("Scalar function callback cannot be null"); + } + + scalarFunction = DuckDBBindings.duckdb_create_scalar_function(); + DuckDBBindings.duckdb_scalar_function_set_name(scalarFunction, name.getBytes(UTF_8)); + + parameterLogicalTypes = new ByteBuffer[parameterTypes.length]; + for (int i = 0; i < parameterTypes.length; i++) { + parameterLogicalTypes[i] = + DuckDBBindings.duckdb_jdbc_parse_logical_type(connRef, parameterTypes[i].getBytes(UTF_8)); + DuckDBBindings.duckdb_scalar_function_add_parameter(scalarFunction, parameterLogicalTypes[i]); + } + + returnLogicalType = DuckDBBindings.duckdb_jdbc_parse_logical_type(connRef, returnType.getBytes(UTF_8)); + DuckDBBindings.duckdb_scalar_function_set_return_type(scalarFunction, returnLogicalType); + DuckDBBindings.duckdb_jdbc_scalar_function_set_callback(connRef, scalarFunction, function); + + if (DuckDBBindings.duckdb_register_scalar_function(connRef, scalarFunction) != 0) { + throw new SQLException("Failed to register scalar function '" + name + "'"); + } + } finally { + if (returnLogicalType != null) { + DuckDBBindings.duckdb_destroy_logical_type(returnLogicalType); + } + if (parameterLogicalTypes != null) { + for (ByteBuffer parameterLogicalType : parameterLogicalTypes) { + if (parameterLogicalType != null) { + DuckDBBindings.duckdb_destroy_logical_type(parameterLogicalType); + } + } + } + if (scalarFunction != null) { + DuckDBBindings.duckdb_destroy_scalar_function(scalarFunction); + } + connRefLock.unlock(); + } + } + public String getProfilingInformation(ProfilerPrintFormat format) throws SQLException { checkOpen(); connRefLock.lock(); diff --git a/src/main/java/org/duckdb/DuckDBDataChunkReader.java b/src/main/java/org/duckdb/DuckDBDataChunkReader.java new file mode 100644 index 000000000..c997f8842 --- /dev/null +++ b/src/main/java/org/duckdb/DuckDBDataChunkReader.java @@ -0,0 +1,44 @@ +package org.duckdb; + +import static org.duckdb.DuckDBBindings.*; + +import java.nio.ByteBuffer; +import java.sql.SQLException; + +public final class DuckDBDataChunkReader { + private final ByteBuffer chunkRef; + private final int rowCount; + private final int columnCount; + private final DuckDBReadableVector[] vectors; + + DuckDBDataChunkReader(ByteBuffer chunkRef, int rowCount) throws SQLException { + if (chunkRef == null) { + throw new SQLException("Invalid data chunk reference"); + } + this.chunkRef = chunkRef; + this.rowCount = rowCount; + this.columnCount = (int) duckdb_data_chunk_get_column_count(chunkRef); + this.vectors = new DuckDBReadableVector[columnCount]; + } + + public int rowCount() { + return rowCount; + } + + public int columnCount() { + return columnCount; + } + + public DuckDBReadableVector vector(int columnIndex) throws SQLException { + if (columnIndex < 0 || columnIndex >= columnCount) { + throw new IndexOutOfBoundsException("Column index out of bounds: " + columnIndex); + } + DuckDBReadableVector vector = vectors[columnIndex]; + if (vector == null) { + ByteBuffer vectorRef = duckdb_data_chunk_get_vector(chunkRef, columnIndex); + vector = new DuckDBReadableVector(vectorRef, rowCount); + vectors[columnIndex] = vector; + } + return vector; + } +} diff --git a/src/main/java/org/duckdb/DuckDBReadableVector.java b/src/main/java/org/duckdb/DuckDBReadableVector.java new file mode 100644 index 000000000..c2db43a7b --- /dev/null +++ b/src/main/java/org/duckdb/DuckDBReadableVector.java @@ -0,0 +1,262 @@ +package org.duckdb; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.duckdb.DuckDBBindings.*; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.sql.Date; +import java.sql.SQLException; +import java.sql.Timestamp; +import java.time.Instant; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.OffsetDateTime; +import java.time.ZoneId; +import java.time.temporal.ChronoUnit; + +public final class DuckDBReadableVector { + private static final BigDecimal ULONG_MULTIPLIER = new BigDecimal("18446744073709551616"); + private static final ByteOrder NATIVE_ORDER = ByteOrder.nativeOrder(); + + private final ByteBuffer vectorRef; + private final int rowCount; + private final DuckDBVectorTypeInfo typeInfo; + private final ByteBuffer data; + private final ByteBuffer validity; + + DuckDBReadableVector(ByteBuffer vectorRef, int rowCount) throws SQLException { + if (vectorRef == null) { + throw new SQLException("Invalid vector reference"); + } + this.vectorRef = vectorRef; + this.rowCount = rowCount; + this.typeInfo = DuckDBVectorTypeInfo.fromVector(vectorRef); + this.data = duckdb_vector_get_data(vectorRef, (long) rowCount * typeInfo.widthBytes); + this.validity = duckdb_vector_get_validity(vectorRef, rowCount); + } + + public DuckDBColumnType getType() { + return typeInfo.columnType; + } + + public int rowCount() { + return rowCount; + } + + public boolean isNull(int row) { + checkRowIndex(row); + if (validity == null) { + return false; + } + int entryPos = (row / 64) * Long.BYTES; + long mask = validity.order(NATIVE_ORDER).getLong(entryPos); + return (mask & (1L << (row % 64))) == 0; + } + + public boolean getBoolean(int row) throws SQLException { + checkRowIndex(row); + requireType(DuckDBColumnType.BOOLEAN); + return data.get(row) != 0; + } + + public byte getByte(int row) throws SQLException { + checkRowIndex(row); + requireType(DuckDBColumnType.TINYINT); + return data.get(row); + } + + public short getShort(int row) throws SQLException { + checkRowIndex(row); + requireType(DuckDBColumnType.SMALLINT); + return data.order(NATIVE_ORDER).getShort(row * Short.BYTES); + } + + public short getUint8(int row) throws SQLException { + checkRowIndex(row); + requireType(DuckDBColumnType.UTINYINT); + return (short) Byte.toUnsignedInt(data.get(row)); + } + + public int getUint16(int row) throws SQLException { + checkRowIndex(row); + requireType(DuckDBColumnType.USMALLINT); + return Short.toUnsignedInt(data.order(NATIVE_ORDER).getShort(row * Short.BYTES)); + } + + public int getInt(int row) throws SQLException { + checkRowIndex(row); + requireType(DuckDBColumnType.INTEGER); + return data.order(NATIVE_ORDER).getInt(row * Integer.BYTES); + } + + public long getUint32(int row) throws SQLException { + checkRowIndex(row); + requireType(DuckDBColumnType.UINTEGER); + return Integer.toUnsignedLong(data.order(NATIVE_ORDER).getInt(row * Integer.BYTES)); + } + + public long getLong(int row) throws SQLException { + checkRowIndex(row); + requireType(DuckDBColumnType.BIGINT); + return data.order(NATIVE_ORDER).getLong(row * Long.BYTES); + } + + public BigInteger getUint64(int row) throws SQLException { + checkRowIndex(row); + requireType(DuckDBColumnType.UBIGINT); + long value = data.order(NATIVE_ORDER).getLong(row * Long.BYTES); + return unsignedLongToBigInteger(value); + } + + public float getFloat(int row) throws SQLException { + checkRowIndex(row); + requireType(DuckDBColumnType.FLOAT); + return data.order(NATIVE_ORDER).getFloat(row * Float.BYTES); + } + + public double getDouble(int row) throws SQLException { + checkRowIndex(row); + requireType(DuckDBColumnType.DOUBLE); + return data.order(NATIVE_ORDER).getDouble(row * Double.BYTES); + } + + public LocalDate getLocalDate(int row) throws SQLException { + checkRowIndex(row); + requireType(DuckDBColumnType.DATE); + return LocalDate.ofEpochDay(data.order(NATIVE_ORDER).getInt(row * Integer.BYTES)); + } + + public Date getDate(int row) throws SQLException { + return Date.valueOf(getLocalDate(row)); + } + + public LocalDateTime getLocalDateTime(int row) throws SQLException { + checkRowIndex(row); + requireTimestampType(); + long epochValue = data.order(NATIVE_ORDER).getLong(row * Long.BYTES); + switch (typeInfo.capiType) { + case DUCKDB_TYPE_TIMESTAMP_S: + return DuckDBTimestamp.localDateTimeFromTimestamp(epochValue, ChronoUnit.SECONDS, null); + case DUCKDB_TYPE_TIMESTAMP_MS: + return DuckDBTimestamp.localDateTimeFromTimestamp(epochValue, ChronoUnit.MILLIS, null); + case DUCKDB_TYPE_TIMESTAMP: + return DuckDBTimestamp.localDateTimeFromTimestamp(epochValue, ChronoUnit.MICROS, null); + case DUCKDB_TYPE_TIMESTAMP_NS: + return DuckDBTimestamp.localDateTimeFromTimestamp(epochValue, ChronoUnit.NANOS, null); + case DUCKDB_TYPE_TIMESTAMP_TZ: + return DuckDBTimestamp.localDateTimeFromTimestampWithTimezone(epochValue, ChronoUnit.MICROS, null); + default: + throw new SQLException("Expected vector type TIMESTAMP*, found " + typeInfo.columnType); + } + } + + public Timestamp getTimestamp(int row) throws SQLException { + return Timestamp.valueOf(getLocalDateTime(row)); + } + + public OffsetDateTime getOffsetDateTime(int row) throws SQLException { + checkRowIndex(row); + requireType(DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE); + long micros = data.order(NATIVE_ORDER).getLong(row * Long.BYTES); + Instant instant = instantFromEpoch(micros, ChronoUnit.MICROS); + return instant.atZone(ZoneId.systemDefault()).toOffsetDateTime(); + } + + public BigDecimal getBigDecimal(int row) throws SQLException { + checkRowIndex(row); + requireType(DuckDBColumnType.DECIMAL); + switch (typeInfo.storageType) { + case DUCKDB_TYPE_SMALLINT: + return BigDecimal.valueOf(data.order(NATIVE_ORDER).getShort(row * Short.BYTES), typeInfo.decimalMeta.scale); + case DUCKDB_TYPE_INTEGER: + return BigDecimal.valueOf(data.order(NATIVE_ORDER).getInt(row * Integer.BYTES), typeInfo.decimalMeta.scale); + case DUCKDB_TYPE_BIGINT: + return BigDecimal.valueOf(data.order(NATIVE_ORDER).getLong(row * Long.BYTES), typeInfo.decimalMeta.scale); + case DUCKDB_TYPE_HUGEINT: { + ByteBuffer slice = data.duplicate().order(NATIVE_ORDER); + slice.position(row * typeInfo.widthBytes); + long lower = slice.getLong(); + long upper = slice.getLong(); + return new BigDecimal(upper) + .multiply(ULONG_MULTIPLIER) + .add(new BigDecimal(Long.toUnsignedString(lower))) + .scaleByPowerOfTen(typeInfo.decimalMeta.scale * -1); + } + default: + throw new SQLException("Unsupported DECIMAL storage type: " + typeInfo.storageType); + } + } + + public String getString(int row) throws SQLException { + checkRowIndex(row); + requireType(DuckDBColumnType.VARCHAR); + if (isNull(row)) { + return null; + } + byte[] bytes = duckdb_jdbc_varchar_string_bytes(data, validity, rowCount, row); + if (bytes == null) { + return null; + } + return new String(bytes, UTF_8); + } + + ByteBuffer vectorRef() { + return vectorRef; + } + + private void requireType(DuckDBColumnType expected) throws SQLException { + if (typeInfo.columnType != expected) { + throw new SQLException("Expected vector type " + expected + ", found " + typeInfo.columnType); + } + } + + private void requireTimestampType() throws SQLException { + switch (typeInfo.columnType) { + case TIMESTAMP: + case TIMESTAMP_S: + case TIMESTAMP_MS: + case TIMESTAMP_NS: + case TIMESTAMP_WITH_TIME_ZONE: + return; + default: + throw new SQLException("Expected vector type TIMESTAMP*, found " + typeInfo.columnType); + } + } + + private void checkRowIndex(int row) { + if (row < 0 || row >= rowCount) { + throw new IndexOutOfBoundsException("Row index out of bounds: " + row); + } + } + + private static Instant instantFromEpoch(long value, ChronoUnit unit) throws SQLException { + switch (unit) { + case SECONDS: + return Instant.ofEpochSecond(value); + case MILLIS: + return Instant.ofEpochMilli(value); + case MICROS: { + long epochSecond = Math.floorDiv(value, 1_000_000L); + long nanoAdjustment = Math.floorMod(value, 1_000_000L) * 1000L; + return Instant.ofEpochSecond(epochSecond, nanoAdjustment); + } + case NANOS: { + long epochSecond = Math.floorDiv(value, 1_000_000_000L); + long nanoAdjustment = Math.floorMod(value, 1_000_000_000L); + return Instant.ofEpochSecond(epochSecond, nanoAdjustment); + } + default: + throw new SQLException("Unsupported unit type: " + unit); + } + } + + private static BigInteger unsignedLongToBigInteger(long value) { + if (value >= 0) { + return BigInteger.valueOf(value); + } + return BigInteger.valueOf(value & Long.MAX_VALUE).setBit(Long.SIZE - 1); + } +} diff --git a/src/main/java/org/duckdb/DuckDBVectorTypeInfo.java b/src/main/java/org/duckdb/DuckDBVectorTypeInfo.java new file mode 100644 index 000000000..f73a979df --- /dev/null +++ b/src/main/java/org/duckdb/DuckDBVectorTypeInfo.java @@ -0,0 +1,87 @@ +package org.duckdb; + +import static org.duckdb.DuckDBBindings.*; + +import java.nio.ByteBuffer; +import java.sql.SQLException; + +final class DuckDBVectorTypeInfo { + final DuckDBColumnType columnType; + final DuckDBBindings.CAPIType capiType; + final DuckDBBindings.CAPIType storageType; + final int widthBytes; + final DuckDBColumnTypeMetaData decimalMeta; + + private DuckDBVectorTypeInfo(DuckDBColumnType columnType, DuckDBBindings.CAPIType capiType, + DuckDBBindings.CAPIType storageType, int widthBytes, + DuckDBColumnTypeMetaData decimalMeta) { + this.columnType = columnType; + this.capiType = capiType; + this.storageType = storageType; + this.widthBytes = widthBytes; + this.decimalMeta = decimalMeta; + } + + static DuckDBVectorTypeInfo fromVector(ByteBuffer vectorRef) throws SQLException { + ByteBuffer logicalType = duckdb_vector_get_column_type(vectorRef); + if (logicalType == null) { + throw new SQLException("Cannot read vector type"); + } + + try { + DuckDBBindings.CAPIType capiType = + DuckDBBindings.CAPIType.capiTypeFromTypeId(duckdb_get_type_id(logicalType)); + switch (capiType) { + case DUCKDB_TYPE_BOOLEAN: + return new DuckDBVectorTypeInfo(DuckDBColumnType.BOOLEAN, capiType, capiType, 1, null); + case DUCKDB_TYPE_TINYINT: + return new DuckDBVectorTypeInfo(DuckDBColumnType.TINYINT, capiType, capiType, 1, null); + case DUCKDB_TYPE_UTINYINT: + return new DuckDBVectorTypeInfo(DuckDBColumnType.UTINYINT, capiType, capiType, 1, null); + case DUCKDB_TYPE_SMALLINT: + return new DuckDBVectorTypeInfo(DuckDBColumnType.SMALLINT, capiType, capiType, 2, null); + case DUCKDB_TYPE_USMALLINT: + return new DuckDBVectorTypeInfo(DuckDBColumnType.USMALLINT, capiType, capiType, 2, null); + case DUCKDB_TYPE_INTEGER: + return new DuckDBVectorTypeInfo(DuckDBColumnType.INTEGER, capiType, capiType, 4, null); + case DUCKDB_TYPE_UINTEGER: + return new DuckDBVectorTypeInfo(DuckDBColumnType.UINTEGER, capiType, capiType, 4, null); + case DUCKDB_TYPE_BIGINT: + return new DuckDBVectorTypeInfo(DuckDBColumnType.BIGINT, capiType, capiType, 8, null); + case DUCKDB_TYPE_UBIGINT: + return new DuckDBVectorTypeInfo(DuckDBColumnType.UBIGINT, capiType, capiType, 8, null); + case DUCKDB_TYPE_FLOAT: + return new DuckDBVectorTypeInfo(DuckDBColumnType.FLOAT, capiType, capiType, 4, null); + case DUCKDB_TYPE_DOUBLE: + return new DuckDBVectorTypeInfo(DuckDBColumnType.DOUBLE, capiType, capiType, 8, null); + case DUCKDB_TYPE_DATE: + return new DuckDBVectorTypeInfo(DuckDBColumnType.DATE, capiType, capiType, 4, null); + case DUCKDB_TYPE_TIMESTAMP_S: + return new DuckDBVectorTypeInfo(DuckDBColumnType.TIMESTAMP_S, capiType, capiType, 8, null); + case DUCKDB_TYPE_TIMESTAMP_MS: + return new DuckDBVectorTypeInfo(DuckDBColumnType.TIMESTAMP_MS, capiType, capiType, 8, null); + case DUCKDB_TYPE_TIMESTAMP: + return new DuckDBVectorTypeInfo(DuckDBColumnType.TIMESTAMP, capiType, capiType, 8, null); + case DUCKDB_TYPE_TIMESTAMP_NS: + return new DuckDBVectorTypeInfo(DuckDBColumnType.TIMESTAMP_NS, capiType, capiType, 8, null); + case DUCKDB_TYPE_TIMESTAMP_TZ: + return new DuckDBVectorTypeInfo(DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE, capiType, capiType, 8, null); + case DUCKDB_TYPE_VARCHAR: + return new DuckDBVectorTypeInfo(DuckDBColumnType.VARCHAR, capiType, capiType, 16, null); + case DUCKDB_TYPE_DECIMAL: { + DuckDBBindings.CAPIType internalType = + DuckDBBindings.CAPIType.capiTypeFromTypeId(duckdb_decimal_internal_type(logicalType)); + DuckDBColumnTypeMetaData decimalMeta = new DuckDBColumnTypeMetaData( + (short) (internalType.widthBytes * 8), (short) duckdb_decimal_width(logicalType), + (short) duckdb_decimal_scale(logicalType)); + return new DuckDBVectorTypeInfo(DuckDBColumnType.DECIMAL, capiType, internalType, + (int) internalType.widthBytes, decimalMeta); + } + default: + throw new SQLException("Unsupported vectorized scalar function type: " + capiType); + } + } finally { + duckdb_destroy_logical_type(logicalType); + } + } +} diff --git a/src/main/java/org/duckdb/DuckDBVectorizedScalarFunction.java b/src/main/java/org/duckdb/DuckDBVectorizedScalarFunction.java new file mode 100644 index 000000000..66649f694 --- /dev/null +++ b/src/main/java/org/duckdb/DuckDBVectorizedScalarFunction.java @@ -0,0 +1,16 @@ +package org.duckdb; + +@FunctionalInterface +public interface DuckDBVectorizedScalarFunction { + /** + * Processes a full input chunk and writes one output value per row directly into the DuckDB output vector. + * + *

The input and output wrappers are valid only for the duration of the callback and must not be retained. + * + * @param input input vectors for the current chunk + * @param rowCount number of rows in the current chunk + * @param out output vector for the current chunk + * @throws Exception when function execution fails + */ + void apply(DuckDBDataChunkReader input, int rowCount, DuckDBWritableVector out) throws Exception; +} diff --git a/src/main/java/org/duckdb/DuckDBWritableVector.java b/src/main/java/org/duckdb/DuckDBWritableVector.java new file mode 100644 index 000000000..606676a8e --- /dev/null +++ b/src/main/java/org/duckdb/DuckDBWritableVector.java @@ -0,0 +1,406 @@ +package org.duckdb; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.duckdb.DuckDBBindings.*; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.sql.SQLException; +import java.sql.Timestamp; +import java.time.Instant; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.OffsetDateTime; +import java.time.ZoneId; +import java.time.ZoneOffset; + +public final class DuckDBWritableVector { + private static final BigInteger UINT64_MAX = new BigInteger("18446744073709551615"); + private static final ByteOrder NATIVE_ORDER = ByteOrder.nativeOrder(); + + private final ByteBuffer vectorRef; + private final int rowCount; + private final DuckDBVectorTypeInfo typeInfo; + private final ByteBuffer data; + private ByteBuffer validity; + + DuckDBWritableVector(ByteBuffer vectorRef, int rowCount) throws SQLException { + if (vectorRef == null) { + throw new SQLException("Invalid vector reference"); + } + this.vectorRef = vectorRef; + this.rowCount = rowCount; + this.typeInfo = DuckDBVectorTypeInfo.fromVector(vectorRef); + this.data = duckdb_vector_get_data(vectorRef, (long) rowCount * typeInfo.widthBytes); + this.validity = duckdb_vector_get_validity(vectorRef, rowCount); + } + + public DuckDBColumnType getType() { + return typeInfo.columnType; + } + + public int rowCount() { + return rowCount; + } + + public void setNull(int row) throws SQLException { + checkRowIndex(row); + ensureValidity(); + duckdb_validity_set_row_validity(validity, row, false); + } + + public void setBoolean(int row, boolean value) throws SQLException { + checkRowIndex(row); + requireType(DuckDBColumnType.BOOLEAN); + data.put(row, value ? (byte) 1 : (byte) 0); + markValid(row); + } + + public void setByte(int row, byte value) throws SQLException { + checkRowIndex(row); + requireType(DuckDBColumnType.TINYINT); + data.put(row, value); + markValid(row); + } + + public void setShort(int row, short value) throws SQLException { + checkRowIndex(row); + requireType(DuckDBColumnType.SMALLINT); + data.order(NATIVE_ORDER).putShort(row * Short.BYTES, value); + markValid(row); + } + + public void setUint8(int row, int value) throws SQLException { + checkRowIndex(row); + requireType(DuckDBColumnType.UTINYINT); + checkUnsignedRange("UTINYINT", value, 0xFFL); + data.put(row, (byte) value); + markValid(row); + } + + public void setUint16(int row, int value) throws SQLException { + checkRowIndex(row); + requireType(DuckDBColumnType.USMALLINT); + checkUnsignedRange("USMALLINT", value, 0xFFFFL); + data.order(NATIVE_ORDER).putShort(row * Short.BYTES, (short) value); + markValid(row); + } + + public void setInt(int row, int value) throws SQLException { + checkRowIndex(row); + requireType(DuckDBColumnType.INTEGER); + data.order(NATIVE_ORDER).putInt(row * Integer.BYTES, value); + markValid(row); + } + + public void setUint32(int row, long value) throws SQLException { + checkRowIndex(row); + requireType(DuckDBColumnType.UINTEGER); + checkUnsignedRange("UINTEGER", value, 0xFFFFFFFFL); + data.order(NATIVE_ORDER).putInt(row * Integer.BYTES, (int) value); + markValid(row); + } + + public void setLong(int row, long value) throws SQLException { + checkRowIndex(row); + requireType(DuckDBColumnType.BIGINT); + data.order(NATIVE_ORDER).putLong(row * Long.BYTES, value); + markValid(row); + } + + public void setUint64(int row, BigInteger value) throws SQLException { + checkRowIndex(row); + requireType(DuckDBColumnType.UBIGINT); + if (value == null) { + setNull(row); + return; + } + if (value.signum() < 0 || value.compareTo(UINT64_MAX) > 0) { + throw new SQLException("Value out of range for UBIGINT: " + value); + } + data.order(NATIVE_ORDER).putLong(row * Long.BYTES, value.longValue()); + markValid(row); + } + + public void setFloat(int row, float value) throws SQLException { + checkRowIndex(row); + requireType(DuckDBColumnType.FLOAT); + data.order(NATIVE_ORDER).putFloat(row * Float.BYTES, value); + markValid(row); + } + + public void setDouble(int row, double value) throws SQLException { + checkRowIndex(row); + requireType(DuckDBColumnType.DOUBLE); + data.order(NATIVE_ORDER).putDouble(row * Double.BYTES, value); + markValid(row); + } + + public void setDate(int row, LocalDate value) throws SQLException { + checkRowIndex(row); + requireType(DuckDBColumnType.DATE); + if (value == null) { + setNull(row); + return; + } + long days = value.toEpochDay(); + if (days < Integer.MIN_VALUE || days > Integer.MAX_VALUE) { + throw new SQLException("Value out of range for DATE: " + value); + } + data.order(NATIVE_ORDER).putInt(row * Integer.BYTES, (int) days); + markValid(row); + } + + public void setDate(int row, java.sql.Date value) throws SQLException { + setDate(row, value == null ? null : value.toLocalDate()); + } + + public void setDate(int row, java.util.Date value) throws SQLException { + if (value == null) { + setNull(row); + return; + } + if (value instanceof java.sql.Date) { + setDate(row, (java.sql.Date) value); + return; + } + LocalDate localDate = Instant.ofEpochMilli(value.getTime()).atZone(ZoneOffset.UTC).toLocalDate(); + setDate(row, localDate); + } + + public void setTimestamp(int row, LocalDateTime value) throws SQLException { + checkRowIndex(row); + requireTimestampType(false); + if (value == null) { + setNull(row); + return; + } + data.order(NATIVE_ORDER).putLong(row * Long.BYTES, encodeLocalDateTime(value)); + markValid(row); + } + + public void setTimestamp(int row, Timestamp value) throws SQLException { + if (value == null) { + setNull(row); + return; + } + if (typeInfo.columnType == DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE) { + checkRowIndex(row); + data.order(NATIVE_ORDER).putLong(row * Long.BYTES, encodeInstant(value.toInstant())); + markValid(row); + return; + } + setTimestamp(row, value.toLocalDateTime()); + } + + public void setTimestamp(int row, java.util.Date value) throws SQLException { + checkRowIndex(row); + requireTimestampType(false); + if (value == null) { + setNull(row); + return; + } + if (value instanceof Timestamp) { + setTimestamp(row, (Timestamp) value); + return; + } + data.order(NATIVE_ORDER).putLong(row * Long.BYTES, encodeJavaUtilDate(value)); + markValid(row); + } + + public void setTimestamp(int row, LocalDate value) throws SQLException { + if (value == null) { + setNull(row); + return; + } + if (typeInfo.columnType == DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE) { + checkRowIndex(row); + Instant instant = value.atStartOfDay(ZoneId.systemDefault()).toInstant(); + data.order(NATIVE_ORDER).putLong(row * Long.BYTES, encodeInstant(instant)); + markValid(row); + return; + } + setTimestamp(row, value.atStartOfDay()); + } + + public void setOffsetDateTime(int row, OffsetDateTime value) throws SQLException { + checkRowIndex(row); + requireTimestampType(true); + if (value == null) { + setNull(row); + return; + } + data.order(NATIVE_ORDER) + .putLong(row * Long.BYTES, DuckDBTimestamp.localDateTime2Micros( + value.withOffsetSameInstant(ZoneOffset.UTC).toLocalDateTime())); + markValid(row); + } + + public void setBigDecimal(int row, BigDecimal value) throws SQLException { + checkRowIndex(row); + requireType(DuckDBColumnType.DECIMAL); + if (value == null) { + setNull(row); + return; + } + BigDecimal scaled; + try { + scaled = value.setScale(typeInfo.decimalMeta.scale); + } catch (ArithmeticException e) { + throw decimalOutOfRange(value, e); + } + if (scaled.precision() > typeInfo.decimalMeta.width) { + throw decimalOutOfRange(value); + } + switch (typeInfo.storageType) { + case DUCKDB_TYPE_SMALLINT: + try { + data.order(NATIVE_ORDER).putShort(row * Short.BYTES, scaled.unscaledValue().shortValueExact()); + } catch (ArithmeticException e) { + throw decimalOutOfRange(value, e); + } + break; + case DUCKDB_TYPE_INTEGER: + try { + data.order(NATIVE_ORDER).putInt(row * Integer.BYTES, scaled.unscaledValue().intValueExact()); + } catch (ArithmeticException e) { + throw decimalOutOfRange(value, e); + } + break; + case DUCKDB_TYPE_BIGINT: + try { + data.order(NATIVE_ORDER).putLong(row * Long.BYTES, scaled.unscaledValue().longValueExact()); + } catch (ArithmeticException e) { + throw decimalOutOfRange(value, e); + } + break; + case DUCKDB_TYPE_HUGEINT: { + BigInteger unscaled = scaled.unscaledValue(); + ByteBuffer slice = data.duplicate().order(NATIVE_ORDER); + slice.position(row * typeInfo.widthBytes); + slice.putLong(unscaled.longValue()); + slice.putLong(unscaled.shiftRight(Long.SIZE).longValue()); + break; + } + default: + throw new SQLException("Unsupported DECIMAL storage type: " + typeInfo.storageType); + } + markValid(row); + } + + public void setString(int row, String value) throws SQLException { + checkRowIndex(row); + requireType(DuckDBColumnType.VARCHAR); + if (value == null) { + setNull(row); + return; + } + duckdb_vector_assign_string_element_len(vectorRef, row, value.getBytes(UTF_8)); + markValid(row); + } + + ByteBuffer vectorRef() { + return vectorRef; + } + + private void ensureValidity() throws SQLException { + if (validity != null) { + return; + } + duckdb_vector_ensure_validity_writable(vectorRef); + validity = duckdb_vector_get_validity(vectorRef, rowCount); + if (validity == null) { + throw new SQLException("Cannot initialize vector validity"); + } + } + + private void markValid(int row) { + if (validity == null) { + return; + } + duckdb_validity_set_row_validity(validity, row, true); + } + + private void requireType(DuckDBColumnType expected) throws SQLException { + if (typeInfo.columnType != expected) { + throw new SQLException("Expected vector type " + expected + ", found " + typeInfo.columnType); + } + } + + private void checkRowIndex(int row) { + if (row < 0 || row >= rowCount) { + throw new IndexOutOfBoundsException("Row index out of bounds: " + row); + } + } + + private void requireTimestampType(boolean requireTimezone) throws SQLException { + if (requireTimezone) { + if (typeInfo.columnType != DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE) { + throw new SQLException("Expected vector type TIMESTAMP WITH TIME ZONE, found " + typeInfo.columnType); + } + return; + } + switch (typeInfo.columnType) { + case TIMESTAMP: + case TIMESTAMP_S: + case TIMESTAMP_MS: + case TIMESTAMP_NS: + case TIMESTAMP_WITH_TIME_ZONE: + return; + default: + throw new SQLException("Expected vector type TIMESTAMP*, found " + typeInfo.columnType); + } + } + + private long encodeLocalDateTime(LocalDateTime value) throws SQLException { + Instant instant; + if (typeInfo.columnType == DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE) { + instant = value.atZone(ZoneId.systemDefault()).toInstant(); + } else { + instant = value.toInstant(ZoneOffset.UTC); + } + return encodeInstant(instant); + } + + private long encodeJavaUtilDate(java.util.Date value) throws SQLException { + return encodeInstant(Instant.ofEpochMilli(value.getTime())); + } + + private long encodeInstant(Instant instant) throws SQLException { + long epochSeconds = instant.getEpochSecond(); + int nano = instant.getNano(); + switch (typeInfo.capiType) { + case DUCKDB_TYPE_TIMESTAMP_S: + return epochSeconds; + case DUCKDB_TYPE_TIMESTAMP_MS: + return Math.addExact(Math.multiplyExact(epochSeconds, 1_000L), nano / 1_000_000L); + case DUCKDB_TYPE_TIMESTAMP: + case DUCKDB_TYPE_TIMESTAMP_TZ: + return Math.addExact(Math.multiplyExact(epochSeconds, 1_000_000L), nano / 1_000L); + case DUCKDB_TYPE_TIMESTAMP_NS: + return Math.addExact(Math.multiplyExact(epochSeconds, 1_000_000_000L), nano); + default: + throw new SQLException("Expected vector type TIMESTAMP*, found " + typeInfo.columnType); + } + } + + private static void checkUnsignedRange(String typeName, long value, long maxValue) throws SQLException { + if (value < 0 || value > maxValue) { + throw new SQLException("Value out of range for " + typeName + ": " + value); + } + } + + private SQLException decimalOutOfRange(BigDecimal value) { + return new SQLException("Value out of range for " + decimalTypeName() + ": " + value); + } + + private SQLException decimalOutOfRange(BigDecimal value, ArithmeticException cause) { + return new SQLException("Value out of range for " + decimalTypeName() + ": " + value, cause); + } + + private String decimalTypeName() { + return "DECIMAL(" + typeInfo.decimalMeta.width + "," + typeInfo.decimalMeta.scale + ")"; + } +} diff --git a/src/test/java/org/duckdb/TestBindings.java b/src/test/java/org/duckdb/TestBindings.java index 2c97dfee7..1fc3a8d9e 100644 --- a/src/test/java/org/duckdb/TestBindings.java +++ b/src/test/java/org/duckdb/TestBindings.java @@ -6,7 +6,9 @@ import static org.duckdb.TestDuckDBJDBC.JDBC_URL; import static org.duckdb.test.Assertions.*; +import java.math.BigInteger; import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.nio.charset.StandardCharsets; import java.sql.*; import java.util.Arrays; @@ -42,6 +44,39 @@ public static void test_bindings_logical_type() throws Exception { assertThrows(() -> { duckdb_destroy_logical_type(null); }, SQLException.class); } + public static void test_bindings_parse_logical_type() throws Exception { + ByteBuffer integerType = duckdb_jdbc_parse_logical_type(null, "INTEGER".getBytes(UTF_8)); + assertNotNull(integerType); + assertEquals(DUCKDB_TYPE_INTEGER.typeId, duckdb_get_type_id(integerType)); + duckdb_destroy_logical_type(integerType); + + ByteBuffer decimalType = duckdb_jdbc_parse_logical_type(null, "DECIMAL(18,3)".getBytes(UTF_8)); + assertNotNull(decimalType); + assertEquals(DUCKDB_TYPE_DECIMAL.typeId, duckdb_get_type_id(decimalType)); + assertEquals(18, duckdb_decimal_width(decimalType)); + assertEquals(3, duckdb_decimal_scale(decimalType)); + assertEquals(DUCKDB_TYPE_BIGINT.typeId, duckdb_decimal_internal_type(decimalType)); + duckdb_destroy_logical_type(decimalType); + + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + stmt.execute("CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy')"); + + ByteBuffer enumType = duckdb_jdbc_parse_logical_type(conn.connRef, "mood".getBytes(UTF_8)); + assertNotNull(enumType); + assertEquals(DUCKDB_TYPE_ENUM.typeId, duckdb_get_type_id(enumType)); + assertEquals(3L, duckdb_enum_dictionary_size(enumType)); + assertEquals("sad".getBytes(UTF_8), duckdb_enum_dictionary_value(enumType, 0)); + duckdb_destroy_logical_type(enumType); + + assertThrows(() -> { + duckdb_jdbc_parse_logical_type(conn.connRef, "missing_type".getBytes(UTF_8)); + }, SQLException.class); + } + + assertThrows(() -> { duckdb_jdbc_parse_logical_type(null, null); }, SQLException.class); + } + public static void test_bindings_vector_create() throws Exception { ByteBuffer lt = duckdb_create_logical_type(DUCKDB_TYPE_INTEGER.typeId); ByteBuffer vec = duckdb_create_vector(lt); @@ -108,6 +143,73 @@ public static void test_bindings_vector_strings() throws Exception { duckdb_destroy_logical_type(lt); } + public static void test_bindings_varchar_string_bytes_null_row() throws Exception { + ByteBuffer lt = duckdb_create_logical_type(DUCKDB_TYPE_VARCHAR.typeId); + ByteBuffer vec = duckdb_create_vector(lt); + + long rowCount = duckdb_vector_size(); + ByteBuffer data = duckdb_vector_get_data(vec, rowCount * STRING_T_SIZE_BYTES); + duckdb_vector_ensure_validity_writable(vec); + ByteBuffer validity = duckdb_vector_get_validity(vec, rowCount); + + duckdb_validity_set_row_validity(validity, 0L, false); + assertNull(duckdb_jdbc_varchar_string_bytes(data, validity, rowCount, 0L)); + assertThrows( + () -> { duckdb_jdbc_varchar_string_bytes(data, validity, rowCount, rowCount); }, SQLException.class); + + duckdb_destroy_vector(vec); + duckdb_destroy_logical_type(lt); + } + + public static void test_bindings_vector_native_endian_roundtrip() throws Exception { + ByteBuffer lt = duckdb_create_logical_type(DUCKDB_TYPE_INTEGER.typeId); + ByteBuffer vec = duckdb_create_vector(lt); + + int rowCount = (int) duckdb_vector_size(); + int expected = 0x01020304; + DuckDBWritableVector writable = new DuckDBWritableVector(vec, rowCount); + writable.setInt(0, expected); + + ByteBuffer rawData = duckdb_vector_get_data(vec, (long) rowCount * Integer.BYTES); + assertEquals(rawData.order(ByteOrder.nativeOrder()).getInt(0), expected); + + DuckDBReadableVector readable = new DuckDBReadableVector(vec, rowCount); + assertEquals(readable.getInt(0), expected); + + duckdb_destroy_vector(vec); + duckdb_destroy_logical_type(lt); + } + + public static void test_bindings_vector_ubigint_native_endian_roundtrip() throws Exception { + ByteBuffer lt = duckdb_create_logical_type(DUCKDB_TYPE_UBIGINT.typeId); + ByteBuffer vec = duckdb_create_vector(lt); + + int rowCount = (int) duckdb_vector_size(); + assertTrue(rowCount >= 4); + BigInteger[] values = + new BigInteger[] {BigInteger.ZERO, new BigInteger("42"), new BigInteger("9223372036854775808"), + new BigInteger("18446744073709551615")}; + + DuckDBWritableVector writable = new DuckDBWritableVector(vec, rowCount); + for (int i = 0; i < values.length; i++) { + writable.setUint64(i, values[i]); + } + + ByteBuffer rawData = duckdb_vector_get_data(vec, (long) rowCount * Long.BYTES); + ByteBuffer nativeData = rawData.order(ByteOrder.nativeOrder()); + for (int i = 0; i < values.length; i++) { + assertEquals(nativeData.getLong(i * Long.BYTES), values[i].longValue()); + } + + DuckDBReadableVector readable = new DuckDBReadableVector(vec, rowCount); + for (int i = 0; i < values.length; i++) { + assertEquals(readable.getUint64(i), values[i]); + } + + duckdb_destroy_vector(vec); + duckdb_destroy_logical_type(lt); + } + public static void test_bindings_vector_validity() throws Exception { ByteBuffer lt = duckdb_create_logical_type(DUCKDB_TYPE_VARCHAR.typeId); ByteBuffer vec = duckdb_create_vector(lt); diff --git a/src/test/java/org/duckdb/TestDuckDBJDBC.java b/src/test/java/org/duckdb/TestDuckDBJDBC.java index 3769121a7..88051efdb 100644 --- a/src/test/java/org/duckdb/TestDuckDBJDBC.java +++ b/src/test/java/org/duckdb/TestDuckDBJDBC.java @@ -2243,12 +2243,13 @@ public static void main(String[] args) throws Exception { Class clazz = Class.forName("org.duckdb." + arg1); statusCode = runTests(new String[0], clazz); } else { - statusCode = runTests(args, TestDuckDBJDBC.class, TestAppender.class, TestAppenderCollection.class, - TestAppenderCollection2D.class, TestAppenderComposite.class, - TestSingleValueAppender.class, TestBatch.class, TestBindings.class, TestClosure.class, - TestExtensionTypes.class, TestMetadata.class, TestNoLib.class, /* TestSpatial.class,*/ - TestParameterMetadata.class, TestPrepare.class, TestResults.class, - TestSessionInit.class, TestTimestamp.class, TestVariant.class); + statusCode = + runTests(args, TestDuckDBJDBC.class, TestAppender.class, TestAppenderCollection.class, + TestAppenderCollection2D.class, TestAppenderComposite.class, TestSingleValueAppender.class, + TestBatch.class, TestBindings.class, TestClosure.class, TestExtensionTypes.class, + TestMetadata.class, TestNoLib.class, TestScalarFunctions.class, + /* TestSpatial.class,*/ TestParameterMetadata.class, TestPrepare.class, TestResults.class, + TestSessionInit.class, TestTimestamp.class, TestVariant.class); } System.exit(statusCode); } diff --git a/src/test/java/org/duckdb/TestScalarFunctions.java b/src/test/java/org/duckdb/TestScalarFunctions.java new file mode 100644 index 000000000..8359d5d76 --- /dev/null +++ b/src/test/java/org/duckdb/TestScalarFunctions.java @@ -0,0 +1,909 @@ +package org.duckdb; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.duckdb.DuckDBBindings.*; +import static org.duckdb.DuckDBBindings.CAPIType.DUCKDB_TYPE_INTEGER; +import static org.duckdb.TestDuckDBJDBC.JDBC_URL; +import static org.duckdb.test.Assertions.*; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.nio.ByteBuffer; +import java.sql.Date; +import java.sql.DriverManager; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.sql.Time; +import java.sql.Timestamp; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.OffsetDateTime; +import java.time.ZoneOffset; + +public class TestScalarFunctions { + private interface ResultSetVerifier { + void verify(ResultSet rs) throws Exception; + } + + public static void test_bindings_scalar_function() throws Exception { + ByteBuffer intType = duckdb_create_logical_type(DUCKDB_TYPE_INTEGER.typeId); + ByteBuffer scalarFunction = duckdb_create_scalar_function(); + assertNotNull(scalarFunction); + + duckdb_scalar_function_set_name(scalarFunction, "binding_scalar_fn".getBytes(UTF_8)); + duckdb_scalar_function_add_parameter(scalarFunction, intType); + duckdb_scalar_function_set_return_type(scalarFunction, intType); + + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class)) { + assertEquals(duckdb_register_scalar_function(conn.connRef, scalarFunction), 1); + + assertThrows(() -> { duckdb_register_scalar_function(null, scalarFunction); }, SQLException.class); + assertThrows(() -> { duckdb_register_scalar_function(conn.connRef, null); }, SQLException.class); + } + + duckdb_destroy_scalar_function(scalarFunction); + duckdb_destroy_logical_type(intType); + + assertThrows(() -> { duckdb_destroy_scalar_function(null); }, SQLException.class); + assertThrows(() -> { duckdb_scalar_function_set_name(null, "x".getBytes(UTF_8)); }, SQLException.class); + } + + public static void test_register_scalar_function() throws Exception { + test_register_scalar_function_integer(); + } + + public static void test_register_scalar_function_parallel() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + stmt.execute("PRAGMA threads=4"); + conn.registerScalarFunction("java_add_one_bigint", new String[] {"BIGINT"}, "BIGINT", + (input, rowCount, out) -> { + DuckDBReadableVector in = input.vector(0); + for (int i = 0; i < rowCount; i++) { + out.setLong(i, in.getLong(i) + 1); + } + }); + + try (ResultSet rs = stmt.executeQuery("SELECT sum(java_add_one_bigint(i)) FROM range(1000000) t(i)")) { + assertTrue(rs.next()); + assertEquals(rs.getLong(1), 500000500000L); + assertFalse(rs.wasNull()); + assertFalse(rs.next()); + } + } + } + + public static void test_register_scalar_function_boolean() throws Exception { + assertUnaryScalarFunction("java_not_bool", "BOOLEAN", "BOOLEAN", + (input, rowCount, out) + -> { + DuckDBReadableVector in = input.vector(0); + for (int i = 0; i < rowCount; i++) { + if (in.isNull(i)) { + out.setNull(i); + } else { + out.setBoolean(i, !in.getBoolean(i)); + } + } + }, + "SELECT java_not_bool(v) FROM (VALUES (TRUE), (NULL), (FALSE)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Boolean.class), false); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Boolean.class), true); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_tinyint() throws Exception { + assertUnaryScalarFunction("java_add_tinyint", "TINYINT", "TINYINT", + (input, rowCount, out) + -> { + DuckDBReadableVector in = input.vector(0); + for (int i = 0; i < rowCount; i++) { + if (in.isNull(i)) { + out.setNull(i); + } else { + out.setByte(i, (byte) (in.getByte(i) + 1)); + } + } + }, + "SELECT java_add_tinyint(v) FROM (VALUES (41::TINYINT), (NULL), (-2::TINYINT)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Byte.class), (byte) 42); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Byte.class), (byte) -1); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_smallint() throws Exception { + assertUnaryScalarFunction( + "java_add_smallint", "SMALLINT", "SMALLINT", + (input, rowCount, out) + -> { + DuckDBReadableVector in = input.vector(0); + for (int i = 0; i < rowCount; i++) { + if (in.isNull(i)) { + out.setNull(i); + } else { + out.setShort(i, (short) (in.getShort(i) + 2)); + } + } + }, + "SELECT java_add_smallint(v) FROM (VALUES (40::SMALLINT), (NULL), (-4::SMALLINT)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Short.class), (short) 42); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Short.class), (short) -2); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_integer() throws Exception { + assertUnaryScalarFunction("java_add_int", "INTEGER", "INTEGER", + (input, rowCount, out) + -> { + DuckDBReadableVector in = input.vector(0); + for (int i = 0; i < rowCount; i++) { + if (in.isNull(i)) { + out.setNull(i); + } else { + out.setInt(i, in.getInt(i) + 1); + } + } + }, + "SELECT java_add_int(v) FROM (VALUES (1), (NULL), (41)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 2); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 42); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_integer_revalidates_after_null() throws Exception { + assertUnaryScalarFunction("java_revalidate_int", "INTEGER", "INTEGER", + (input, rowCount, out) + -> { + DuckDBReadableVector in = input.vector(0); + for (int i = 0; i < rowCount; i++) { + if (in.isNull(i)) { + out.setNull(i); + } else { + out.setNull(i); + out.setInt(i, in.getInt(i) + 1); + } + } + }, + "SELECT java_revalidate_int(v) FROM (VALUES (41), (NULL)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 42); + assertFalse(rs.wasNull()); + assertTrue(rs.next()); + assertNullRow(rs); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_bigint() throws Exception { + assertUnaryScalarFunction("java_add_bigint", "BIGINT", "BIGINT", + (input, rowCount, out) + -> { + DuckDBReadableVector in = input.vector(0); + for (int i = 0; i < rowCount; i++) { + if (in.isNull(i)) { + out.setNull(i); + } else { + out.setLong(i, in.getLong(i) + 3); + } + } + }, + "SELECT java_add_bigint(v) FROM (VALUES (39::BIGINT), (NULL), (-5::BIGINT)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Long.class), 42L); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Long.class), -2L); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_utinyint() throws Exception { + assertUnaryScalarFunction( + "java_add_utinyint", "UTINYINT", "UTINYINT", + (input, rowCount, out) + -> { + DuckDBReadableVector in = input.vector(0); + for (int i = 0; i < rowCount; i++) { + if (in.isNull(i)) { + out.setNull(i); + } else { + out.setUint8(i, in.getUint8(i) + 1); + } + } + }, + "SELECT java_add_utinyint(v) FROM (VALUES (41::UTINYINT), (NULL), (254::UTINYINT)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Short.class), (short) 42); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Short.class), (short) 255); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_usmallint() throws Exception { + assertUnaryScalarFunction( + "java_add_usmallint", "USMALLINT", "USMALLINT", + (input, rowCount, out) + -> { + DuckDBReadableVector in = input.vector(0); + for (int i = 0; i < rowCount; i++) { + if (in.isNull(i)) { + out.setNull(i); + } else { + out.setUint16(i, in.getUint16(i) + 2); + } + } + }, + "SELECT java_add_usmallint(v) FROM (VALUES (40::USMALLINT), (NULL), (65533::USMALLINT)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 42); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 65535); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_uinteger() throws Exception { + assertUnaryScalarFunction( + "java_add_uinteger", "UINTEGER", "UINTEGER", + (input, rowCount, out) + -> { + DuckDBReadableVector in = input.vector(0); + for (int i = 0; i < rowCount; i++) { + if (in.isNull(i)) { + out.setNull(i); + } else { + out.setUint32(i, in.getUint32(i) + 3); + } + } + }, + "SELECT java_add_uinteger(v) FROM (VALUES (39::UINTEGER), (NULL), (4294967292::UINTEGER)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Long.class), 42L); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Long.class), 4294967295L); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_ubigint() throws Exception { + assertUnaryScalarFunction("java_add_ubigint", "UBIGINT", "UBIGINT", + (input, rowCount, out) + -> { + DuckDBReadableVector in = input.vector(0); + BigInteger increment = BigInteger.ONE; + for (int i = 0; i < rowCount; i++) { + if (in.isNull(i)) { + out.setNull(i); + } else { + out.setUint64(i, in.getUint64(i).add(increment)); + } + } + }, + "SELECT java_add_ubigint(v) FROM (VALUES (41::UBIGINT), (NULL), " + + "(18446744073709551614::UBIGINT)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, BigInteger.class), new BigInteger("42")); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, BigInteger.class), + new BigInteger("18446744073709551615")); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_float() throws Exception { + assertUnaryScalarFunction("java_add_float", "FLOAT", "FLOAT", + (input, rowCount, out) + -> { + DuckDBReadableVector in = input.vector(0); + for (int i = 0; i < rowCount; i++) { + if (in.isNull(i)) { + out.setNull(i); + } else { + out.setFloat(i, in.getFloat(i) + 1.25f); + } + } + }, + "SELECT java_add_float(v) FROM (VALUES (40.75::FLOAT), (NULL), (-2.5::FLOAT)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Float.class), 42.0f); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Float.class), -1.25f); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_double() throws Exception { + assertUnaryScalarFunction("java_add_double", "DOUBLE", "DOUBLE", + (input, rowCount, out) + -> { + DuckDBReadableVector in = input.vector(0); + for (int i = 0; i < rowCount; i++) { + if (in.isNull(i)) { + out.setNull(i); + } else { + out.setDouble(i, in.getDouble(i) + 1.5d); + } + } + }, + "SELECT java_add_double(v) FROM (VALUES (40.5::DOUBLE), (NULL), (-3.0::DOUBLE)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Double.class), 42.0d); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Double.class), -1.5d); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_decimal() throws Exception { + assertUnaryScalarFunction("java_add_decimal", "DECIMAL(38,10)", "DECIMAL(38,10)", + (input, rowCount, out) + -> { + DuckDBReadableVector in = input.vector(0); + BigDecimal increment = new BigDecimal("0.0000000001"); + for (int i = 0; i < rowCount; i++) { + if (in.isNull(i)) { + out.setNull(i); + } else { + out.setBigDecimal(i, in.getBigDecimal(i).add(increment)); + } + } + }, + "SELECT java_add_decimal(v) FROM (VALUES " + + "(CAST('12345678901234567890.1234567890' AS DECIMAL(38,10))), " + + "(NULL), " + + "(CAST('-0.0000000001' AS DECIMAL(38,10)))) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, BigDecimal.class), + new BigDecimal("12345678901234567890.1234567891")); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, BigDecimal.class), BigDecimal.ZERO.setScale(10)); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_decimal_precision_overflow() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerScalarFunction("java_decimal_precision_overflow", new String[] {"DECIMAL(10,2)"}, + "DECIMAL(10,2)", (input, rowCount, out) -> { + for (int i = 0; i < rowCount; i++) { + out.setBigDecimal(i, new BigDecimal("12345678901.23")); + } + }); + + String err = assertThrows(() -> { + stmt.execute("SELECT java_decimal_precision_overflow(CAST(1 AS DECIMAL(10,2)))"); + }, SQLException.class); + assertTrue(err.contains("DECIMAL(10,2)")); + } + } + + public static void test_register_scalar_function_decimal_scale_overflow() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerScalarFunction("java_decimal_scale_overflow", new String[] {"DECIMAL(10,2)"}, "DECIMAL(10,2)", + (input, rowCount, out) -> { + for (int i = 0; i < rowCount; i++) { + out.setBigDecimal(i, new BigDecimal("1.234")); + } + }); + + String err = assertThrows(() -> { + stmt.execute("SELECT java_decimal_scale_overflow(CAST(1 AS DECIMAL(10,2)))"); + }, SQLException.class); + assertTrue(err.contains("DECIMAL(10,2)")); + } + } + + public static void test_register_scalar_function_date() throws Exception { + assertUnaryScalarFunction( + "java_add_date", "DATE", "DATE", + (input, rowCount, out) + -> { + DuckDBReadableVector in = input.vector(0); + for (int i = 0; i < rowCount; i++) { + if (in.isNull(i)) { + out.setNull(i); + } else { + out.setDate(i, in.getLocalDate(i).plusDays(2)); + } + } + }, + "SELECT java_add_date(v) FROM (VALUES (DATE '2024-07-20'), (NULL), (DATE '1969-12-31')) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, LocalDate.class), LocalDate.of(2024, 7, 22)); + assertEquals(rs.getDate(1), Date.valueOf(LocalDate.of(2024, 7, 22))); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, LocalDate.class), LocalDate.of(1970, 1, 2)); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_date_from_java_util_date() throws Exception { + assertUnaryScalarFunction("java_date_from_util_date", "DATE", "DATE", + (input, rowCount, out) + -> { + DuckDBReadableVector in = input.vector(0); + for (int i = 0; i < rowCount; i++) { + if (in.isNull(i)) { + out.setNull(i); + } else { + LocalDate value = in.getLocalDate(i).plusDays(1); + out.setDate(i, java.util.Date.from(value.atStartOfDay(UTC).toInstant())); + } + } + }, + "SELECT java_date_from_util_date(v) FROM (VALUES (DATE '2024-07-21'), (NULL)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, LocalDate.class), LocalDate.of(2024, 7, 22)); + assertTrue(rs.next()); + assertNullRow(rs); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_timestamp() throws Exception { + assertUnaryScalarFunction("java_add_timestamp", "TIMESTAMP", "TIMESTAMP", + (input, rowCount, out) + -> { + DuckDBReadableVector in = input.vector(0); + for (int i = 0; i < rowCount; i++) { + if (in.isNull(i)) { + out.setNull(i); + } else { + out.setTimestamp(i, in.getLocalDateTime(i).plusMinutes(30)); + } + } + }, + "SELECT java_add_timestamp(v) FROM (VALUES " + + "(TIMESTAMP '2024-07-21 12:34:56.123456'), " + + "(NULL), " + + "(TIMESTAMP '1969-12-31 23:45:00')) t(v)", + rs -> { + assertTrue(rs.next()); + Timestamp ts1 = rs.getTimestamp(1); + assertEquals(ts1, Timestamp.valueOf("2024-07-21 13:04:56.123456")); + assertEquals(rs.getObject(1, LocalDateTime.class), + LocalDateTime.of(2024, 7, 21, 13, 4, 56, 123456000)); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getTimestamp(1), Timestamp.valueOf("1970-01-01 00:15:00")); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_timestamp_s() throws Exception { + assertUnaryScalarFunction( + "java_add_timestamp_s", "TIMESTAMP_S", "TIMESTAMP_S", + (input, rowCount, out) + -> { + DuckDBReadableVector in = input.vector(0); + for (int i = 0; i < rowCount; i++) { + if (in.isNull(i)) { + out.setNull(i); + } else { + out.setTimestamp(i, in.getLocalDateTime(i).plusSeconds(2)); + } + } + }, + "SELECT java_add_timestamp_s(v) FROM (VALUES (TIMESTAMP_S '2024-07-21 12:34:56'), (NULL)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getTimestamp(1), Timestamp.valueOf("2024-07-21 12:34:58")); + assertTrue(rs.next()); + assertNullRow(rs); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_timestamp_s_pre_epoch() throws Exception { + assertUnaryScalarFunction("java_copy_timestamp_s_pre_epoch", "TIMESTAMP", "TIMESTAMP_S", + (input, rowCount, out) + -> { + DuckDBReadableVector in = input.vector(0); + for (int i = 0; i < rowCount; i++) { + if (in.isNull(i)) { + out.setNull(i); + } else { + out.setTimestamp(i, in.getLocalDateTime(i)); + } + } + }, + "SELECT java_copy_timestamp_s_pre_epoch(v) FROM (VALUES " + + "(TIMESTAMP '1969-12-31 23:59:59.999')) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getTimestamp(1), Timestamp.valueOf("1969-12-31 23:59:59")); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_timestamp_ms() throws Exception { + assertUnaryScalarFunction("java_add_timestamp_ms", "TIMESTAMP_MS", "TIMESTAMP_MS", + (input, rowCount, out) + -> { + DuckDBReadableVector in = input.vector(0); + for (int i = 0; i < rowCount; i++) { + if (in.isNull(i)) { + out.setNull(i); + } else { + out.setTimestamp(i, in.getLocalDateTime(i).plusNanos(7_000_000)); + } + } + }, + "SELECT java_add_timestamp_ms(v) FROM (VALUES " + + "(TIMESTAMP_MS '2024-07-21 12:34:56.123'), (NULL)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, LocalDateTime.class), + LocalDateTime.of(2024, 7, 21, 12, 34, 56, 130_000_000)); + assertTrue(rs.next()); + assertNullRow(rs); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_timestamp_ms_pre_epoch() throws Exception { + assertUnaryScalarFunction("java_copy_timestamp_ms_pre_epoch", "TIMESTAMP", "TIMESTAMP_MS", + (input, rowCount, out) + -> { + DuckDBReadableVector in = input.vector(0); + for (int i = 0; i < rowCount; i++) { + if (in.isNull(i)) { + out.setNull(i); + } else { + out.setTimestamp(i, in.getLocalDateTime(i)); + } + } + }, + "SELECT java_copy_timestamp_ms_pre_epoch(v) FROM (VALUES " + + "(TIMESTAMP '1969-12-31 23:59:59.9995')) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, LocalDateTime.class), + LocalDateTime.of(1969, 12, 31, 23, 59, 59, 999_000_000)); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_timestamp_ns() throws Exception { + assertUnaryScalarFunction("java_add_timestamp_ns", "TIMESTAMP_NS", "TIMESTAMP_NS", + (input, rowCount, out) + -> { + DuckDBReadableVector in = input.vector(0); + for (int i = 0; i < rowCount; i++) { + if (in.isNull(i)) { + out.setNull(i); + } else { + out.setTimestamp(i, in.getLocalDateTime(i).plusNanos(789)); + } + } + }, + "SELECT java_add_timestamp_ns(v) FROM (VALUES " + + "(TIMESTAMP_NS '2024-07-21 12:34:56.123456789'), (NULL)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, LocalDateTime.class), + LocalDateTime.of(2024, 7, 21, 12, 34, 56, 123457578)); + assertTrue(rs.next()); + assertNullRow(rs); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_timestamptz() throws Exception { + assertUnaryScalarFunction( + "java_add_timestamptz", "TIMESTAMP WITH TIME ZONE", "TIMESTAMP WITH TIME ZONE", + (input, rowCount, out) + -> { + DuckDBReadableVector in = input.vector(0); + for (int i = 0; i < rowCount; i++) { + if (in.isNull(i)) { + out.setNull(i); + } else { + out.setOffsetDateTime(i, in.getOffsetDateTime(i).plusMinutes(5)); + } + } + }, + "SELECT java_add_timestamptz(v) FROM (VALUES " + + "(TIMESTAMPTZ '2024-07-21 12:34:56.123456+02:00'), (NULL)) t(v)", + rs -> { + assertTrue(rs.next()); + assertTrue(rs.getObject(1, OffsetDateTime.class) + .isEqual(OffsetDateTime.of(2024, 7, 21, 10, 39, 56, 123456000, ZoneOffset.UTC))); + assertTrue(rs.next()); + assertNullRow(rs); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_timestamptz_set_timestamp() throws Exception { + assertUnaryScalarFunction( + "java_copy_timestamptz_with_timestamp", "TIMESTAMP WITH TIME ZONE", "TIMESTAMP WITH TIME ZONE", + (input, rowCount, out) + -> { + DuckDBReadableVector in = input.vector(0); + for (int i = 0; i < rowCount; i++) { + if (in.isNull(i)) { + out.setNull(i); + } else { + out.setTimestamp(i, in.getTimestamp(i)); + } + } + }, + "SELECT java_copy_timestamptz_with_timestamp(v) FROM (VALUES " + + "(TIMESTAMPTZ '2024-07-21 12:34:56.123456+02:00')) t(v)", + rs -> { + assertTrue(rs.next()); + assertTrue(rs.getObject(1, OffsetDateTime.class) + .isEqual(OffsetDateTime.of(2024, 7, 21, 10, 34, 56, 123456000, ZoneOffset.UTC))); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_timestamp_from_java_util_date() throws Exception { + assertUnaryScalarFunction( + "java_timestamp_from_util_date", "TIMESTAMP", "TIMESTAMP", + (input, rowCount, out) + -> { + DuckDBReadableVector in = input.vector(0); + long oneSecondMillis = 1000L; + for (int i = 0; i < rowCount; i++) { + if (in.isNull(i)) { + out.setNull(i); + } else { + out.setTimestamp(i, new java.util.Date(in.getTimestamp(i).getTime() + oneSecondMillis)); + } + } + }, + "SELECT epoch_ms(java_timestamp_from_util_date(v)) FROM (VALUES " + + "(TIMESTAMP '2024-07-21 12:34:56.123456'), " + + "(NULL)) t(v)", + rs -> { + assertTrue(rs.next()); + Timestamp input = Timestamp.valueOf("2024-07-21 12:34:56.123456"); + assertEquals(rs.getLong(1), input.getTime() + 1000L); + assertFalse(rs.wasNull()); + assertTrue(rs.next()); + assertNullRow(rs); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_timestamp_from_java_util_date_typed_timestamp() throws Exception { + assertUnaryScalarFunction( + "java_timestamp_from_util_ts", "TIMESTAMP", "TIMESTAMP", + (input, rowCount, out) + -> { + DuckDBReadableVector in = input.vector(0); + for (int i = 0; i < rowCount; i++) { + if (in.isNull(i)) { + out.setNull(i); + } else { + java.util.Date value = Timestamp.valueOf(in.getLocalDateTime(i).plusNanos(789000)); + out.setTimestamp(i, value); + } + } + }, + "SELECT java_timestamp_from_util_ts(v) FROM (VALUES (TIMESTAMP '2024-07-21 12:34:56.123456')) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getTimestamp(1), Timestamp.valueOf("2024-07-21 12:34:56.124245")); + assertFalse(rs.wasNull()); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_timestamp_from_java_util_date_typed_sql_date() throws Exception { + assertUnaryScalarFunction( + "java_timestamp_from_util_sql_date", "DATE", "TIMESTAMP", + (input, rowCount, out) + -> { + DuckDBReadableVector in = input.vector(0); + for (int i = 0; i < rowCount; i++) { + if (in.isNull(i)) { + out.setNull(i); + } else { + java.util.Date value = Date.valueOf(in.getLocalDate(i)); + out.setTimestamp(i, value); + } + } + }, + "SELECT epoch_ms(java_timestamp_from_util_sql_date(v)) FROM (VALUES (DATE '2024-07-21')) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getLong(1), Date.valueOf("2024-07-21").getTime()); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_timestamp_from_java_util_date_typed_sql_time() throws Exception { + assertUnaryScalarFunction( + "java_timestamp_from_util_sql_time", "TIMESTAMP", "TIMESTAMP", + (input, rowCount, out) + -> { + for (int i = 0; i < rowCount; i++) { + java.util.Date value = Time.valueOf("12:34:56"); + out.setTimestamp(i, value); + } + }, + "SELECT epoch_ms(java_timestamp_from_util_sql_time(v)) FROM (VALUES (TIMESTAMP '2024-07-21 00:00:00')) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getLong(1), Time.valueOf("12:34:56").getTime()); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_timestamp_from_local_date() throws Exception { + assertUnaryScalarFunction( + "java_timestamp_from_local_date", "DATE", "TIMESTAMP", + (input, rowCount, out) + -> { + DuckDBReadableVector in = input.vector(0); + for (int i = 0; i < rowCount; i++) { + if (in.isNull(i)) { + out.setNull(i); + } else { + out.setTimestamp(i, in.getLocalDate(i).plusDays(1)); + } + } + }, + "SELECT java_timestamp_from_local_date(v) FROM (VALUES (DATE '2024-07-21'), (NULL)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getTimestamp(1), Timestamp.valueOf("2024-07-22 00:00:00")); + assertEquals(rs.getObject(1, LocalDateTime.class), LocalDateTime.of(2024, 7, 22, 0, 0)); + assertTrue(rs.next()); + assertNullRow(rs); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_varchar() throws Exception { + assertUnaryScalarFunction("java_suffix_varchar", "VARCHAR", "VARCHAR", + (input, rowCount, out) + -> { + DuckDBReadableVector in = input.vector(0); + for (int i = 0; i < rowCount; i++) { + if (in.isNull(i)) { + out.setNull(i); + } else { + out.setString(i, in.getString(i) + "_java"); + } + } + }, + "SELECT java_suffix_varchar(v) FROM (VALUES ('duck'), (NULL), " + + "('abcdefghijklmnop')) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, String.class), "duck_java"); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, String.class), "abcdefghijklmnop_java"); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_varchar_get_string_handles_null() throws Exception { + assertUnaryScalarFunction("java_echo_varchar_nullable", "VARCHAR", "VARCHAR", + (input, rowCount, out) + -> { + DuckDBReadableVector in = input.vector(0); + for (int i = 0; i < rowCount; i++) { + out.setString(i, in.getString(i)); + } + }, + "SELECT java_echo_varchar_nullable(v) FROM (VALUES ('duck'), (NULL), " + + "('abcdefghijklmnop')) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, String.class), "duck"); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, String.class), "abcdefghijklmnop"); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_varchar_revalidates_after_null() throws Exception { + assertUnaryScalarFunction("java_revalidate_varchar", "VARCHAR", "VARCHAR", + (input, rowCount, out) + -> { + DuckDBReadableVector in = input.vector(0); + for (int i = 0; i < rowCount; i++) { + if (in.isNull(i)) { + out.setNull(i); + } else { + out.setNull(i); + out.setString(i, in.getString(i) + "_ok"); + } + } + }, + "SELECT java_revalidate_varchar(v) FROM (VALUES ('duck'), (NULL)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, String.class), "duck_ok"); + assertFalse(rs.wasNull()); + assertTrue(rs.next()); + assertNullRow(rs); + assertFalse(rs.next()); + }); + } + + private static void assertUnaryScalarFunction(String functionName, String parameterType, String returnType, + DuckDBVectorizedScalarFunction function, String query, + ResultSetVerifier verifier) throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerScalarFunction(functionName, new String[] {parameterType}, returnType, function); + try (ResultSet rs = stmt.executeQuery(query)) { + verifier.verify(rs); + } + } + } + + private static void assertNullRow(ResultSet rs) throws Exception { + assertEquals(rs.getObject(1), null); + assertTrue(rs.wasNull()); + } + + private static final java.time.ZoneOffset UTC = java.time.ZoneOffset.UTC; +} From c12a32ee5c3f2b04b79b76257ddae9c869528ccf Mon Sep 17 00:00:00 2001 From: Luis Fernando Kauer Date: Tue, 31 Mar 2026 18:03:38 -0300 Subject: [PATCH 2/9] refine scalar udf type parsing and isolate jni scalar bridge --- CMakeLists.txt | 1 + UDF.MD | 14 +- duckdb_java.def | 1 + duckdb_java.exp | 1 + duckdb_java.map | 1 + src/jni/bindings_logical_type.cpp | 20 ++ src/jni/duckdb_java.cpp | 288 ----------------- src/jni/scalar_functions.cpp | 300 ++++++++++++++++++ src/jni/scalar_functions.hpp | 6 + src/main/java/org/duckdb/DuckDBBindings.java | 2 + .../java/org/duckdb/DuckDBConnection.java | 86 +++-- .../java/org/duckdb/DuckDBLogicalType.java | 172 ++++++++++ src/test/java/org/duckdb/TestBindings.java | 37 ++- .../java/org/duckdb/TestScalarFunctions.java | 53 ++++ 14 files changed, 661 insertions(+), 321 deletions(-) create mode 100644 src/jni/scalar_functions.cpp create mode 100644 src/jni/scalar_functions.hpp create mode 100644 src/main/java/org/duckdb/DuckDBLogicalType.java diff --git a/CMakeLists.txt b/CMakeLists.txt index e5bbe070a..7bd2316a9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -598,6 +598,7 @@ add_library(duckdb_java SHARED src/jni/duckdb_java.cpp src/jni/functions.cpp src/jni/refs.cpp + src/jni/scalar_functions.cpp src/jni/types.cpp src/jni/util.cpp ${DUCKDB_SRC_FILES}) diff --git a/UDF.MD b/UDF.MD index 7b81cab9e..7fed83857 100644 --- a/UDF.MD +++ b/UDF.MD @@ -11,8 +11,20 @@ void registerScalarFunction( ) throws SQLException ``` +You can also use the typed overload: + +```java +void registerScalarFunction( + String name, + DuckDBLogicalType[] parameterTypes, + DuckDBLogicalType returnType, + DuckDBVectorizedScalarFunction function +) throws SQLException +``` + Notes: -- `parameterTypes` and `returnType` are SQL type strings (for example: `INTEGER`, `VARCHAR`, `TIMESTAMP`). +- The string overload accepts SQL types like `INTEGER`, `VARCHAR`, `TIMESTAMP`, `DECIMAL(18,2)`. +- The typed overload uses `DuckDBLogicalType` (C API style) and avoids SQL text parsing. - The callback is vectorized: process `rowCount` rows from the input chunk and write one value per row into `out`. - `DuckDBDataChunkReader` / `DuckDBReadableVector` / `DuckDBWritableVector` are valid only during the callback. diff --git a/duckdb_java.def b/duckdb_java.def index 8edfc8220..bd2e5f6d8 100644 --- a/duckdb_java.def +++ b/duckdb_java.def @@ -60,6 +60,7 @@ Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1return_1type Java_org_duckdb_DuckDBBindings_duckdb_1register_1scalar_1function Java_org_duckdb_DuckDBBindings_duckdb_1jdbc_1varchar_1string_1bytes Java_org_duckdb_DuckDBBindings_duckdb_1create_1logical_1type +Java_org_duckdb_DuckDBBindings_duckdb_1create_1decimal_1type Java_org_duckdb_DuckDBBindings_duckdb_1jdbc_1parse_1logical_1type Java_org_duckdb_DuckDBBindings_duckdb_1get_1type_1id Java_org_duckdb_DuckDBBindings_duckdb_1decimal_1width diff --git a/duckdb_java.exp b/duckdb_java.exp index b6811a443..39d462620 100644 --- a/duckdb_java.exp +++ b/duckdb_java.exp @@ -57,6 +57,7 @@ _Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1return_1type _Java_org_duckdb_DuckDBBindings_duckdb_1register_1scalar_1function _Java_org_duckdb_DuckDBBindings_duckdb_1jdbc_1varchar_1string_1bytes _Java_org_duckdb_DuckDBBindings_duckdb_1create_1logical_1type +_Java_org_duckdb_DuckDBBindings_duckdb_1create_1decimal_1type _Java_org_duckdb_DuckDBBindings_duckdb_1jdbc_1parse_1logical_1type _Java_org_duckdb_DuckDBBindings_duckdb_1jdbc_1scalar_1function_1set_1callback _Java_org_duckdb_DuckDBBindings_duckdb_1get_1type_1id diff --git a/duckdb_java.map b/duckdb_java.map index 635070538..cdf25f812 100644 --- a/duckdb_java.map +++ b/duckdb_java.map @@ -59,6 +59,7 @@ DUCKDB_JAVA { Java_org_duckdb_DuckDBBindings_duckdb_1register_1scalar_1function; Java_org_duckdb_DuckDBBindings_duckdb_1jdbc_1varchar_1string_1bytes; Java_org_duckdb_DuckDBBindings_duckdb_1create_1logical_1type; + Java_org_duckdb_DuckDBBindings_duckdb_1create_1decimal_1type; Java_org_duckdb_DuckDBBindings_duckdb_1jdbc_1parse_1logical_1type; Java_org_duckdb_DuckDBBindings_duckdb_1get_1type_1id; Java_org_duckdb_DuckDBBindings_duckdb_1decimal_1width; diff --git a/src/jni/bindings_logical_type.cpp b/src/jni/bindings_logical_type.cpp index ac4f0c568..a72750e9d 100644 --- a/src/jni/bindings_logical_type.cpp +++ b/src/jni/bindings_logical_type.cpp @@ -38,6 +38,26 @@ JNIEXPORT jobject JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1create_1logical return make_ptr_buf(env, lt); } +/* + * Class: org_duckdb_DuckDBBindings + * Method: duckdb_create_decimal_type + * Signature: (II)Ljava/nio/ByteBuffer; + */ +JNIEXPORT jobject JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1create_1decimal_1type(JNIEnv *env, jclass, jint width, + jint scale) { + + if (width < 1 || width > 38) { + env->ThrowNew(J_SQLException, "Invalid decimal width: expected 1..38"); + return nullptr; + } + if (scale < 0 || scale > width) { + env->ThrowNew(J_SQLException, "Invalid decimal scale: expected 0..width"); + return nullptr; + } + duckdb_logical_type lt = duckdb_create_decimal_type(static_cast(width), static_cast(scale)); + return make_ptr_buf(env, lt); +} + /* * Class: org_duckdb_DuckDBBindings * Method: duckdb_jdbc_parse_logical_type diff --git a/src/jni/duckdb_java.cpp b/src/jni/duckdb_java.cpp index 408e6c30c..22c51f16c 100644 --- a/src/jni/duckdb_java.cpp +++ b/src/jni/duckdb_java.cpp @@ -30,7 +30,6 @@ using namespace duckdb; using namespace std; static jint JNI_VERSION = JNI_VERSION_1_6; -static JavaVM *JVM_REF = nullptr; void ThrowJNI(JNIEnv *env, const char *message) { D_ASSERT(J_SQLException); @@ -42,7 +41,6 @@ JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved) { if (vm->GetEnv(reinterpret_cast(&env), JNI_VERSION) != JNI_OK) { return JNI_ERR; } - JVM_REF = vm; try { create_refs(env); @@ -65,283 +63,10 @@ JNIEXPORT void JNICALL JNI_OnUnload(JavaVM *vm, void *reserved) { return; } delete_global_refs(env); - JVM_REF = nullptr; -} - -struct JNIEnvGuard { - JavaVM *vm; - JNIEnv *env; - bool detach_when_done; - - explicit JNIEnvGuard(JavaVM *vm_p) : vm(vm_p), env(nullptr), detach_when_done(false) { - if (!vm) { - throw InvalidInputException("JVM is not available"); - } - auto get_env_status = vm->GetEnv(reinterpret_cast(&env), JNI_VERSION); - if (get_env_status == JNI_OK) { - return; - } - if (get_env_status != JNI_EDETACHED) { - throw InvalidInputException("Failed to get JNI environment"); - } - auto attach_status = vm->AttachCurrentThread(reinterpret_cast(&env), nullptr); - if (attach_status != JNI_OK || !env) { - throw InvalidInputException("Failed to attach current thread to JVM"); - } - detach_when_done = true; - } - - ~JNIEnvGuard() { - if (detach_when_done && vm) { - vm->DetachCurrentThread(); - } - } -}; - -struct JavaScalarFunctionState { - JavaVM *vm; - jobject callback; - jmethodID apply_method; - - JavaScalarFunctionState(JavaVM *vm_p, jobject callback_p, jmethodID apply_method_p) - : vm(vm_p), callback(callback_p), apply_method(apply_method_p) { - } - - ~JavaScalarFunctionState() { - if (!vm || !callback) { - return; - } - try { - JNIEnvGuard env_guard(vm); - env_guard.env->DeleteGlobalRef(callback); - } catch (...) { - // noop in destructor - } - } -}; - -struct JavaScalarFunctionLocalState { - JavaVM *vm; - JNIEnv *env; - bool detach_when_done; -}; - -static duckdb_scalar_function scalar_function_buf_to_scalar_function(JNIEnv *env, jobject scalar_function_buf) { - if (scalar_function_buf == nullptr) { - env->ThrowNew(J_SQLException, "Invalid scalar function buffer"); - return nullptr; - } - - auto scalar_function = reinterpret_cast(env->GetDirectBufferAddress(scalar_function_buf)); - if (scalar_function == nullptr) { - env->ThrowNew(J_SQLException, "Invalid scalar function"); - return nullptr; - } - return scalar_function; } jobject ProcessVector(JNIEnv *env, Connection *conn_ref, Vector &vec, idx_t row_count); -static string consume_java_exception_message(JNIEnv *env) { - auto throwable = env->ExceptionOccurred(); - if (!throwable) { - return "Java exception"; - } - env->ExceptionClear(); - - string message = "Java exception"; - auto msg = (jstring)env->CallObjectMethod(throwable, J_Throwable_getMessage); - if (!env->ExceptionCheck() && msg) { - message = jstring_to_string(env, msg); - } - if (env->ExceptionCheck()) { - env->ExceptionClear(); - } - - env->DeleteLocalRef(throwable); - if (msg) { - env->DeleteLocalRef(msg); - } - - return message; -} - -static void get_or_attach_jni_env(JavaVM *vm, JNIEnv *&env, bool &detach_when_done) { - if (!vm) { - throw InvalidInputException("JVM is not available"); - } - - detach_when_done = false; - auto get_env_status = vm->GetEnv(reinterpret_cast(&env), JNI_VERSION); - if (get_env_status == JNI_OK) { - return; - } - if (get_env_status != JNI_EDETACHED) { - throw InvalidInputException("Failed to get JNI environment"); - } - - auto attach_status = vm->AttachCurrentThread(reinterpret_cast(&env), nullptr); - if (attach_status != JNI_OK || !env) { - throw InvalidInputException("Failed to attach current thread to JVM"); - } - detach_when_done = true; -} - -static void execute_java_vectorized_scalar_function(JNIEnv *env, JavaScalarFunctionState &state, DataChunk &input, - Vector &output) { - auto row_count = input.size(); - jobject input_chunk_buf = make_ptr_buf(env, &input); - jobject output_vector_buf = make_ptr_buf(env, &output); - auto input_reader = env->NewObject(J_DuckDataChunkReader, J_DuckDataChunkReader_init, input_chunk_buf, - static_cast(row_count)); - if (env->ExceptionCheck()) { - if (input_chunk_buf) { - env->DeleteLocalRef(input_chunk_buf); - } - if (output_vector_buf) { - env->DeleteLocalRef(output_vector_buf); - } - throw InvalidInputException("Could not create DuckDBDataChunkReader: %s", consume_java_exception_message(env)); - } - - auto output_writer = env->NewObject(J_DuckWritableVector, J_DuckWritableVector_init, output_vector_buf, - static_cast(row_count)); - if (env->ExceptionCheck()) { - env->DeleteLocalRef(input_reader); - if (input_chunk_buf) { - env->DeleteLocalRef(input_chunk_buf); - } - if (output_vector_buf) { - env->DeleteLocalRef(output_vector_buf); - } - throw InvalidInputException("Could not create DuckDBWritableVector: %s", consume_java_exception_message(env)); - } - - env->CallVoidMethod(state.callback, state.apply_method, input_reader, static_cast(row_count), output_writer); - - env->DeleteLocalRef(output_writer); - env->DeleteLocalRef(input_reader); - if (input_chunk_buf) { - env->DeleteLocalRef(input_chunk_buf); - } - if (output_vector_buf) { - env->DeleteLocalRef(output_vector_buf); - } - - if (env->ExceptionCheck()) { - throw InvalidInputException("Java scalar function threw exception: %s", consume_java_exception_message(env)); - } -} - -static void destroy_java_scalar_function_state(void *extra_info); -static void init_java_scalar_function_capi(duckdb_init_info info); -static void execute_java_scalar_function_capi(duckdb_function_info info, duckdb_data_chunk input, duckdb_vector output); - -static jmethodID get_scalar_callback_method(JNIEnv *env, jobject function_j, const char *signature, - const char *error_message) { - auto callback_class = env->GetObjectClass(function_j); - auto apply_method = env->GetMethodID(callback_class, "apply", signature); - env->DeleteLocalRef(callback_class); - if (!apply_method || env->ExceptionCheck()) { - consume_java_exception_message(env); - throw InvalidInputException("%s", error_message); - } - return apply_method; -} - -static void install_java_scalar_function_callback(JNIEnv *env, jobject conn_ref_buf, jobject scalar_function_buf, - jobject function_j, const char *signature, - const char *error_message) { - auto connection = get_connection(env, conn_ref_buf); - if (!connection) { - throw InvalidInputException("Invalid connection"); - } - auto scalar_function = scalar_function_buf_to_scalar_function(env, scalar_function_buf); - if (env->ExceptionCheck()) { - return; - } - if (!function_j) { - throw InvalidInputException("Invalid scalar function callback"); - } - - auto callback_ref = env->NewGlobalRef(function_j); - if (!callback_ref) { - throw InvalidInputException("Could not create global reference for scalar function callback"); - } - - try { - auto apply_method = get_scalar_callback_method(env, function_j, signature, error_message); - auto state = new JavaScalarFunctionState(JVM_REF, callback_ref, apply_method); - duckdb_scalar_function_set_extra_info(scalar_function, state, destroy_java_scalar_function_state); - duckdb_scalar_function_set_function(scalar_function, execute_java_scalar_function_capi); - duckdb_scalar_function_set_init(scalar_function, init_java_scalar_function_capi); - } catch (...) { - env->DeleteGlobalRef(callback_ref); - throw; - } -} - -static void destroy_java_scalar_function_state(void *extra_info) { - if (!extra_info) { - return; - } - delete reinterpret_cast(extra_info); -} - -static void destroy_java_scalar_function_local_state(void *state_ptr) { - if (!state_ptr) { - return; - } - - auto state = reinterpret_cast(state_ptr); - if (state->detach_when_done && state->vm) { - state->vm->DetachCurrentThread(); - } - delete state; -} - -static void init_java_scalar_function_capi(duckdb_init_info info) { - JavaScalarFunctionLocalState *local_state = nullptr; - try { - auto state = reinterpret_cast(duckdb_scalar_function_init_get_extra_info(info)); - if (!state) { - duckdb_scalar_function_init_set_error(info, "Invalid Java scalar function callback state"); - return; - } - - local_state = new JavaScalarFunctionLocalState(); - local_state->vm = state->vm; - local_state->env = nullptr; - local_state->detach_when_done = false; - get_or_attach_jni_env(local_state->vm, local_state->env, local_state->detach_when_done); - duckdb_scalar_function_init_set_state(info, local_state, destroy_java_scalar_function_local_state); - local_state = nullptr; - } catch (const std::exception &e) { - if (local_state) { - destroy_java_scalar_function_local_state(local_state); - } - duckdb_scalar_function_init_set_error(info, e.what()); - } -} - -static void execute_java_scalar_function_capi(duckdb_function_info info, duckdb_data_chunk input, - duckdb_vector output) { - auto state = reinterpret_cast(duckdb_scalar_function_get_extra_info(info)); - auto local_state = reinterpret_cast(duckdb_scalar_function_get_state(info)); - if (!state || !local_state || !local_state->env || !input || !output) { - duckdb_scalar_function_set_error(info, "Invalid Java scalar function callback state"); - return; - } - - try { - auto &input_chunk = *reinterpret_cast(input); - auto &output_vector = *reinterpret_cast(output); - execute_java_vectorized_scalar_function(local_state->env, *state, input_chunk, output_vector); - } catch (const std::exception &e) { - duckdb_scalar_function_set_error(info, e.what()); - } -} - //! The database instance cache, used so that multiple connections to the same file point to the same database object duckdb::DBInstanceCache instance_cache; @@ -1177,19 +902,6 @@ void _duckdb_jdbc_arrow_register(JNIEnv *env, jclass, jobject conn_ref_buf, jlon conn->TableFunction("arrow_scan_dumb", parameters)->CreateView(name, true, true); } -extern "C" JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1jdbc_1scalar_1function_1set_1callback( - JNIEnv *env, jclass, jobject conn_ref_buf, jobject scalar_function_buf, jobject function_j) { - try { - install_java_scalar_function_callback(env, conn_ref_buf, scalar_function_buf, function_j, - "(Lorg/duckdb/DuckDBDataChunkReader;ILorg/duckdb/DuckDBWritableVector;)V", - "Could not find apply(DuckDBDataChunkReader, int, DuckDBWritableVector) " - "on scalar function callback"); - } catch (const std::exception &e) { - duckdb::ErrorData error(e); - ThrowJNI(env, error.Message().c_str()); - } -} - void _duckdb_jdbc_create_extension_type(JNIEnv *env, jclass, jobject conn_buf) { auto connection = get_connection(env, conn_buf); diff --git a/src/jni/scalar_functions.cpp b/src/jni/scalar_functions.cpp new file mode 100644 index 000000000..6c3e5682e --- /dev/null +++ b/src/jni/scalar_functions.cpp @@ -0,0 +1,300 @@ +extern "C" { +#include "duckdb.h" +} + +#include "duckdb.hpp" +#include "functions.hpp" +#include "holders.hpp" +#include "refs.hpp" +#include "scalar_functions.hpp" +#include "util.hpp" + +using namespace duckdb; + +struct JNIEnvGuard { + JavaVM *vm; + JNIEnv *env; + bool detach_when_done; + + explicit JNIEnvGuard(JavaVM *vm_p) : vm(vm_p), env(nullptr), detach_when_done(false) { + if (!vm) { + throw InvalidInputException("JVM is not available"); + } + auto get_env_status = vm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6); + if (get_env_status == JNI_OK) { + return; + } + if (get_env_status != JNI_EDETACHED) { + throw InvalidInputException("Failed to get JNI environment"); + } + auto attach_status = vm->AttachCurrentThread(reinterpret_cast(&env), nullptr); + if (attach_status != JNI_OK || !env) { + throw InvalidInputException("Failed to attach current thread to JVM"); + } + detach_when_done = true; + } + + ~JNIEnvGuard() { + if (detach_when_done && vm) { + vm->DetachCurrentThread(); + } + } +}; + +struct JavaScalarFunctionState { + JavaVM *vm; + jobject callback; + jmethodID apply_method; + + JavaScalarFunctionState(JavaVM *vm_p, jobject callback_p, jmethodID apply_method_p) + : vm(vm_p), callback(callback_p), apply_method(apply_method_p) { + } + + ~JavaScalarFunctionState() { + if (!vm || !callback) { + return; + } + try { + JNIEnvGuard env_guard(vm); + env_guard.env->DeleteGlobalRef(callback); + } catch (...) { + // noop in destructor + } + } +}; + +struct JavaScalarFunctionLocalState { + JavaVM *vm; + JNIEnv *env; + bool detach_when_done; +}; + +static duckdb_scalar_function scalar_function_buf_to_scalar_function(JNIEnv *env, jobject scalar_function_buf) { + if (scalar_function_buf == nullptr) { + env->ThrowNew(J_SQLException, "Invalid scalar function buffer"); + return nullptr; + } + + auto scalar_function = reinterpret_cast(env->GetDirectBufferAddress(scalar_function_buf)); + if (scalar_function == nullptr) { + env->ThrowNew(J_SQLException, "Invalid scalar function"); + return nullptr; + } + return scalar_function; +} + +static string consume_java_exception_message(JNIEnv *env) { + auto throwable = env->ExceptionOccurred(); + if (!throwable) { + return "Java exception"; + } + env->ExceptionClear(); + + string message = "Java exception"; + auto msg = (jstring)env->CallObjectMethod(throwable, J_Throwable_getMessage); + if (!env->ExceptionCheck() && msg) { + message = jstring_to_string(env, msg); + } + if (env->ExceptionCheck()) { + env->ExceptionClear(); + } + + env->DeleteLocalRef(throwable); + if (msg) { + env->DeleteLocalRef(msg); + } + + return message; +} + +static void get_or_attach_jni_env(JavaVM *vm, JNIEnv *&env, bool &detach_when_done) { + if (!vm) { + throw InvalidInputException("JVM is not available"); + } + + detach_when_done = false; + auto get_env_status = vm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6); + if (get_env_status == JNI_OK) { + return; + } + if (get_env_status != JNI_EDETACHED) { + throw InvalidInputException("Failed to get JNI environment"); + } + + auto attach_status = vm->AttachCurrentThread(reinterpret_cast(&env), nullptr); + if (attach_status != JNI_OK || !env) { + throw InvalidInputException("Failed to attach current thread to JVM"); + } + detach_when_done = true; +} + +static void execute_java_vectorized_scalar_function(JNIEnv *env, JavaScalarFunctionState &state, DataChunk &input, + Vector &output) { + auto row_count = input.size(); + jobject input_chunk_buf = make_ptr_buf(env, &input); + jobject output_vector_buf = make_ptr_buf(env, &output); + auto input_reader = env->NewObject(J_DuckDataChunkReader, J_DuckDataChunkReader_init, input_chunk_buf, + static_cast(row_count)); + if (env->ExceptionCheck()) { + if (input_chunk_buf) { + env->DeleteLocalRef(input_chunk_buf); + } + if (output_vector_buf) { + env->DeleteLocalRef(output_vector_buf); + } + throw InvalidInputException("Could not create DuckDBDataChunkReader: %s", consume_java_exception_message(env)); + } + + auto output_writer = env->NewObject(J_DuckWritableVector, J_DuckWritableVector_init, output_vector_buf, + static_cast(row_count)); + if (env->ExceptionCheck()) { + env->DeleteLocalRef(input_reader); + if (input_chunk_buf) { + env->DeleteLocalRef(input_chunk_buf); + } + if (output_vector_buf) { + env->DeleteLocalRef(output_vector_buf); + } + throw InvalidInputException("Could not create DuckDBWritableVector: %s", consume_java_exception_message(env)); + } + + env->CallVoidMethod(state.callback, state.apply_method, input_reader, static_cast(row_count), output_writer); + + env->DeleteLocalRef(output_writer); + env->DeleteLocalRef(input_reader); + if (input_chunk_buf) { + env->DeleteLocalRef(input_chunk_buf); + } + if (output_vector_buf) { + env->DeleteLocalRef(output_vector_buf); + } + + if (env->ExceptionCheck()) { + throw InvalidInputException("Java scalar function threw exception: %s", consume_java_exception_message(env)); + } +} + +static void destroy_java_scalar_function_state(void *extra_info); +static void init_java_scalar_function_capi(duckdb_init_info info); +static void execute_java_scalar_function_capi(duckdb_function_info info, duckdb_data_chunk input, duckdb_vector output); + +static jmethodID get_scalar_callback_method(JNIEnv *env, jobject function_j, const char *signature, + const char *error_message) { + auto callback_class = env->GetObjectClass(function_j); + auto apply_method = env->GetMethodID(callback_class, "apply", signature); + env->DeleteLocalRef(callback_class); + if (!apply_method || env->ExceptionCheck()) { + consume_java_exception_message(env); + throw InvalidInputException("%s", error_message); + } + return apply_method; +} + +void duckdb_jdbc_install_scalar_function_callback(JNIEnv *env, jobject conn_ref_buf, jobject scalar_function_buf, + jobject function_j) { + auto connection = get_connection(env, conn_ref_buf); + if (!connection) { + throw InvalidInputException("Invalid connection"); + } + auto scalar_function = scalar_function_buf_to_scalar_function(env, scalar_function_buf); + if (env->ExceptionCheck()) { + return; + } + if (!function_j) { + throw InvalidInputException("Invalid scalar function callback"); + } + + JavaVM *vm = nullptr; + if (env->GetJavaVM(&vm) != JNI_OK || !vm) { + throw InvalidInputException("Failed to get JVM reference"); + } + + auto callback_ref = env->NewGlobalRef(function_j); + if (!callback_ref) { + throw InvalidInputException("Could not create global reference for scalar function callback"); + } + + try { + auto apply_method = get_scalar_callback_method( + env, function_j, "(Lorg/duckdb/DuckDBDataChunkReader;ILorg/duckdb/DuckDBWritableVector;)V", + "Could not find apply(DuckDBDataChunkReader, int, DuckDBWritableVector) on scalar function callback"); + auto state = new JavaScalarFunctionState(vm, callback_ref, apply_method); + duckdb_scalar_function_set_extra_info(scalar_function, state, destroy_java_scalar_function_state); + duckdb_scalar_function_set_function(scalar_function, execute_java_scalar_function_capi); + duckdb_scalar_function_set_init(scalar_function, init_java_scalar_function_capi); + } catch (...) { + env->DeleteGlobalRef(callback_ref); + throw; + } +} + +static void destroy_java_scalar_function_state(void *extra_info) { + if (!extra_info) { + return; + } + delete reinterpret_cast(extra_info); +} + +static void destroy_java_scalar_function_local_state(void *state_ptr) { + if (!state_ptr) { + return; + } + + auto state = reinterpret_cast(state_ptr); + if (state->detach_when_done && state->vm) { + state->vm->DetachCurrentThread(); + } + delete state; +} + +static void init_java_scalar_function_capi(duckdb_init_info info) { + JavaScalarFunctionLocalState *local_state = nullptr; + try { + auto state = reinterpret_cast(duckdb_scalar_function_init_get_extra_info(info)); + if (!state) { + duckdb_scalar_function_init_set_error(info, "Invalid Java scalar function callback state"); + return; + } + + local_state = new JavaScalarFunctionLocalState(); + local_state->vm = state->vm; + local_state->env = nullptr; + local_state->detach_when_done = false; + get_or_attach_jni_env(local_state->vm, local_state->env, local_state->detach_when_done); + duckdb_scalar_function_init_set_state(info, local_state, destroy_java_scalar_function_local_state); + local_state = nullptr; + } catch (const std::exception &e) { + if (local_state) { + destroy_java_scalar_function_local_state(local_state); + } + duckdb_scalar_function_init_set_error(info, e.what()); + } +} + +static void execute_java_scalar_function_capi(duckdb_function_info info, duckdb_data_chunk input, + duckdb_vector output) { + auto state = reinterpret_cast(duckdb_scalar_function_get_extra_info(info)); + auto local_state = reinterpret_cast(duckdb_scalar_function_get_state(info)); + if (!state || !local_state || !local_state->env || !input || !output) { + duckdb_scalar_function_set_error(info, "Invalid Java scalar function callback state"); + return; + } + + try { + auto &input_chunk = *reinterpret_cast(input); + auto &output_vector = *reinterpret_cast(output); + execute_java_vectorized_scalar_function(local_state->env, *state, input_chunk, output_vector); + } catch (const std::exception &e) { + duckdb_scalar_function_set_error(info, e.what()); + } +} + +extern "C" JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1jdbc_1scalar_1function_1set_1callback( + JNIEnv *env, jclass, jobject conn_ref_buf, jobject scalar_function_buf, jobject function_j) { + try { + duckdb_jdbc_install_scalar_function_callback(env, conn_ref_buf, scalar_function_buf, function_j); + } catch (const std::exception &e) { + duckdb::ErrorData error(e); + ThrowJNI(env, error.Message().c_str()); + } +} diff --git a/src/jni/scalar_functions.hpp b/src/jni/scalar_functions.hpp new file mode 100644 index 000000000..0ff981a6d --- /dev/null +++ b/src/jni/scalar_functions.hpp @@ -0,0 +1,6 @@ +#pragma once + +#include "bindings.hpp" + +void duckdb_jdbc_install_scalar_function_callback(JNIEnv *env, jobject conn_ref_buf, jobject scalar_function_buf, + jobject function_j); diff --git a/src/main/java/org/duckdb/DuckDBBindings.java b/src/main/java/org/duckdb/DuckDBBindings.java index 3e0d00b0e..1f1a80e6b 100644 --- a/src/main/java/org/duckdb/DuckDBBindings.java +++ b/src/main/java/org/duckdb/DuckDBBindings.java @@ -41,6 +41,8 @@ static native byte[] duckdb_jdbc_varchar_string_bytes(ByteBuffer vectorData, Byt static native ByteBuffer duckdb_create_logical_type(int duckdb_type); + static native ByteBuffer duckdb_create_decimal_type(int width, int scale); + static native ByteBuffer duckdb_jdbc_parse_logical_type(ByteBuffer connection, byte[] type_name); static native int duckdb_get_type_id(ByteBuffer logical_type); diff --git a/src/main/java/org/duckdb/DuckDBConnection.java b/src/main/java/org/duckdb/DuckDBConnection.java index 5acbeaa4a..ae4f36242 100644 --- a/src/main/java/org/duckdb/DuckDBConnection.java +++ b/src/main/java/org/duckdb/DuckDBConnection.java @@ -503,10 +503,40 @@ public void registerArrowStream(String name, Object arrow_array_stream) { public void registerScalarFunction(String name, String[] parameterTypes, String returnType, DuckDBVectorizedScalarFunction function) throws SQLException { checkOpen(); + if (parameterTypes == null) { + throw new SQLException("Parameter types cannot be null"); + } + DuckDBLogicalType[] parsedParameterTypes = new DuckDBLogicalType[parameterTypes.length]; + DuckDBLogicalType parsedReturnType = null; + try { + if (name == null || name.trim().isEmpty()) { + throw new SQLException("Function name cannot be null or empty"); + } + for (int i = 0; i < parameterTypes.length; i++) { + String parameterType = parameterTypes[i]; + if (parameterType == null || parameterType.trim().isEmpty()) { + throw new SQLException("Parameter type at index " + i + " cannot be null or empty"); + } + parsedParameterTypes[i] = parseStringLogicalType(parameterType); + } + if (returnType == null || returnType.trim().isEmpty()) { + throw new SQLException("Return type cannot be null or empty"); + } + parsedReturnType = parseStringLogicalType(returnType); + registerScalarFunction(name, parsedParameterTypes, parsedReturnType, function); + } finally { + closeLogicalType(parsedReturnType); + for (DuckDBLogicalType parameterType : parsedParameterTypes) { + closeLogicalType(parameterType); + } + } + } + + public void registerScalarFunction(String name, DuckDBLogicalType[] parameterTypes, DuckDBLogicalType returnType, + DuckDBVectorizedScalarFunction function) throws SQLException { + checkOpen(); connRefLock.lock(); ByteBuffer scalarFunction = null; - ByteBuffer returnLogicalType = null; - ByteBuffer[] parameterLogicalTypes = null; try { checkOpen(); if (name == null || name.trim().isEmpty()) { @@ -516,13 +546,12 @@ public void registerScalarFunction(String name, String[] parameterTypes, String throw new SQLException("Parameter types cannot be null"); } for (int i = 0; i < parameterTypes.length; i++) { - String parameterType = parameterTypes[i]; - if (parameterType == null || parameterType.trim().isEmpty()) { - throw new SQLException("Parameter type at index " + i + " cannot be null or empty"); + if (parameterTypes[i] == null) { + throw new SQLException("Parameter type at index " + i + " cannot be null"); } } - if (returnType == null || returnType.trim().isEmpty()) { - throw new SQLException("Return type cannot be null or empty"); + if (returnType == null) { + throw new SQLException("Return type cannot be null"); } if (function == null) { throw new SQLException("Scalar function callback cannot be null"); @@ -531,31 +560,17 @@ public void registerScalarFunction(String name, String[] parameterTypes, String scalarFunction = DuckDBBindings.duckdb_create_scalar_function(); DuckDBBindings.duckdb_scalar_function_set_name(scalarFunction, name.getBytes(UTF_8)); - parameterLogicalTypes = new ByteBuffer[parameterTypes.length]; for (int i = 0; i < parameterTypes.length; i++) { - parameterLogicalTypes[i] = - DuckDBBindings.duckdb_jdbc_parse_logical_type(connRef, parameterTypes[i].getBytes(UTF_8)); - DuckDBBindings.duckdb_scalar_function_add_parameter(scalarFunction, parameterLogicalTypes[i]); + DuckDBBindings.duckdb_scalar_function_add_parameter(scalarFunction, parameterTypes[i].logicalTypeRef()); } - returnLogicalType = DuckDBBindings.duckdb_jdbc_parse_logical_type(connRef, returnType.getBytes(UTF_8)); - DuckDBBindings.duckdb_scalar_function_set_return_type(scalarFunction, returnLogicalType); + DuckDBBindings.duckdb_scalar_function_set_return_type(scalarFunction, returnType.logicalTypeRef()); DuckDBBindings.duckdb_jdbc_scalar_function_set_callback(connRef, scalarFunction, function); if (DuckDBBindings.duckdb_register_scalar_function(connRef, scalarFunction) != 0) { throw new SQLException("Failed to register scalar function '" + name + "'"); } } finally { - if (returnLogicalType != null) { - DuckDBBindings.duckdb_destroy_logical_type(returnLogicalType); - } - if (parameterLogicalTypes != null) { - for (ByteBuffer parameterLogicalType : parameterLogicalTypes) { - if (parameterLogicalType != null) { - DuckDBBindings.duckdb_destroy_logical_type(parameterLogicalType); - } - } - } if (scalarFunction != null) { DuckDBBindings.duckdb_destroy_scalar_function(scalarFunction); } @@ -582,6 +597,31 @@ public String getSessionInitSQL() throws SQLException { return sessionInitSQL; } + private static void closeLogicalType(DuckDBLogicalType logicalType) { + if (logicalType != null) { + logicalType.close(); + } + } + + private DuckDBLogicalType parseStringLogicalType(String typeName) throws SQLException { + try { + return DuckDBLogicalType.parse(typeName); + } catch (SQLException javaParseError) { + connRefLock.lock(); + try { + checkOpen(); + ByteBuffer parsedType = + DuckDBBindings.duckdb_jdbc_parse_logical_type(connRef, typeName.getBytes(UTF_8)); + return DuckDBLogicalType.fromLogicalTypeRef(parsedType); + } catch (SQLException nativeParseError) { + nativeParseError.addSuppressed(javaParseError); + throw nativeParseError; + } finally { + connRefLock.unlock(); + } + } + } + void checkOpen() throws SQLException { if (isClosed()) { throw new SQLException("Connection was closed"); diff --git a/src/main/java/org/duckdb/DuckDBLogicalType.java b/src/main/java/org/duckdb/DuckDBLogicalType.java new file mode 100644 index 000000000..2615c79ef --- /dev/null +++ b/src/main/java/org/duckdb/DuckDBLogicalType.java @@ -0,0 +1,172 @@ +package org.duckdb; + +import static org.duckdb.DuckDBBindings.*; + +import java.nio.ByteBuffer; +import java.sql.SQLException; +import java.util.Locale; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +public final class DuckDBLogicalType implements AutoCloseable { + private static final Pattern DECIMAL_PATTERN = + Pattern.compile("^(DECIMAL|NUMERIC)\\s*\\(\\s*(\\d+)\\s*(?:,\\s*(\\d+)\\s*)?\\)$", Pattern.CASE_INSENSITIVE); + + private ByteBuffer logicalTypeRef; + + private DuckDBLogicalType(ByteBuffer logicalTypeRef) throws SQLException { + if (logicalTypeRef == null) { + throw new SQLException("Failed to create logical type"); + } + this.logicalTypeRef = logicalTypeRef; + } + + public static DuckDBLogicalType of(DuckDBColumnType type) throws SQLException { + if (type == null) { + throw new SQLException("Logical type cannot be null"); + } + switch (type) { + case BOOLEAN: + return createPrimitive(DuckDBBindings.CAPIType.DUCKDB_TYPE_BOOLEAN); + case TINYINT: + return createPrimitive(DuckDBBindings.CAPIType.DUCKDB_TYPE_TINYINT); + case SMALLINT: + return createPrimitive(DuckDBBindings.CAPIType.DUCKDB_TYPE_SMALLINT); + case INTEGER: + return createPrimitive(DuckDBBindings.CAPIType.DUCKDB_TYPE_INTEGER); + case BIGINT: + return createPrimitive(DuckDBBindings.CAPIType.DUCKDB_TYPE_BIGINT); + case UTINYINT: + return createPrimitive(DuckDBBindings.CAPIType.DUCKDB_TYPE_UTINYINT); + case USMALLINT: + return createPrimitive(DuckDBBindings.CAPIType.DUCKDB_TYPE_USMALLINT); + case UINTEGER: + return createPrimitive(DuckDBBindings.CAPIType.DUCKDB_TYPE_UINTEGER); + case UBIGINT: + return createPrimitive(DuckDBBindings.CAPIType.DUCKDB_TYPE_UBIGINT); + case FLOAT: + return createPrimitive(DuckDBBindings.CAPIType.DUCKDB_TYPE_FLOAT); + case DOUBLE: + return createPrimitive(DuckDBBindings.CAPIType.DUCKDB_TYPE_DOUBLE); + case VARCHAR: + return createPrimitive(DuckDBBindings.CAPIType.DUCKDB_TYPE_VARCHAR); + case DATE: + return createPrimitive(DuckDBBindings.CAPIType.DUCKDB_TYPE_DATE); + case TIMESTAMP_S: + return createPrimitive(DuckDBBindings.CAPIType.DUCKDB_TYPE_TIMESTAMP_S); + case TIMESTAMP_MS: + return createPrimitive(DuckDBBindings.CAPIType.DUCKDB_TYPE_TIMESTAMP_MS); + case TIMESTAMP: + return createPrimitive(DuckDBBindings.CAPIType.DUCKDB_TYPE_TIMESTAMP); + case TIMESTAMP_NS: + return createPrimitive(DuckDBBindings.CAPIType.DUCKDB_TYPE_TIMESTAMP_NS); + case TIMESTAMP_WITH_TIME_ZONE: + return createPrimitive(DuckDBBindings.CAPIType.DUCKDB_TYPE_TIMESTAMP_TZ); + default: + throw new SQLException("Unsupported logical type for scalar UDF registration: " + type); + } + } + + public static DuckDBLogicalType decimal(int width, int scale) throws SQLException { + if (width < 1 || width > 38) { + throw new SQLException("DECIMAL width must be between 1 and 38, got: " + width); + } + if (scale < 0 || scale > width) { + throw new SQLException("DECIMAL scale must be between 0 and width, got: " + scale); + } + return new DuckDBLogicalType(duckdb_create_decimal_type(width, scale)); + } + + public static DuckDBLogicalType parse(String typeName) throws SQLException { + if (typeName == null) { + throw new SQLException("Logical type cannot be null"); + } + + String normalized = normalizeTypeName(typeName); + Matcher decimalMatcher = DECIMAL_PATTERN.matcher(normalized); + if (decimalMatcher.matches()) { + try { + int width = Integer.parseInt(decimalMatcher.group(2)); + String scaleText = decimalMatcher.group(3); + int scale = scaleText == null ? 0 : Integer.parseInt(scaleText); + return decimal(width, scale); + } catch (NumberFormatException e) { + throw new SQLException("Invalid DECIMAL precision/scale: " + typeName, e); + } + } + + switch (normalized) { + case "BOOLEAN": + case "BOOL": + return of(DuckDBColumnType.BOOLEAN); + case "TINYINT": + return of(DuckDBColumnType.TINYINT); + case "SMALLINT": + return of(DuckDBColumnType.SMALLINT); + case "INTEGER": + case "INT": + return of(DuckDBColumnType.INTEGER); + case "BIGINT": + return of(DuckDBColumnType.BIGINT); + case "UTINYINT": + return of(DuckDBColumnType.UTINYINT); + case "USMALLINT": + return of(DuckDBColumnType.USMALLINT); + case "UINTEGER": + return of(DuckDBColumnType.UINTEGER); + case "UBIGINT": + return of(DuckDBColumnType.UBIGINT); + case "FLOAT": + case "REAL": + return of(DuckDBColumnType.FLOAT); + case "DOUBLE": + return of(DuckDBColumnType.DOUBLE); + case "VARCHAR": + case "TEXT": + case "STRING": + return of(DuckDBColumnType.VARCHAR); + case "DATE": + return of(DuckDBColumnType.DATE); + case "TIMESTAMP": + return of(DuckDBColumnType.TIMESTAMP); + case "TIMESTAMP_S": + return of(DuckDBColumnType.TIMESTAMP_S); + case "TIMESTAMP_MS": + return of(DuckDBColumnType.TIMESTAMP_MS); + case "TIMESTAMP_NS": + return of(DuckDBColumnType.TIMESTAMP_NS); + case "TIMESTAMPTZ": + case "TIMESTAMP WITH TIME ZONE": + return of(DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE); + default: + throw new SQLException("Unsupported scalar UDF logical type: " + typeName); + } + } + + ByteBuffer logicalTypeRef() throws SQLException { + if (logicalTypeRef == null) { + throw new SQLException("Logical type is already closed"); + } + return logicalTypeRef; + } + + static DuckDBLogicalType fromLogicalTypeRef(ByteBuffer logicalTypeRef) throws SQLException { + return new DuckDBLogicalType(logicalTypeRef); + } + + @Override + public void close() { + if (logicalTypeRef != null) { + duckdb_destroy_logical_type(logicalTypeRef); + logicalTypeRef = null; + } + } + + private static DuckDBLogicalType createPrimitive(DuckDBBindings.CAPIType type) throws SQLException { + return new DuckDBLogicalType(duckdb_create_logical_type(type.typeId)); + } + + private static String normalizeTypeName(String typeName) { + return typeName.trim().replaceAll("\\s+", " ").toUpperCase(Locale.ROOT); + } +} diff --git a/src/test/java/org/duckdb/TestBindings.java b/src/test/java/org/duckdb/TestBindings.java index 1fc3a8d9e..4e329cbc7 100644 --- a/src/test/java/org/duckdb/TestBindings.java +++ b/src/test/java/org/duckdb/TestBindings.java @@ -45,19 +45,31 @@ public static void test_bindings_logical_type() throws Exception { } public static void test_bindings_parse_logical_type() throws Exception { - ByteBuffer integerType = duckdb_jdbc_parse_logical_type(null, "INTEGER".getBytes(UTF_8)); + try (DuckDBLogicalType integerType = DuckDBLogicalType.parse("INTEGER")) { + assertNotNull(integerType); + assertEquals(DUCKDB_TYPE_INTEGER.typeId, duckdb_get_type_id(integerType.logicalTypeRef())); + } + + try (DuckDBLogicalType decimalType = DuckDBLogicalType.parse("DECIMAL(18,3)")) { + assertNotNull(decimalType); + ByteBuffer decimalRef = decimalType.logicalTypeRef(); + assertEquals(DUCKDB_TYPE_DECIMAL.typeId, duckdb_get_type_id(decimalRef)); + assertEquals(18, duckdb_decimal_width(decimalRef)); + assertEquals(3, duckdb_decimal_scale(decimalRef)); + assertEquals(DUCKDB_TYPE_BIGINT.typeId, duckdb_decimal_internal_type(decimalRef)); + } + + assertThrows(() -> { DuckDBLogicalType.parse("MOOD"); }, SQLException.class); + assertThrows(() -> { DuckDBLogicalType.parse("DECIMAL(999999999999999999999999999,0)"); }, SQLException.class); + assertThrows(() -> { DuckDBLogicalType.parse(null); }, SQLException.class); + } + + public static void test_bindings_parse_logical_type_native_fallback() throws Exception { + ByteBuffer integerType = duckdb_jdbc_parse_logical_type(null, "INT4".getBytes(UTF_8)); assertNotNull(integerType); assertEquals(DUCKDB_TYPE_INTEGER.typeId, duckdb_get_type_id(integerType)); duckdb_destroy_logical_type(integerType); - ByteBuffer decimalType = duckdb_jdbc_parse_logical_type(null, "DECIMAL(18,3)".getBytes(UTF_8)); - assertNotNull(decimalType); - assertEquals(DUCKDB_TYPE_DECIMAL.typeId, duckdb_get_type_id(decimalType)); - assertEquals(18, duckdb_decimal_width(decimalType)); - assertEquals(3, duckdb_decimal_scale(decimalType)); - assertEquals(DUCKDB_TYPE_BIGINT.typeId, duckdb_decimal_internal_type(decimalType)); - duckdb_destroy_logical_type(decimalType); - try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); Statement stmt = conn.createStatement()) { stmt.execute("CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy')"); @@ -445,6 +457,13 @@ public static void test_bindings_decimal_type() throws Exception { } } + public static void test_bindings_decimal_type_validation() throws Exception { + assertThrows(() -> { duckdb_create_decimal_type(0, 0); }, SQLException.class); + assertThrows(() -> { duckdb_create_decimal_type(39, 0); }, SQLException.class); + assertThrows(() -> { duckdb_create_decimal_type(10, -1); }, SQLException.class); + assertThrows(() -> { duckdb_create_decimal_type(10, 11); }, SQLException.class); + } + public static void test_bindings_enum_type() throws Exception { try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); Statement stmt = conn.createStatement()) { diff --git a/src/test/java/org/duckdb/TestScalarFunctions.java b/src/test/java/org/duckdb/TestScalarFunctions.java index 8359d5d76..b8ffdade5 100644 --- a/src/test/java/org/duckdb/TestScalarFunctions.java +++ b/src/test/java/org/duckdb/TestScalarFunctions.java @@ -53,6 +53,34 @@ public static void test_register_scalar_function() throws Exception { test_register_scalar_function_integer(); } + public static void test_register_scalar_function_typed_logical_type() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement(); + DuckDBLogicalType intType = DuckDBLogicalType.of(DuckDBColumnType.INTEGER)) { + conn.registerScalarFunction("java_add_int_typed", new DuckDBLogicalType[] {intType}, intType, + (input, rowCount, out) -> { + DuckDBReadableVector in = input.vector(0); + for (int i = 0; i < rowCount; i++) { + if (in.isNull(i)) { + out.setNull(i); + } else { + out.setInt(i, in.getInt(i) + 1); + } + } + }); + try (ResultSet rs = + stmt.executeQuery("SELECT java_add_int_typed(v) FROM (VALUES (1), (NULL), (41)) t(v)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 2); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 42); + assertFalse(rs.next()); + } + } + } + public static void test_register_scalar_function_parallel() throws Exception { try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); Statement stmt = conn.createStatement()) { @@ -175,6 +203,31 @@ public static void test_register_scalar_function_integer() throws Exception { }); } + public static void test_register_scalar_function_integer_alias_parse_fallback() throws Exception { + assertUnaryScalarFunction("java_add_int_alias", "INT4", "INT4", + (input, rowCount, out) + -> { + DuckDBReadableVector in = input.vector(0); + for (int i = 0; i < rowCount; i++) { + if (in.isNull(i)) { + out.setNull(i); + } else { + out.setInt(i, in.getInt(i) + 1); + } + } + }, + "SELECT java_add_int_alias(v) FROM (VALUES (1), (NULL), (41)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 2); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 42); + assertFalse(rs.next()); + }); + } + public static void test_register_scalar_function_integer_revalidates_after_null() throws Exception { assertUnaryScalarFunction("java_revalidate_int", "INTEGER", "INTEGER", (input, rowCount, out) From c83bdcf43c81f9951cb3c057eda1bdddac3c2f48 Mon Sep 17 00:00:00 2001 From: Luis Fernando Kauer Date: Wed, 1 Apr 2026 11:51:51 -0300 Subject: [PATCH 3/9] Align Java scalar UDF registration with C API-first wrapper flow --- UDF.MD | 56 +++-- duckdb_java.def | 4 +- duckdb_java.exp | 4 +- duckdb_java.map | 4 +- src/jni/bindings_logical_type.cpp | 37 ---- src/jni/bindings_scalar_function.cpp | 41 ++++ src/jni/scalar_functions.cpp | 104 +++------- src/jni/scalar_functions.hpp | 4 +- src/main/java/org/duckdb/DuckDBBindings.java | 8 +- .../java/org/duckdb/DuckDBConnection.java | 60 +----- .../java/org/duckdb/DuckDBLogicalType.java | 80 -------- .../duckdb/DuckDBScalarFunctionWrapper.java | 32 +++ src/test/java/org/duckdb/TestBindings.java | 35 +--- .../java/org/duckdb/TestScalarFunctions.java | 193 +++++++++--------- 14 files changed, 251 insertions(+), 411 deletions(-) create mode 100644 src/main/java/org/duckdb/DuckDBScalarFunctionWrapper.java diff --git a/UDF.MD b/UDF.MD index 7fed83857..270a3d33e 100644 --- a/UDF.MD +++ b/UDF.MD @@ -2,17 +2,6 @@ Use `DuckDBConnection.registerScalarFunction` to register a vectorized scalar function in Java. -```java -void registerScalarFunction( - String name, - String[] parameterTypes, - String returnType, - DuckDBVectorizedScalarFunction function -) throws SQLException -``` - -You can also use the typed overload: - ```java void registerScalarFunction( String name, @@ -23,8 +12,7 @@ void registerScalarFunction( ``` Notes: -- The string overload accepts SQL types like `INTEGER`, `VARCHAR`, `TIMESTAMP`, `DECIMAL(18,2)`. -- The typed overload uses `DuckDBLogicalType` (C API style) and avoids SQL text parsing. +- The API uses typed logical types (`DuckDBLogicalType`) instead of SQL type strings. - The callback is vectorized: process `rowCount` rows from the input chunk and write one value per row into `out`. - `DuckDBDataChunkReader` / `DuckDBReadableVector` / `DuckDBWritableVector` are valid only during the callback. @@ -32,7 +20,8 @@ Notes: ```java try (DuckDBConnection conn = DriverManager.getConnection("jdbc:duckdb:").unwrap(DuckDBConnection.class)) { - conn.registerScalarFunction("java_add_one", new String[] {"INTEGER"}, "INTEGER", (input, rowCount, out) -> { + try (DuckDBLogicalType intType = DuckDBLogicalType.of(DuckDBColumnType.INTEGER)) { + conn.registerScalarFunction("java_add_one", new DuckDBLogicalType[] {intType}, intType, (input, rowCount, out) -> { DuckDBReadableVector in = input.vector(0); for (int i = 0; i < rowCount; i++) { if (in.isNull(i)) { @@ -42,6 +31,7 @@ try (DuckDBConnection conn = DriverManager.getConnection("jdbc:duckdb:").unwrap( } } }); + } } ``` @@ -51,25 +41,29 @@ Build a label from `TIMESTAMP` + `VARCHAR` + `DOUBLE`, preserving `NULL` behavio ```java try (DuckDBConnection conn = DriverManager.getConnection("jdbc:duckdb:").unwrap(DuckDBConnection.class)) { - conn.registerScalarFunction( - "java_event_label", - new String[] {"TIMESTAMP", "VARCHAR", "DOUBLE"}, - "VARCHAR", - (input, rowCount, out) -> { - DuckDBReadableVector ts = input.vector(0); - DuckDBReadableVector tag = input.vector(1); - DuckDBReadableVector score = input.vector(2); + try (DuckDBLogicalType tsType = DuckDBLogicalType.of(DuckDBColumnType.TIMESTAMP); + DuckDBLogicalType strType = DuckDBLogicalType.of(DuckDBColumnType.VARCHAR); + DuckDBLogicalType dblType = DuckDBLogicalType.of(DuckDBColumnType.DOUBLE)) { + conn.registerScalarFunction( + "java_event_label", + new DuckDBLogicalType[] {tsType, strType, dblType}, + strType, + (input, rowCount, out) -> { + DuckDBReadableVector ts = input.vector(0); + DuckDBReadableVector tag = input.vector(1); + DuckDBReadableVector score = input.vector(2); - for (int i = 0; i < rowCount; i++) { - if (ts.isNull(i) || tag.isNull(i) || score.isNull(i)) { - out.setNull(i); - continue; + for (int i = 0; i < rowCount; i++) { + if (ts.isNull(i) || tag.isNull(i) || score.isNull(i)) { + out.setNull(i); + continue; + } + String value = + ts.getLocalDateTime(i) + " | " + tag.getString(i).trim().toUpperCase() + " | " + score.getDouble(i); + out.setString(i, value); } - String value = - ts.getLocalDateTime(i) + " | " + tag.getString(i).trim().toUpperCase() + " | " + score.getDouble(i); - out.setString(i, value); } - } - ); + ); + } } ``` diff --git a/duckdb_java.def b/duckdb_java.def index bd2e5f6d8..dc9e5a4ae 100644 --- a/duckdb_java.def +++ b/duckdb_java.def @@ -58,10 +58,10 @@ Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1name Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1add_1parameter Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1return_1type Java_org_duckdb_DuckDBBindings_duckdb_1register_1scalar_1function +Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1error Java_org_duckdb_DuckDBBindings_duckdb_1jdbc_1varchar_1string_1bytes Java_org_duckdb_DuckDBBindings_duckdb_1create_1logical_1type Java_org_duckdb_DuckDBBindings_duckdb_1create_1decimal_1type -Java_org_duckdb_DuckDBBindings_duckdb_1jdbc_1parse_1logical_1type Java_org_duckdb_DuckDBBindings_duckdb_1get_1type_1id Java_org_duckdb_DuckDBBindings_duckdb_1decimal_1width Java_org_duckdb_DuckDBBindings_duckdb_1decimal_1scale @@ -107,7 +107,7 @@ Java_org_duckdb_DuckDBBindings_duckdb_1appender_1column_1count Java_org_duckdb_DuckDBBindings_duckdb_1appender_1column_1type Java_org_duckdb_DuckDBBindings_duckdb_1append_1data_1chunk Java_org_duckdb_DuckDBBindings_duckdb_1append_1default_1to_1chunk -Java_org_duckdb_DuckDBBindings_duckdb_1jdbc_1scalar_1function_1set_1callback +Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1function duckdb_adbc_init duckdb_add_aggregate_function_to_set diff --git a/duckdb_java.exp b/duckdb_java.exp index 39d462620..cad64ce4e 100644 --- a/duckdb_java.exp +++ b/duckdb_java.exp @@ -55,11 +55,11 @@ _Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1name _Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1add_1parameter _Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1return_1type _Java_org_duckdb_DuckDBBindings_duckdb_1register_1scalar_1function +_Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1error _Java_org_duckdb_DuckDBBindings_duckdb_1jdbc_1varchar_1string_1bytes _Java_org_duckdb_DuckDBBindings_duckdb_1create_1logical_1type _Java_org_duckdb_DuckDBBindings_duckdb_1create_1decimal_1type -_Java_org_duckdb_DuckDBBindings_duckdb_1jdbc_1parse_1logical_1type -_Java_org_duckdb_DuckDBBindings_duckdb_1jdbc_1scalar_1function_1set_1callback +_Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1function _Java_org_duckdb_DuckDBBindings_duckdb_1get_1type_1id _Java_org_duckdb_DuckDBBindings_duckdb_1decimal_1width _Java_org_duckdb_DuckDBBindings_duckdb_1decimal_1scale diff --git a/duckdb_java.map b/duckdb_java.map index cdf25f812..e8179a6b3 100644 --- a/duckdb_java.map +++ b/duckdb_java.map @@ -57,10 +57,10 @@ DUCKDB_JAVA { Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1add_1parameter; Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1return_1type; Java_org_duckdb_DuckDBBindings_duckdb_1register_1scalar_1function; + Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1error; Java_org_duckdb_DuckDBBindings_duckdb_1jdbc_1varchar_1string_1bytes; Java_org_duckdb_DuckDBBindings_duckdb_1create_1logical_1type; Java_org_duckdb_DuckDBBindings_duckdb_1create_1decimal_1type; - Java_org_duckdb_DuckDBBindings_duckdb_1jdbc_1parse_1logical_1type; Java_org_duckdb_DuckDBBindings_duckdb_1get_1type_1id; Java_org_duckdb_DuckDBBindings_duckdb_1decimal_1width; Java_org_duckdb_DuckDBBindings_duckdb_1decimal_1scale; @@ -106,7 +106,7 @@ DUCKDB_JAVA { Java_org_duckdb_DuckDBBindings_duckdb_1appender_1column_1type; Java_org_duckdb_DuckDBBindings_duckdb_1append_1data_1chunk; Java_org_duckdb_DuckDBBindings_duckdb_1append_1default_1to_1chunk; - Java_org_duckdb_DuckDBBindings_duckdb_1jdbc_1scalar_1function_1set_1callback; + Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1function; duckdb_adbc_init; duckdb_add_aggregate_function_to_set; diff --git a/src/jni/bindings_logical_type.cpp b/src/jni/bindings_logical_type.cpp index a72750e9d..302bcb160 100644 --- a/src/jni/bindings_logical_type.cpp +++ b/src/jni/bindings_logical_type.cpp @@ -1,6 +1,4 @@ #include "bindings.hpp" -#include "duckdb/common/types.hpp" -#include "holders.hpp" #include "refs.hpp" #include "util.hpp" @@ -58,41 +56,6 @@ JNIEXPORT jobject JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1create_1decimal return make_ptr_buf(env, lt); } -/* - * Class: org_duckdb_DuckDBBindings - * Method: duckdb_jdbc_parse_logical_type - * Signature: (Ljava/nio/ByteBuffer;[B)Ljava/nio/ByteBuffer; - */ -JNIEXPORT jobject JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1jdbc_1parse_1logical_1type(JNIEnv *env, jclass, - jobject connection, - jbyteArray type_name) { - - if (type_name == nullptr) { - env->ThrowNew(J_SQLException, "Invalid logical type name"); - return nullptr; - } - - try { - auto sql_type_name = jbyteArray_to_string(env, type_name); - duckdb::LogicalType logical_type; - if (connection) { - auto conn = get_connection(env, connection); - if (env->ExceptionCheck()) { - return nullptr; - } - conn->context->RunFunctionInTransaction( - [&]() { logical_type = duckdb::TransformStringToLogicalType(sql_type_name, *conn->context); }); - } else { - logical_type = duckdb::TransformStringToLogicalType(sql_type_name); - } - return make_ptr_buf(env, - reinterpret_cast(new duckdb::LogicalType(std::move(logical_type)))); - } catch (const std::exception &e) { - env->ThrowNew(J_SQLException, e.what()); - return nullptr; - } -} - /* * Class: org_duckdb_DuckDBBindings * Method: duckdb_get_type_id diff --git a/src/jni/bindings_scalar_function.cpp b/src/jni/bindings_scalar_function.cpp index c59e1387a..f84a8f6f0 100644 --- a/src/jni/bindings_scalar_function.cpp +++ b/src/jni/bindings_scalar_function.cpp @@ -1,6 +1,8 @@ #include "bindings.hpp" +#include "functions.hpp" #include "holders.hpp" #include "refs.hpp" +#include "scalar_functions.hpp" #include "util.hpp" static duckdb_scalar_function scalar_function_buf_to_scalar_function(JNIEnv *env, jobject scalar_function_buf) { @@ -20,6 +22,20 @@ static duckdb_scalar_function scalar_function_buf_to_scalar_function(JNIEnv *env return scalar_function; } +static duckdb_function_info function_info_buf_to_function_info(JNIEnv *env, jobject function_info_buf) { + if (function_info_buf == nullptr) { + env->ThrowNew(J_SQLException, "Invalid scalar function info buffer"); + return nullptr; + } + + auto function_info = reinterpret_cast(env->GetDirectBufferAddress(function_info_buf)); + if (function_info == nullptr) { + env->ThrowNew(J_SQLException, "Invalid scalar function info"); + return nullptr; + } + return function_info; +} + JNIEXPORT jobject JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1create_1scalar_1function(JNIEnv *env, jclass) { return make_ptr_buf(env, duckdb_create_scalar_function()); } @@ -88,3 +104,28 @@ JNIEXPORT jint JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1register_1scalar_1 } return static_cast(duckdb_register_scalar_function(conn, function)); } + +JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1function( + JNIEnv *env, jclass, jobject conn_ref_buf, jobject scalar_function_buf, jobject function_j) { + try { + duckdb_jdbc_scalar_function_set_function(env, conn_ref_buf, scalar_function_buf, function_j); + } catch (const std::exception &e) { + duckdb::ErrorData error(e); + ThrowJNI(env, error.Message().c_str()); + } +} + +JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1error(JNIEnv *env, jclass, + jobject function_info_buf, + jbyteArray error) { + auto function_info = function_info_buf_to_function_info(env, function_info_buf); + if (env->ExceptionCheck()) { + return; + } + if (error == nullptr) { + env->ThrowNew(J_SQLException, "Invalid scalar function error"); + return; + } + auto error_message = jbyteArray_to_string(env, error); + duckdb_scalar_function_set_error(function_info, error_message.c_str()); +} diff --git a/src/jni/scalar_functions.cpp b/src/jni/scalar_functions.cpp index 6c3e5682e..bb77f5abf 100644 --- a/src/jni/scalar_functions.cpp +++ b/src/jni/scalar_functions.cpp @@ -2,15 +2,12 @@ extern "C" { #include "duckdb.h" } -#include "duckdb.hpp" -#include "functions.hpp" +#include "duckdb/common/exception.hpp" #include "holders.hpp" #include "refs.hpp" #include "scalar_functions.hpp" #include "util.hpp" -using namespace duckdb; - struct JNIEnvGuard { JavaVM *vm; JNIEnv *env; @@ -18,18 +15,18 @@ struct JNIEnvGuard { explicit JNIEnvGuard(JavaVM *vm_p) : vm(vm_p), env(nullptr), detach_when_done(false) { if (!vm) { - throw InvalidInputException("JVM is not available"); + throw duckdb::InvalidInputException("JVM is not available"); } auto get_env_status = vm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6); if (get_env_status == JNI_OK) { return; } if (get_env_status != JNI_EDETACHED) { - throw InvalidInputException("Failed to get JNI environment"); + throw duckdb::InvalidInputException("Failed to get JNI environment"); } auto attach_status = vm->AttachCurrentThread(reinterpret_cast(&env), nullptr); if (attach_status != JNI_OK || !env) { - throw InvalidInputException("Failed to attach current thread to JVM"); + throw duckdb::InvalidInputException("Failed to attach current thread to JVM"); } detach_when_done = true; } @@ -83,14 +80,14 @@ static duckdb_scalar_function scalar_function_buf_to_scalar_function(JNIEnv *env return scalar_function; } -static string consume_java_exception_message(JNIEnv *env) { +static std::string consume_java_exception_message(JNIEnv *env) { auto throwable = env->ExceptionOccurred(); if (!throwable) { return "Java exception"; } env->ExceptionClear(); - string message = "Java exception"; + std::string message = "Java exception"; auto msg = (jstring)env->CallObjectMethod(throwable, J_Throwable_getMessage); if (!env->ExceptionCheck() && msg) { message = jstring_to_string(env, msg); @@ -109,7 +106,7 @@ static string consume_java_exception_message(JNIEnv *env) { static void get_or_attach_jni_env(JavaVM *vm, JNIEnv *&env, bool &detach_when_done) { if (!vm) { - throw InvalidInputException("JVM is not available"); + throw duckdb::InvalidInputException("JVM is not available"); } detach_when_done = false; @@ -118,50 +115,28 @@ static void get_or_attach_jni_env(JavaVM *vm, JNIEnv *&env, bool &detach_when_do return; } if (get_env_status != JNI_EDETACHED) { - throw InvalidInputException("Failed to get JNI environment"); + throw duckdb::InvalidInputException("Failed to get JNI environment"); } auto attach_status = vm->AttachCurrentThread(reinterpret_cast(&env), nullptr); if (attach_status != JNI_OK || !env) { - throw InvalidInputException("Failed to attach current thread to JVM"); + throw duckdb::InvalidInputException("Failed to attach current thread to JVM"); } detach_when_done = true; } -static void execute_java_vectorized_scalar_function(JNIEnv *env, JavaScalarFunctionState &state, DataChunk &input, - Vector &output) { - auto row_count = input.size(); - jobject input_chunk_buf = make_ptr_buf(env, &input); - jobject output_vector_buf = make_ptr_buf(env, &output); - auto input_reader = env->NewObject(J_DuckDataChunkReader, J_DuckDataChunkReader_init, input_chunk_buf, - static_cast(row_count)); - if (env->ExceptionCheck()) { - if (input_chunk_buf) { - env->DeleteLocalRef(input_chunk_buf); - } - if (output_vector_buf) { - env->DeleteLocalRef(output_vector_buf); - } - throw InvalidInputException("Could not create DuckDBDataChunkReader: %s", consume_java_exception_message(env)); - } - - auto output_writer = env->NewObject(J_DuckWritableVector, J_DuckWritableVector_init, output_vector_buf, - static_cast(row_count)); - if (env->ExceptionCheck()) { - env->DeleteLocalRef(input_reader); - if (input_chunk_buf) { - env->DeleteLocalRef(input_chunk_buf); - } - if (output_vector_buf) { - env->DeleteLocalRef(output_vector_buf); - } - throw InvalidInputException("Could not create DuckDBWritableVector: %s", consume_java_exception_message(env)); +static void execute_java_vectorized_scalar_function(JNIEnv *env, JavaScalarFunctionState &state, + duckdb_function_info info, duckdb_data_chunk input, + duckdb_vector output) { + auto row_count = duckdb_data_chunk_get_size(input); + jobject function_info_buf = make_ptr_buf(env, info); + jobject input_chunk_buf = make_ptr_buf(env, input); + jobject output_vector_buf = make_ptr_buf(env, output); + env->CallVoidMethod(state.callback, state.apply_method, function_info_buf, input_chunk_buf, + static_cast(row_count), output_vector_buf); + if (function_info_buf) { + env->DeleteLocalRef(function_info_buf); } - - env->CallVoidMethod(state.callback, state.apply_method, input_reader, static_cast(row_count), output_writer); - - env->DeleteLocalRef(output_writer); - env->DeleteLocalRef(input_reader); if (input_chunk_buf) { env->DeleteLocalRef(input_chunk_buf); } @@ -170,7 +145,8 @@ static void execute_java_vectorized_scalar_function(JNIEnv *env, JavaScalarFunct } if (env->ExceptionCheck()) { - throw InvalidInputException("Java scalar function threw exception: %s", consume_java_exception_message(env)); + throw duckdb::InvalidInputException("Java scalar function wrapper threw exception: %s", + consume_java_exception_message(env)); } } @@ -179,45 +155,45 @@ static void init_java_scalar_function_capi(duckdb_init_info info); static void execute_java_scalar_function_capi(duckdb_function_info info, duckdb_data_chunk input, duckdb_vector output); static jmethodID get_scalar_callback_method(JNIEnv *env, jobject function_j, const char *signature, - const char *error_message) { + const char *method_name, const char *error_message) { auto callback_class = env->GetObjectClass(function_j); - auto apply_method = env->GetMethodID(callback_class, "apply", signature); + auto apply_method = env->GetMethodID(callback_class, method_name, signature); env->DeleteLocalRef(callback_class); if (!apply_method || env->ExceptionCheck()) { consume_java_exception_message(env); - throw InvalidInputException("%s", error_message); + throw duckdb::InvalidInputException("%s", error_message); } return apply_method; } -void duckdb_jdbc_install_scalar_function_callback(JNIEnv *env, jobject conn_ref_buf, jobject scalar_function_buf, - jobject function_j) { +void duckdb_jdbc_scalar_function_set_function(JNIEnv *env, jobject conn_ref_buf, jobject scalar_function_buf, + jobject function_j) { auto connection = get_connection(env, conn_ref_buf); if (!connection) { - throw InvalidInputException("Invalid connection"); + throw duckdb::InvalidInputException("Invalid connection"); } auto scalar_function = scalar_function_buf_to_scalar_function(env, scalar_function_buf); if (env->ExceptionCheck()) { return; } if (!function_j) { - throw InvalidInputException("Invalid scalar function callback"); + throw duckdb::InvalidInputException("Invalid scalar function callback"); } JavaVM *vm = nullptr; if (env->GetJavaVM(&vm) != JNI_OK || !vm) { - throw InvalidInputException("Failed to get JVM reference"); + throw duckdb::InvalidInputException("Failed to get JVM reference"); } auto callback_ref = env->NewGlobalRef(function_j); if (!callback_ref) { - throw InvalidInputException("Could not create global reference for scalar function callback"); + throw duckdb::InvalidInputException("Could not create global reference for scalar function callback"); } try { auto apply_method = get_scalar_callback_method( - env, function_j, "(Lorg/duckdb/DuckDBDataChunkReader;ILorg/duckdb/DuckDBWritableVector;)V", - "Could not find apply(DuckDBDataChunkReader, int, DuckDBWritableVector) on scalar function callback"); + env, function_j, "(Ljava/nio/ByteBuffer;Ljava/nio/ByteBuffer;ILjava/nio/ByteBuffer;)V", "execute", + "Could not find execute(ByteBuffer, ByteBuffer, int, ByteBuffer) on scalar function callback"); auto state = new JavaScalarFunctionState(vm, callback_ref, apply_method); duckdb_scalar_function_set_extra_info(scalar_function, state, destroy_java_scalar_function_state); duckdb_scalar_function_set_function(scalar_function, execute_java_scalar_function_capi); @@ -281,20 +257,8 @@ static void execute_java_scalar_function_capi(duckdb_function_info info, duckdb_ } try { - auto &input_chunk = *reinterpret_cast(input); - auto &output_vector = *reinterpret_cast(output); - execute_java_vectorized_scalar_function(local_state->env, *state, input_chunk, output_vector); + execute_java_vectorized_scalar_function(local_state->env, *state, info, input, output); } catch (const std::exception &e) { duckdb_scalar_function_set_error(info, e.what()); } } - -extern "C" JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1jdbc_1scalar_1function_1set_1callback( - JNIEnv *env, jclass, jobject conn_ref_buf, jobject scalar_function_buf, jobject function_j) { - try { - duckdb_jdbc_install_scalar_function_callback(env, conn_ref_buf, scalar_function_buf, function_j); - } catch (const std::exception &e) { - duckdb::ErrorData error(e); - ThrowJNI(env, error.Message().c_str()); - } -} diff --git a/src/jni/scalar_functions.hpp b/src/jni/scalar_functions.hpp index 0ff981a6d..d5bea4265 100644 --- a/src/jni/scalar_functions.hpp +++ b/src/jni/scalar_functions.hpp @@ -2,5 +2,5 @@ #include "bindings.hpp" -void duckdb_jdbc_install_scalar_function_callback(JNIEnv *env, jobject conn_ref_buf, jobject scalar_function_buf, - jobject function_j); +void duckdb_jdbc_scalar_function_set_function(JNIEnv *env, jobject conn_ref_buf, jobject scalar_function_buf, + jobject function_j); diff --git a/src/main/java/org/duckdb/DuckDBBindings.java b/src/main/java/org/duckdb/DuckDBBindings.java index 1f1a80e6b..56a4c5689 100644 --- a/src/main/java/org/duckdb/DuckDBBindings.java +++ b/src/main/java/org/duckdb/DuckDBBindings.java @@ -31,8 +31,10 @@ public class DuckDBBindings { static native int duckdb_register_scalar_function(ByteBuffer connection, ByteBuffer scalarFunction); - static native void duckdb_jdbc_scalar_function_set_callback(ByteBuffer connection, ByteBuffer scalarFunction, - DuckDBVectorizedScalarFunction function); + static native void duckdb_scalar_function_set_function(ByteBuffer connection, ByteBuffer scalarFunction, + Object function); + + static native void duckdb_scalar_function_set_error(ByteBuffer functionInfo, byte[] error); static native byte[] duckdb_jdbc_varchar_string_bytes(ByteBuffer vectorData, ByteBuffer validity, long rowCount, long row); @@ -43,8 +45,6 @@ static native byte[] duckdb_jdbc_varchar_string_bytes(ByteBuffer vectorData, Byt static native ByteBuffer duckdb_create_decimal_type(int width, int scale); - static native ByteBuffer duckdb_jdbc_parse_logical_type(ByteBuffer connection, byte[] type_name); - static native int duckdb_get_type_id(ByteBuffer logical_type); static native int duckdb_decimal_width(ByteBuffer logical_type); diff --git a/src/main/java/org/duckdb/DuckDBConnection.java b/src/main/java/org/duckdb/DuckDBConnection.java index ae4f36242..8c010dea6 100644 --- a/src/main/java/org/duckdb/DuckDBConnection.java +++ b/src/main/java/org/duckdb/DuckDBConnection.java @@ -500,38 +500,6 @@ public void registerArrowStream(String name, Object arrow_array_stream) { } } - public void registerScalarFunction(String name, String[] parameterTypes, String returnType, - DuckDBVectorizedScalarFunction function) throws SQLException { - checkOpen(); - if (parameterTypes == null) { - throw new SQLException("Parameter types cannot be null"); - } - DuckDBLogicalType[] parsedParameterTypes = new DuckDBLogicalType[parameterTypes.length]; - DuckDBLogicalType parsedReturnType = null; - try { - if (name == null || name.trim().isEmpty()) { - throw new SQLException("Function name cannot be null or empty"); - } - for (int i = 0; i < parameterTypes.length; i++) { - String parameterType = parameterTypes[i]; - if (parameterType == null || parameterType.trim().isEmpty()) { - throw new SQLException("Parameter type at index " + i + " cannot be null or empty"); - } - parsedParameterTypes[i] = parseStringLogicalType(parameterType); - } - if (returnType == null || returnType.trim().isEmpty()) { - throw new SQLException("Return type cannot be null or empty"); - } - parsedReturnType = parseStringLogicalType(returnType); - registerScalarFunction(name, parsedParameterTypes, parsedReturnType, function); - } finally { - closeLogicalType(parsedReturnType); - for (DuckDBLogicalType parameterType : parsedParameterTypes) { - closeLogicalType(parameterType); - } - } - } - public void registerScalarFunction(String name, DuckDBLogicalType[] parameterTypes, DuckDBLogicalType returnType, DuckDBVectorizedScalarFunction function) throws SQLException { checkOpen(); @@ -565,7 +533,8 @@ public void registerScalarFunction(String name, DuckDBLogicalType[] parameterTyp } DuckDBBindings.duckdb_scalar_function_set_return_type(scalarFunction, returnType.logicalTypeRef()); - DuckDBBindings.duckdb_jdbc_scalar_function_set_callback(connRef, scalarFunction, function); + DuckDBBindings.duckdb_scalar_function_set_function(connRef, scalarFunction, + new DuckDBScalarFunctionWrapper(function)); if (DuckDBBindings.duckdb_register_scalar_function(connRef, scalarFunction) != 0) { throw new SQLException("Failed to register scalar function '" + name + "'"); @@ -597,31 +566,6 @@ public String getSessionInitSQL() throws SQLException { return sessionInitSQL; } - private static void closeLogicalType(DuckDBLogicalType logicalType) { - if (logicalType != null) { - logicalType.close(); - } - } - - private DuckDBLogicalType parseStringLogicalType(String typeName) throws SQLException { - try { - return DuckDBLogicalType.parse(typeName); - } catch (SQLException javaParseError) { - connRefLock.lock(); - try { - checkOpen(); - ByteBuffer parsedType = - DuckDBBindings.duckdb_jdbc_parse_logical_type(connRef, typeName.getBytes(UTF_8)); - return DuckDBLogicalType.fromLogicalTypeRef(parsedType); - } catch (SQLException nativeParseError) { - nativeParseError.addSuppressed(javaParseError); - throw nativeParseError; - } finally { - connRefLock.unlock(); - } - } - } - void checkOpen() throws SQLException { if (isClosed()) { throw new SQLException("Connection was closed"); diff --git a/src/main/java/org/duckdb/DuckDBLogicalType.java b/src/main/java/org/duckdb/DuckDBLogicalType.java index 2615c79ef..1578e6ba2 100644 --- a/src/main/java/org/duckdb/DuckDBLogicalType.java +++ b/src/main/java/org/duckdb/DuckDBLogicalType.java @@ -4,14 +4,8 @@ import java.nio.ByteBuffer; import java.sql.SQLException; -import java.util.Locale; -import java.util.regex.Matcher; -import java.util.regex.Pattern; public final class DuckDBLogicalType implements AutoCloseable { - private static final Pattern DECIMAL_PATTERN = - Pattern.compile("^(DECIMAL|NUMERIC)\\s*\\(\\s*(\\d+)\\s*(?:,\\s*(\\d+)\\s*)?\\)$", Pattern.CASE_INSENSITIVE); - private ByteBuffer logicalTypeRef; private DuckDBLogicalType(ByteBuffer logicalTypeRef) throws SQLException { @@ -77,72 +71,6 @@ public static DuckDBLogicalType decimal(int width, int scale) throws SQLExceptio return new DuckDBLogicalType(duckdb_create_decimal_type(width, scale)); } - public static DuckDBLogicalType parse(String typeName) throws SQLException { - if (typeName == null) { - throw new SQLException("Logical type cannot be null"); - } - - String normalized = normalizeTypeName(typeName); - Matcher decimalMatcher = DECIMAL_PATTERN.matcher(normalized); - if (decimalMatcher.matches()) { - try { - int width = Integer.parseInt(decimalMatcher.group(2)); - String scaleText = decimalMatcher.group(3); - int scale = scaleText == null ? 0 : Integer.parseInt(scaleText); - return decimal(width, scale); - } catch (NumberFormatException e) { - throw new SQLException("Invalid DECIMAL precision/scale: " + typeName, e); - } - } - - switch (normalized) { - case "BOOLEAN": - case "BOOL": - return of(DuckDBColumnType.BOOLEAN); - case "TINYINT": - return of(DuckDBColumnType.TINYINT); - case "SMALLINT": - return of(DuckDBColumnType.SMALLINT); - case "INTEGER": - case "INT": - return of(DuckDBColumnType.INTEGER); - case "BIGINT": - return of(DuckDBColumnType.BIGINT); - case "UTINYINT": - return of(DuckDBColumnType.UTINYINT); - case "USMALLINT": - return of(DuckDBColumnType.USMALLINT); - case "UINTEGER": - return of(DuckDBColumnType.UINTEGER); - case "UBIGINT": - return of(DuckDBColumnType.UBIGINT); - case "FLOAT": - case "REAL": - return of(DuckDBColumnType.FLOAT); - case "DOUBLE": - return of(DuckDBColumnType.DOUBLE); - case "VARCHAR": - case "TEXT": - case "STRING": - return of(DuckDBColumnType.VARCHAR); - case "DATE": - return of(DuckDBColumnType.DATE); - case "TIMESTAMP": - return of(DuckDBColumnType.TIMESTAMP); - case "TIMESTAMP_S": - return of(DuckDBColumnType.TIMESTAMP_S); - case "TIMESTAMP_MS": - return of(DuckDBColumnType.TIMESTAMP_MS); - case "TIMESTAMP_NS": - return of(DuckDBColumnType.TIMESTAMP_NS); - case "TIMESTAMPTZ": - case "TIMESTAMP WITH TIME ZONE": - return of(DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE); - default: - throw new SQLException("Unsupported scalar UDF logical type: " + typeName); - } - } - ByteBuffer logicalTypeRef() throws SQLException { if (logicalTypeRef == null) { throw new SQLException("Logical type is already closed"); @@ -150,10 +78,6 @@ ByteBuffer logicalTypeRef() throws SQLException { return logicalTypeRef; } - static DuckDBLogicalType fromLogicalTypeRef(ByteBuffer logicalTypeRef) throws SQLException { - return new DuckDBLogicalType(logicalTypeRef); - } - @Override public void close() { if (logicalTypeRef != null) { @@ -165,8 +89,4 @@ public void close() { private static DuckDBLogicalType createPrimitive(DuckDBBindings.CAPIType type) throws SQLException { return new DuckDBLogicalType(duckdb_create_logical_type(type.typeId)); } - - private static String normalizeTypeName(String typeName) { - return typeName.trim().replaceAll("\\s+", " ").toUpperCase(Locale.ROOT); - } } diff --git a/src/main/java/org/duckdb/DuckDBScalarFunctionWrapper.java b/src/main/java/org/duckdb/DuckDBScalarFunctionWrapper.java new file mode 100644 index 000000000..736a91b3b --- /dev/null +++ b/src/main/java/org/duckdb/DuckDBScalarFunctionWrapper.java @@ -0,0 +1,32 @@ +package org.duckdb; + +import static java.nio.charset.StandardCharsets.UTF_8; + +import java.nio.ByteBuffer; + +final class DuckDBScalarFunctionWrapper { + private final DuckDBVectorizedScalarFunction function; + + DuckDBScalarFunctionWrapper(DuckDBVectorizedScalarFunction function) { + this.function = function; + } + + public void execute(ByteBuffer functionInfo, ByteBuffer inputChunk, int rowCount, ByteBuffer outputVector) { + try { + DuckDBDataChunkReader inputReader = new DuckDBDataChunkReader(inputChunk, rowCount); + DuckDBWritableVector outputWriter = new DuckDBWritableVector(outputVector, rowCount); + function.apply(inputReader, rowCount, outputWriter); + } catch (Throwable throwable) { + reportError(functionInfo, throwable); + } + } + + private static void reportError(ByteBuffer functionInfo, Throwable throwable) { + String message = throwable.getMessage(); + String className = throwable.getClass().getName(); + String formatted = + message == null || message.isEmpty() ? className : String.format("%s: %s", className, message); + String error = "Java scalar function threw exception: " + formatted; + DuckDBBindings.duckdb_scalar_function_set_error(functionInfo, error.getBytes(UTF_8)); + } +} diff --git a/src/test/java/org/duckdb/TestBindings.java b/src/test/java/org/duckdb/TestBindings.java index 4e329cbc7..ca8b2431f 100644 --- a/src/test/java/org/duckdb/TestBindings.java +++ b/src/test/java/org/duckdb/TestBindings.java @@ -45,12 +45,12 @@ public static void test_bindings_logical_type() throws Exception { } public static void test_bindings_parse_logical_type() throws Exception { - try (DuckDBLogicalType integerType = DuckDBLogicalType.parse("INTEGER")) { + try (DuckDBLogicalType integerType = DuckDBLogicalType.of(DuckDBColumnType.INTEGER)) { assertNotNull(integerType); assertEquals(DUCKDB_TYPE_INTEGER.typeId, duckdb_get_type_id(integerType.logicalTypeRef())); } - try (DuckDBLogicalType decimalType = DuckDBLogicalType.parse("DECIMAL(18,3)")) { + try (DuckDBLogicalType decimalType = DuckDBLogicalType.decimal(18, 3)) { assertNotNull(decimalType); ByteBuffer decimalRef = decimalType.logicalTypeRef(); assertEquals(DUCKDB_TYPE_DECIMAL.typeId, duckdb_get_type_id(decimalRef)); @@ -59,34 +59,9 @@ public static void test_bindings_parse_logical_type() throws Exception { assertEquals(DUCKDB_TYPE_BIGINT.typeId, duckdb_decimal_internal_type(decimalRef)); } - assertThrows(() -> { DuckDBLogicalType.parse("MOOD"); }, SQLException.class); - assertThrows(() -> { DuckDBLogicalType.parse("DECIMAL(999999999999999999999999999,0)"); }, SQLException.class); - assertThrows(() -> { DuckDBLogicalType.parse(null); }, SQLException.class); - } - - public static void test_bindings_parse_logical_type_native_fallback() throws Exception { - ByteBuffer integerType = duckdb_jdbc_parse_logical_type(null, "INT4".getBytes(UTF_8)); - assertNotNull(integerType); - assertEquals(DUCKDB_TYPE_INTEGER.typeId, duckdb_get_type_id(integerType)); - duckdb_destroy_logical_type(integerType); - - try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); - Statement stmt = conn.createStatement()) { - stmt.execute("CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy')"); - - ByteBuffer enumType = duckdb_jdbc_parse_logical_type(conn.connRef, "mood".getBytes(UTF_8)); - assertNotNull(enumType); - assertEquals(DUCKDB_TYPE_ENUM.typeId, duckdb_get_type_id(enumType)); - assertEquals(3L, duckdb_enum_dictionary_size(enumType)); - assertEquals("sad".getBytes(UTF_8), duckdb_enum_dictionary_value(enumType, 0)); - duckdb_destroy_logical_type(enumType); - - assertThrows(() -> { - duckdb_jdbc_parse_logical_type(conn.connRef, "missing_type".getBytes(UTF_8)); - }, SQLException.class); - } - - assertThrows(() -> { duckdb_jdbc_parse_logical_type(null, null); }, SQLException.class); + assertThrows(() -> { DuckDBLogicalType.of(null); }, SQLException.class); + assertThrows(() -> { DuckDBLogicalType.decimal(39, 0); }, SQLException.class); + assertThrows(() -> { DuckDBLogicalType.decimal(10, 11); }, SQLException.class); } public static void test_bindings_vector_create() throws Exception { diff --git a/src/test/java/org/duckdb/TestScalarFunctions.java b/src/test/java/org/duckdb/TestScalarFunctions.java index b8ffdade5..bd35d5447 100644 --- a/src/test/java/org/duckdb/TestScalarFunctions.java +++ b/src/test/java/org/duckdb/TestScalarFunctions.java @@ -83,9 +83,10 @@ public static void test_register_scalar_function_typed_logical_type() throws Exc public static void test_register_scalar_function_parallel() throws Exception { try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); - Statement stmt = conn.createStatement()) { + Statement stmt = conn.createStatement(); + DuckDBLogicalType bigintType = DuckDBLogicalType.of(DuckDBColumnType.BIGINT)) { stmt.execute("PRAGMA threads=4"); - conn.registerScalarFunction("java_add_one_bigint", new String[] {"BIGINT"}, "BIGINT", + conn.registerScalarFunction("java_add_one_bigint", new DuckDBLogicalType[] {bigintType}, bigintType, (input, rowCount, out) -> { DuckDBReadableVector in = input.vector(0); for (int i = 0; i < rowCount; i++) { @@ -102,8 +103,22 @@ public static void test_register_scalar_function_parallel() throws Exception { } } + public static void test_register_scalar_function_exception_propagation() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement(); + DuckDBLogicalType intType = DuckDBLogicalType.of(DuckDBColumnType.INTEGER)) { + conn.registerScalarFunction("java_throws_exception", new DuckDBLogicalType[] {intType}, intType, + (input, rowCount, out) -> { throw new IllegalStateException("boom"); }); + String message = + assertThrows(() -> { stmt.executeQuery("SELECT java_throws_exception(1)"); }, SQLException.class); + assertTrue(message.contains("Java scalar function threw exception")); + assertTrue(message.contains("IllegalStateException")); + assertTrue(message.contains("boom")); + } + } + public static void test_register_scalar_function_boolean() throws Exception { - assertUnaryScalarFunction("java_not_bool", "BOOLEAN", "BOOLEAN", + assertUnaryScalarFunction("java_not_bool", DuckDBColumnType.BOOLEAN, DuckDBColumnType.BOOLEAN, (input, rowCount, out) -> { DuckDBReadableVector in = input.vector(0); @@ -128,7 +143,7 @@ public static void test_register_scalar_function_boolean() throws Exception { } public static void test_register_scalar_function_tinyint() throws Exception { - assertUnaryScalarFunction("java_add_tinyint", "TINYINT", "TINYINT", + assertUnaryScalarFunction("java_add_tinyint", DuckDBColumnType.TINYINT, DuckDBColumnType.TINYINT, (input, rowCount, out) -> { DuckDBReadableVector in = input.vector(0); @@ -154,7 +169,7 @@ public static void test_register_scalar_function_tinyint() throws Exception { public static void test_register_scalar_function_smallint() throws Exception { assertUnaryScalarFunction( - "java_add_smallint", "SMALLINT", "SMALLINT", + "java_add_smallint", DuckDBColumnType.SMALLINT, DuckDBColumnType.SMALLINT, (input, rowCount, out) -> { DuckDBReadableVector in = input.vector(0); @@ -179,7 +194,7 @@ public static void test_register_scalar_function_smallint() throws Exception { } public static void test_register_scalar_function_integer() throws Exception { - assertUnaryScalarFunction("java_add_int", "INTEGER", "INTEGER", + assertUnaryScalarFunction("java_add_int", DuckDBColumnType.INTEGER, DuckDBColumnType.INTEGER, (input, rowCount, out) -> { DuckDBReadableVector in = input.vector(0); @@ -203,33 +218,8 @@ public static void test_register_scalar_function_integer() throws Exception { }); } - public static void test_register_scalar_function_integer_alias_parse_fallback() throws Exception { - assertUnaryScalarFunction("java_add_int_alias", "INT4", "INT4", - (input, rowCount, out) - -> { - DuckDBReadableVector in = input.vector(0); - for (int i = 0; i < rowCount; i++) { - if (in.isNull(i)) { - out.setNull(i); - } else { - out.setInt(i, in.getInt(i) + 1); - } - } - }, - "SELECT java_add_int_alias(v) FROM (VALUES (1), (NULL), (41)) t(v)", - rs -> { - assertTrue(rs.next()); - assertEquals(rs.getObject(1, Integer.class), 2); - assertTrue(rs.next()); - assertNullRow(rs); - assertTrue(rs.next()); - assertEquals(rs.getObject(1, Integer.class), 42); - assertFalse(rs.next()); - }); - } - public static void test_register_scalar_function_integer_revalidates_after_null() throws Exception { - assertUnaryScalarFunction("java_revalidate_int", "INTEGER", "INTEGER", + assertUnaryScalarFunction("java_revalidate_int", DuckDBColumnType.INTEGER, DuckDBColumnType.INTEGER, (input, rowCount, out) -> { DuckDBReadableVector in = input.vector(0); @@ -254,7 +244,7 @@ public static void test_register_scalar_function_integer_revalidates_after_null( } public static void test_register_scalar_function_bigint() throws Exception { - assertUnaryScalarFunction("java_add_bigint", "BIGINT", "BIGINT", + assertUnaryScalarFunction("java_add_bigint", DuckDBColumnType.BIGINT, DuckDBColumnType.BIGINT, (input, rowCount, out) -> { DuckDBReadableVector in = input.vector(0); @@ -280,7 +270,7 @@ public static void test_register_scalar_function_bigint() throws Exception { public static void test_register_scalar_function_utinyint() throws Exception { assertUnaryScalarFunction( - "java_add_utinyint", "UTINYINT", "UTINYINT", + "java_add_utinyint", DuckDBColumnType.UTINYINT, DuckDBColumnType.UTINYINT, (input, rowCount, out) -> { DuckDBReadableVector in = input.vector(0); @@ -306,7 +296,7 @@ public static void test_register_scalar_function_utinyint() throws Exception { public static void test_register_scalar_function_usmallint() throws Exception { assertUnaryScalarFunction( - "java_add_usmallint", "USMALLINT", "USMALLINT", + "java_add_usmallint", DuckDBColumnType.USMALLINT, DuckDBColumnType.USMALLINT, (input, rowCount, out) -> { DuckDBReadableVector in = input.vector(0); @@ -332,7 +322,7 @@ public static void test_register_scalar_function_usmallint() throws Exception { public static void test_register_scalar_function_uinteger() throws Exception { assertUnaryScalarFunction( - "java_add_uinteger", "UINTEGER", "UINTEGER", + "java_add_uinteger", DuckDBColumnType.UINTEGER, DuckDBColumnType.UINTEGER, (input, rowCount, out) -> { DuckDBReadableVector in = input.vector(0); @@ -357,7 +347,7 @@ public static void test_register_scalar_function_uinteger() throws Exception { } public static void test_register_scalar_function_ubigint() throws Exception { - assertUnaryScalarFunction("java_add_ubigint", "UBIGINT", "UBIGINT", + assertUnaryScalarFunction("java_add_ubigint", DuckDBColumnType.UBIGINT, DuckDBColumnType.UBIGINT, (input, rowCount, out) -> { DuckDBReadableVector in = input.vector(0); @@ -385,7 +375,7 @@ public static void test_register_scalar_function_ubigint() throws Exception { } public static void test_register_scalar_function_float() throws Exception { - assertUnaryScalarFunction("java_add_float", "FLOAT", "FLOAT", + assertUnaryScalarFunction("java_add_float", DuckDBColumnType.FLOAT, DuckDBColumnType.FLOAT, (input, rowCount, out) -> { DuckDBReadableVector in = input.vector(0); @@ -410,7 +400,7 @@ public static void test_register_scalar_function_float() throws Exception { } public static void test_register_scalar_function_double() throws Exception { - assertUnaryScalarFunction("java_add_double", "DOUBLE", "DOUBLE", + assertUnaryScalarFunction("java_add_double", DuckDBColumnType.DOUBLE, DuckDBColumnType.DOUBLE, (input, rowCount, out) -> { DuckDBReadableVector in = input.vector(0); @@ -435,40 +425,43 @@ public static void test_register_scalar_function_double() throws Exception { } public static void test_register_scalar_function_decimal() throws Exception { - assertUnaryScalarFunction("java_add_decimal", "DECIMAL(38,10)", "DECIMAL(38,10)", - (input, rowCount, out) - -> { - DuckDBReadableVector in = input.vector(0); - BigDecimal increment = new BigDecimal("0.0000000001"); - for (int i = 0; i < rowCount; i++) { - if (in.isNull(i)) { - out.setNull(i); - } else { - out.setBigDecimal(i, in.getBigDecimal(i).add(increment)); + try (DuckDBLogicalType decimalType = DuckDBLogicalType.decimal(38, 10)) { + assertUnaryScalarFunction("java_add_decimal", decimalType, decimalType, + (input, rowCount, out) + -> { + DuckDBReadableVector in = input.vector(0); + BigDecimal increment = new BigDecimal("0.0000000001"); + for (int i = 0; i < rowCount; i++) { + if (in.isNull(i)) { + out.setNull(i); + } else { + out.setBigDecimal(i, in.getBigDecimal(i).add(increment)); + } } - } - }, - "SELECT java_add_decimal(v) FROM (VALUES " - + "(CAST('12345678901234567890.1234567890' AS DECIMAL(38,10))), " - + "(NULL), " - + "(CAST('-0.0000000001' AS DECIMAL(38,10)))) t(v)", - rs -> { - assertTrue(rs.next()); - assertEquals(rs.getObject(1, BigDecimal.class), - new BigDecimal("12345678901234567890.1234567891")); - assertTrue(rs.next()); - assertNullRow(rs); - assertTrue(rs.next()); - assertEquals(rs.getObject(1, BigDecimal.class), BigDecimal.ZERO.setScale(10)); - assertFalse(rs.next()); - }); + }, + "SELECT java_add_decimal(v) FROM (VALUES " + + "(CAST('12345678901234567890.1234567890' AS DECIMAL(38,10))), " + + "(NULL), " + + "(CAST('-0.0000000001' AS DECIMAL(38,10)))) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, BigDecimal.class), + new BigDecimal("12345678901234567890.1234567891")); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, BigDecimal.class), BigDecimal.ZERO.setScale(10)); + assertFalse(rs.next()); + }); + } } public static void test_register_scalar_function_decimal_precision_overflow() throws Exception { try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); - Statement stmt = conn.createStatement()) { - conn.registerScalarFunction("java_decimal_precision_overflow", new String[] {"DECIMAL(10,2)"}, - "DECIMAL(10,2)", (input, rowCount, out) -> { + Statement stmt = conn.createStatement(); + DuckDBLogicalType decimalType = DuckDBLogicalType.decimal(10, 2)) { + conn.registerScalarFunction("java_decimal_precision_overflow", new DuckDBLogicalType[] {decimalType}, + decimalType, (input, rowCount, out) -> { for (int i = 0; i < rowCount; i++) { out.setBigDecimal(i, new BigDecimal("12345678901.23")); } @@ -483,9 +476,10 @@ public static void test_register_scalar_function_decimal_precision_overflow() th public static void test_register_scalar_function_decimal_scale_overflow() throws Exception { try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); - Statement stmt = conn.createStatement()) { - conn.registerScalarFunction("java_decimal_scale_overflow", new String[] {"DECIMAL(10,2)"}, "DECIMAL(10,2)", - (input, rowCount, out) -> { + Statement stmt = conn.createStatement(); + DuckDBLogicalType decimalType = DuckDBLogicalType.decimal(10, 2)) { + conn.registerScalarFunction("java_decimal_scale_overflow", new DuckDBLogicalType[] {decimalType}, + decimalType, (input, rowCount, out) -> { for (int i = 0; i < rowCount; i++) { out.setBigDecimal(i, new BigDecimal("1.234")); } @@ -500,7 +494,7 @@ public static void test_register_scalar_function_decimal_scale_overflow() throws public static void test_register_scalar_function_date() throws Exception { assertUnaryScalarFunction( - "java_add_date", "DATE", "DATE", + "java_add_date", DuckDBColumnType.DATE, DuckDBColumnType.DATE, (input, rowCount, out) -> { DuckDBReadableVector in = input.vector(0); @@ -526,7 +520,7 @@ public static void test_register_scalar_function_date() throws Exception { } public static void test_register_scalar_function_date_from_java_util_date() throws Exception { - assertUnaryScalarFunction("java_date_from_util_date", "DATE", "DATE", + assertUnaryScalarFunction("java_date_from_util_date", DuckDBColumnType.DATE, DuckDBColumnType.DATE, (input, rowCount, out) -> { DuckDBReadableVector in = input.vector(0); @@ -550,7 +544,7 @@ public static void test_register_scalar_function_date_from_java_util_date() thro } public static void test_register_scalar_function_timestamp() throws Exception { - assertUnaryScalarFunction("java_add_timestamp", "TIMESTAMP", "TIMESTAMP", + assertUnaryScalarFunction("java_add_timestamp", DuckDBColumnType.TIMESTAMP, DuckDBColumnType.TIMESTAMP, (input, rowCount, out) -> { DuckDBReadableVector in = input.vector(0); @@ -582,7 +576,7 @@ public static void test_register_scalar_function_timestamp() throws Exception { public static void test_register_scalar_function_timestamp_s() throws Exception { assertUnaryScalarFunction( - "java_add_timestamp_s", "TIMESTAMP_S", "TIMESTAMP_S", + "java_add_timestamp_s", DuckDBColumnType.TIMESTAMP_S, DuckDBColumnType.TIMESTAMP_S, (input, rowCount, out) -> { DuckDBReadableVector in = input.vector(0); @@ -605,7 +599,8 @@ public static void test_register_scalar_function_timestamp_s() throws Exception } public static void test_register_scalar_function_timestamp_s_pre_epoch() throws Exception { - assertUnaryScalarFunction("java_copy_timestamp_s_pre_epoch", "TIMESTAMP", "TIMESTAMP_S", + assertUnaryScalarFunction("java_copy_timestamp_s_pre_epoch", DuckDBColumnType.TIMESTAMP, + DuckDBColumnType.TIMESTAMP_S, (input, rowCount, out) -> { DuckDBReadableVector in = input.vector(0); @@ -627,7 +622,7 @@ public static void test_register_scalar_function_timestamp_s_pre_epoch() throws } public static void test_register_scalar_function_timestamp_ms() throws Exception { - assertUnaryScalarFunction("java_add_timestamp_ms", "TIMESTAMP_MS", "TIMESTAMP_MS", + assertUnaryScalarFunction("java_add_timestamp_ms", DuckDBColumnType.TIMESTAMP_MS, DuckDBColumnType.TIMESTAMP_MS, (input, rowCount, out) -> { DuckDBReadableVector in = input.vector(0); @@ -652,7 +647,8 @@ public static void test_register_scalar_function_timestamp_ms() throws Exception } public static void test_register_scalar_function_timestamp_ms_pre_epoch() throws Exception { - assertUnaryScalarFunction("java_copy_timestamp_ms_pre_epoch", "TIMESTAMP", "TIMESTAMP_MS", + assertUnaryScalarFunction("java_copy_timestamp_ms_pre_epoch", DuckDBColumnType.TIMESTAMP, + DuckDBColumnType.TIMESTAMP_MS, (input, rowCount, out) -> { DuckDBReadableVector in = input.vector(0); @@ -675,7 +671,7 @@ public static void test_register_scalar_function_timestamp_ms_pre_epoch() throws } public static void test_register_scalar_function_timestamp_ns() throws Exception { - assertUnaryScalarFunction("java_add_timestamp_ns", "TIMESTAMP_NS", "TIMESTAMP_NS", + assertUnaryScalarFunction("java_add_timestamp_ns", DuckDBColumnType.TIMESTAMP_NS, DuckDBColumnType.TIMESTAMP_NS, (input, rowCount, out) -> { DuckDBReadableVector in = input.vector(0); @@ -701,7 +697,8 @@ public static void test_register_scalar_function_timestamp_ns() throws Exception public static void test_register_scalar_function_timestamptz() throws Exception { assertUnaryScalarFunction( - "java_add_timestamptz", "TIMESTAMP WITH TIME ZONE", "TIMESTAMP WITH TIME ZONE", + "java_add_timestamptz", DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE, + DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE, (input, rowCount, out) -> { DuckDBReadableVector in = input.vector(0); @@ -727,7 +724,8 @@ public static void test_register_scalar_function_timestamptz() throws Exception public static void test_register_scalar_function_timestamptz_set_timestamp() throws Exception { assertUnaryScalarFunction( - "java_copy_timestamptz_with_timestamp", "TIMESTAMP WITH TIME ZONE", "TIMESTAMP WITH TIME ZONE", + "java_copy_timestamptz_with_timestamp", DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE, + DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE, (input, rowCount, out) -> { DuckDBReadableVector in = input.vector(0); @@ -751,7 +749,7 @@ public static void test_register_scalar_function_timestamptz_set_timestamp() thr public static void test_register_scalar_function_timestamp_from_java_util_date() throws Exception { assertUnaryScalarFunction( - "java_timestamp_from_util_date", "TIMESTAMP", "TIMESTAMP", + "java_timestamp_from_util_date", DuckDBColumnType.TIMESTAMP, DuckDBColumnType.TIMESTAMP, (input, rowCount, out) -> { DuckDBReadableVector in = input.vector(0); @@ -780,7 +778,7 @@ public static void test_register_scalar_function_timestamp_from_java_util_date() public static void test_register_scalar_function_timestamp_from_java_util_date_typed_timestamp() throws Exception { assertUnaryScalarFunction( - "java_timestamp_from_util_ts", "TIMESTAMP", "TIMESTAMP", + "java_timestamp_from_util_ts", DuckDBColumnType.TIMESTAMP, DuckDBColumnType.TIMESTAMP, (input, rowCount, out) -> { DuckDBReadableVector in = input.vector(0); @@ -804,7 +802,7 @@ public static void test_register_scalar_function_timestamp_from_java_util_date_t public static void test_register_scalar_function_timestamp_from_java_util_date_typed_sql_date() throws Exception { assertUnaryScalarFunction( - "java_timestamp_from_util_sql_date", "DATE", "TIMESTAMP", + "java_timestamp_from_util_sql_date", DuckDBColumnType.DATE, DuckDBColumnType.TIMESTAMP, (input, rowCount, out) -> { DuckDBReadableVector in = input.vector(0); @@ -827,7 +825,7 @@ public static void test_register_scalar_function_timestamp_from_java_util_date_t public static void test_register_scalar_function_timestamp_from_java_util_date_typed_sql_time() throws Exception { assertUnaryScalarFunction( - "java_timestamp_from_util_sql_time", "TIMESTAMP", "TIMESTAMP", + "java_timestamp_from_util_sql_time", DuckDBColumnType.TIMESTAMP, DuckDBColumnType.TIMESTAMP, (input, rowCount, out) -> { for (int i = 0; i < rowCount; i++) { @@ -845,7 +843,7 @@ public static void test_register_scalar_function_timestamp_from_java_util_date_t public static void test_register_scalar_function_timestamp_from_local_date() throws Exception { assertUnaryScalarFunction( - "java_timestamp_from_local_date", "DATE", "TIMESTAMP", + "java_timestamp_from_local_date", DuckDBColumnType.DATE, DuckDBColumnType.TIMESTAMP, (input, rowCount, out) -> { DuckDBReadableVector in = input.vector(0); @@ -869,7 +867,7 @@ public static void test_register_scalar_function_timestamp_from_local_date() thr } public static void test_register_scalar_function_varchar() throws Exception { - assertUnaryScalarFunction("java_suffix_varchar", "VARCHAR", "VARCHAR", + assertUnaryScalarFunction("java_suffix_varchar", DuckDBColumnType.VARCHAR, DuckDBColumnType.VARCHAR, (input, rowCount, out) -> { DuckDBReadableVector in = input.vector(0); @@ -895,7 +893,7 @@ public static void test_register_scalar_function_varchar() throws Exception { } public static void test_register_scalar_function_varchar_get_string_handles_null() throws Exception { - assertUnaryScalarFunction("java_echo_varchar_nullable", "VARCHAR", "VARCHAR", + assertUnaryScalarFunction("java_echo_varchar_nullable", DuckDBColumnType.VARCHAR, DuckDBColumnType.VARCHAR, (input, rowCount, out) -> { DuckDBReadableVector in = input.vector(0); @@ -917,7 +915,7 @@ public static void test_register_scalar_function_varchar_get_string_handles_null } public static void test_register_scalar_function_varchar_revalidates_after_null() throws Exception { - assertUnaryScalarFunction("java_revalidate_varchar", "VARCHAR", "VARCHAR", + assertUnaryScalarFunction("java_revalidate_varchar", DuckDBColumnType.VARCHAR, DuckDBColumnType.VARCHAR, (input, rowCount, out) -> { DuckDBReadableVector in = input.vector(0); @@ -941,12 +939,21 @@ public static void test_register_scalar_function_varchar_revalidates_after_null( }); } - private static void assertUnaryScalarFunction(String functionName, String parameterType, String returnType, - DuckDBVectorizedScalarFunction function, String query, - ResultSetVerifier verifier) throws Exception { + private static void assertUnaryScalarFunction(String functionName, DuckDBColumnType parameterType, + DuckDBColumnType returnType, DuckDBVectorizedScalarFunction function, + String query, ResultSetVerifier verifier) throws Exception { + try (DuckDBLogicalType parameterLogicalType = DuckDBLogicalType.of(parameterType); + DuckDBLogicalType returnLogicalType = DuckDBLogicalType.of(returnType)) { + assertUnaryScalarFunction(functionName, parameterLogicalType, returnLogicalType, function, query, verifier); + } + } + + private static void assertUnaryScalarFunction(String functionName, DuckDBLogicalType parameterType, + DuckDBLogicalType returnType, DuckDBVectorizedScalarFunction function, + String query, ResultSetVerifier verifier) throws Exception { try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); Statement stmt = conn.createStatement()) { - conn.registerScalarFunction(functionName, new String[] {parameterType}, returnType, function); + conn.registerScalarFunction(functionName, new DuckDBLogicalType[] {parameterType}, returnType, function); try (ResultSet rs = stmt.executeQuery(query)) { verifier.verify(rs); } From 925981bdd88bd6f4027ff91804c02c0609508fdc Mon Sep 17 00:00:00 2001 From: Luis Fernando Kauer Date: Wed, 1 Apr 2026 18:28:00 -0300 Subject: [PATCH 4/9] Refine Java scalar UDF callback API and migrate tests Replace the vectorized callback interface with DuckDBScalarFunction, remove explicit rowCount from callback signatures, and derive row count from input chunks in Java. Update JNI callback invocation signature accordingly, align connection registration and wrapper plumbing, remove legacy RowCountScalarFunction test adapters, and migrate scalar UDF tests/docs to the new (input, out) callback format. --- UDF.MD | 12 +- src/jni/refs.cpp | 2 +- src/jni/scalar_functions.cpp | 15 +-- .../java/org/duckdb/DuckDBConnection.java | 2 +- .../org/duckdb/DuckDBDataChunkReader.java | 4 +- ...unction.java => DuckDBScalarFunction.java} | 5 +- .../duckdb/DuckDBScalarFunctionWrapper.java | 12 +- .../java/org/duckdb/DuckDBVectorTypeInfo.java | 2 +- .../java/org/duckdb/TestScalarFunctions.java | 111 ++++++++++++------ 9 files changed, 99 insertions(+), 66 deletions(-) rename src/main/java/org/duckdb/{DuckDBVectorizedScalarFunction.java => DuckDBScalarFunction.java} (68%) diff --git a/UDF.MD b/UDF.MD index 270a3d33e..1d641cc2c 100644 --- a/UDF.MD +++ b/UDF.MD @@ -1,19 +1,19 @@ # Java Scalar Functions (UDF) -Use `DuckDBConnection.registerScalarFunction` to register a vectorized scalar function in Java. +Use `DuckDBConnection.registerScalarFunction` to register a scalar function in Java. ```java void registerScalarFunction( String name, DuckDBLogicalType[] parameterTypes, DuckDBLogicalType returnType, - DuckDBVectorizedScalarFunction function + DuckDBScalarFunction function ) throws SQLException ``` Notes: - The API uses typed logical types (`DuckDBLogicalType`) instead of SQL type strings. -- The callback is vectorized: process `rowCount` rows from the input chunk and write one value per row into `out`. +- The callback processes all rows from the current input chunk (`input.rowCount()`) and writes one value per row into `out`. - `DuckDBDataChunkReader` / `DuckDBReadableVector` / `DuckDBWritableVector` are valid only during the callback. ## Simple example @@ -21,8 +21,9 @@ Notes: ```java try (DuckDBConnection conn = DriverManager.getConnection("jdbc:duckdb:").unwrap(DuckDBConnection.class)) { try (DuckDBLogicalType intType = DuckDBLogicalType.of(DuckDBColumnType.INTEGER)) { - conn.registerScalarFunction("java_add_one", new DuckDBLogicalType[] {intType}, intType, (input, rowCount, out) -> { + conn.registerScalarFunction("java_add_one", new DuckDBLogicalType[] {intType}, intType, (input, out) -> { DuckDBReadableVector in = input.vector(0); + int rowCount = input.rowCount(); for (int i = 0; i < rowCount; i++) { if (in.isNull(i)) { out.setNull(i); @@ -48,10 +49,11 @@ try (DuckDBConnection conn = DriverManager.getConnection("jdbc:duckdb:").unwrap( "java_event_label", new DuckDBLogicalType[] {tsType, strType, dblType}, strType, - (input, rowCount, out) -> { + (input, out) -> { DuckDBReadableVector ts = input.vector(0); DuckDBReadableVector tag = input.vector(1); DuckDBReadableVector score = input.vector(2); + int rowCount = input.rowCount(); for (int i = 0; i < rowCount; i++) { if (ts.isNull(i) || tag.isNull(i) || score.isNull(i)) { diff --git a/src/jni/refs.cpp b/src/jni/refs.cpp index d76cfbd80..8e7e7799e 100644 --- a/src/jni/refs.cpp +++ b/src/jni/refs.cpp @@ -294,7 +294,7 @@ void create_refs(JNIEnv *env) { J_DuckVector_varlen = get_field_id(env, J_DuckVector, "varlen_data", "[Ljava/lang/Object;"); J_DuckDataChunkReader = make_class_ref(env, "org/duckdb/DuckDBDataChunkReader"); - J_DuckDataChunkReader_init = get_method_id(env, J_DuckDataChunkReader, "", "(Ljava/nio/ByteBuffer;I)V"); + J_DuckDataChunkReader_init = get_method_id(env, J_DuckDataChunkReader, "", "(Ljava/nio/ByteBuffer;)V"); J_DuckWritableVector = make_class_ref(env, "org/duckdb/DuckDBWritableVector"); J_DuckWritableVector_init = get_method_id(env, J_DuckWritableVector, "", "(Ljava/nio/ByteBuffer;I)V"); diff --git a/src/jni/scalar_functions.cpp b/src/jni/scalar_functions.cpp index bb77f5abf..563ecd5c0 100644 --- a/src/jni/scalar_functions.cpp +++ b/src/jni/scalar_functions.cpp @@ -125,15 +125,12 @@ static void get_or_attach_jni_env(JavaVM *vm, JNIEnv *&env, bool &detach_when_do detach_when_done = true; } -static void execute_java_vectorized_scalar_function(JNIEnv *env, JavaScalarFunctionState &state, - duckdb_function_info info, duckdb_data_chunk input, - duckdb_vector output) { - auto row_count = duckdb_data_chunk_get_size(input); +static void execute_java_scalar_function(JNIEnv *env, JavaScalarFunctionState &state, duckdb_function_info info, + duckdb_data_chunk input, duckdb_vector output) { jobject function_info_buf = make_ptr_buf(env, info); jobject input_chunk_buf = make_ptr_buf(env, input); jobject output_vector_buf = make_ptr_buf(env, output); - env->CallVoidMethod(state.callback, state.apply_method, function_info_buf, input_chunk_buf, - static_cast(row_count), output_vector_buf); + env->CallVoidMethod(state.callback, state.apply_method, function_info_buf, input_chunk_buf, output_vector_buf); if (function_info_buf) { env->DeleteLocalRef(function_info_buf); } @@ -192,8 +189,8 @@ void duckdb_jdbc_scalar_function_set_function(JNIEnv *env, jobject conn_ref_buf, try { auto apply_method = get_scalar_callback_method( - env, function_j, "(Ljava/nio/ByteBuffer;Ljava/nio/ByteBuffer;ILjava/nio/ByteBuffer;)V", "execute", - "Could not find execute(ByteBuffer, ByteBuffer, int, ByteBuffer) on scalar function callback"); + env, function_j, "(Ljava/nio/ByteBuffer;Ljava/nio/ByteBuffer;Ljava/nio/ByteBuffer;)V", "execute", + "Could not find execute(ByteBuffer, ByteBuffer, ByteBuffer) on scalar function callback"); auto state = new JavaScalarFunctionState(vm, callback_ref, apply_method); duckdb_scalar_function_set_extra_info(scalar_function, state, destroy_java_scalar_function_state); duckdb_scalar_function_set_function(scalar_function, execute_java_scalar_function_capi); @@ -257,7 +254,7 @@ static void execute_java_scalar_function_capi(duckdb_function_info info, duckdb_ } try { - execute_java_vectorized_scalar_function(local_state->env, *state, info, input, output); + execute_java_scalar_function(local_state->env, *state, info, input, output); } catch (const std::exception &e) { duckdb_scalar_function_set_error(info, e.what()); } diff --git a/src/main/java/org/duckdb/DuckDBConnection.java b/src/main/java/org/duckdb/DuckDBConnection.java index 8c010dea6..b27a39597 100644 --- a/src/main/java/org/duckdb/DuckDBConnection.java +++ b/src/main/java/org/duckdb/DuckDBConnection.java @@ -501,7 +501,7 @@ public void registerArrowStream(String name, Object arrow_array_stream) { } public void registerScalarFunction(String name, DuckDBLogicalType[] parameterTypes, DuckDBLogicalType returnType, - DuckDBVectorizedScalarFunction function) throws SQLException { + DuckDBScalarFunction function) throws SQLException { checkOpen(); connRefLock.lock(); ByteBuffer scalarFunction = null; diff --git a/src/main/java/org/duckdb/DuckDBDataChunkReader.java b/src/main/java/org/duckdb/DuckDBDataChunkReader.java index c997f8842..af628f5a3 100644 --- a/src/main/java/org/duckdb/DuckDBDataChunkReader.java +++ b/src/main/java/org/duckdb/DuckDBDataChunkReader.java @@ -11,12 +11,12 @@ public final class DuckDBDataChunkReader { private final int columnCount; private final DuckDBReadableVector[] vectors; - DuckDBDataChunkReader(ByteBuffer chunkRef, int rowCount) throws SQLException { + DuckDBDataChunkReader(ByteBuffer chunkRef) throws SQLException { if (chunkRef == null) { throw new SQLException("Invalid data chunk reference"); } this.chunkRef = chunkRef; - this.rowCount = rowCount; + this.rowCount = (int) duckdb_data_chunk_get_size(chunkRef); this.columnCount = (int) duckdb_data_chunk_get_column_count(chunkRef); this.vectors = new DuckDBReadableVector[columnCount]; } diff --git a/src/main/java/org/duckdb/DuckDBVectorizedScalarFunction.java b/src/main/java/org/duckdb/DuckDBScalarFunction.java similarity index 68% rename from src/main/java/org/duckdb/DuckDBVectorizedScalarFunction.java rename to src/main/java/org/duckdb/DuckDBScalarFunction.java index 66649f694..9ef0e6fca 100644 --- a/src/main/java/org/duckdb/DuckDBVectorizedScalarFunction.java +++ b/src/main/java/org/duckdb/DuckDBScalarFunction.java @@ -1,16 +1,15 @@ package org.duckdb; @FunctionalInterface -public interface DuckDBVectorizedScalarFunction { +public interface DuckDBScalarFunction { /** * Processes a full input chunk and writes one output value per row directly into the DuckDB output vector. * *

The input and output wrappers are valid only for the duration of the callback and must not be retained. * * @param input input vectors for the current chunk - * @param rowCount number of rows in the current chunk * @param out output vector for the current chunk * @throws Exception when function execution fails */ - void apply(DuckDBDataChunkReader input, int rowCount, DuckDBWritableVector out) throws Exception; + void apply(DuckDBDataChunkReader input, DuckDBWritableVector out) throws Exception; } diff --git a/src/main/java/org/duckdb/DuckDBScalarFunctionWrapper.java b/src/main/java/org/duckdb/DuckDBScalarFunctionWrapper.java index 736a91b3b..9390bd135 100644 --- a/src/main/java/org/duckdb/DuckDBScalarFunctionWrapper.java +++ b/src/main/java/org/duckdb/DuckDBScalarFunctionWrapper.java @@ -5,17 +5,17 @@ import java.nio.ByteBuffer; final class DuckDBScalarFunctionWrapper { - private final DuckDBVectorizedScalarFunction function; + private final DuckDBScalarFunction function; - DuckDBScalarFunctionWrapper(DuckDBVectorizedScalarFunction function) { + DuckDBScalarFunctionWrapper(DuckDBScalarFunction function) { this.function = function; } - public void execute(ByteBuffer functionInfo, ByteBuffer inputChunk, int rowCount, ByteBuffer outputVector) { + public void execute(ByteBuffer functionInfo, ByteBuffer inputChunk, ByteBuffer outputVector) { try { - DuckDBDataChunkReader inputReader = new DuckDBDataChunkReader(inputChunk, rowCount); - DuckDBWritableVector outputWriter = new DuckDBWritableVector(outputVector, rowCount); - function.apply(inputReader, rowCount, outputWriter); + DuckDBDataChunkReader inputReader = new DuckDBDataChunkReader(inputChunk); + DuckDBWritableVector outputWriter = new DuckDBWritableVector(outputVector, inputReader.rowCount()); + function.apply(inputReader, outputWriter); } catch (Throwable throwable) { reportError(functionInfo, throwable); } diff --git a/src/main/java/org/duckdb/DuckDBVectorTypeInfo.java b/src/main/java/org/duckdb/DuckDBVectorTypeInfo.java index f73a979df..909abd6fd 100644 --- a/src/main/java/org/duckdb/DuckDBVectorTypeInfo.java +++ b/src/main/java/org/duckdb/DuckDBVectorTypeInfo.java @@ -78,7 +78,7 @@ static DuckDBVectorTypeInfo fromVector(ByteBuffer vectorRef) throws SQLException (int) internalType.widthBytes, decimalMeta); } default: - throw new SQLException("Unsupported vectorized scalar function type: " + capiType); + throw new SQLException("Unsupported scalar function vector type: " + capiType); } } finally { duckdb_destroy_logical_type(logicalType); diff --git a/src/test/java/org/duckdb/TestScalarFunctions.java b/src/test/java/org/duckdb/TestScalarFunctions.java index bd35d5447..54bd6255b 100644 --- a/src/test/java/org/duckdb/TestScalarFunctions.java +++ b/src/test/java/org/duckdb/TestScalarFunctions.java @@ -58,7 +58,8 @@ public static void test_register_scalar_function_typed_logical_type() throws Exc Statement stmt = conn.createStatement(); DuckDBLogicalType intType = DuckDBLogicalType.of(DuckDBColumnType.INTEGER)) { conn.registerScalarFunction("java_add_int_typed", new DuckDBLogicalType[] {intType}, intType, - (input, rowCount, out) -> { + (input, out) -> { + int rowCount = input.rowCount(); DuckDBReadableVector in = input.vector(0); for (int i = 0; i < rowCount; i++) { if (in.isNull(i)) { @@ -87,7 +88,8 @@ public static void test_register_scalar_function_parallel() throws Exception { DuckDBLogicalType bigintType = DuckDBLogicalType.of(DuckDBColumnType.BIGINT)) { stmt.execute("PRAGMA threads=4"); conn.registerScalarFunction("java_add_one_bigint", new DuckDBLogicalType[] {bigintType}, bigintType, - (input, rowCount, out) -> { + (input, out) -> { + int rowCount = input.rowCount(); DuckDBReadableVector in = input.vector(0); for (int i = 0; i < rowCount; i++) { out.setLong(i, in.getLong(i) + 1); @@ -108,7 +110,7 @@ public static void test_register_scalar_function_exception_propagation() throws Statement stmt = conn.createStatement(); DuckDBLogicalType intType = DuckDBLogicalType.of(DuckDBColumnType.INTEGER)) { conn.registerScalarFunction("java_throws_exception", new DuckDBLogicalType[] {intType}, intType, - (input, rowCount, out) -> { throw new IllegalStateException("boom"); }); + (input, out) -> { throw new IllegalStateException("boom"); }); String message = assertThrows(() -> { stmt.executeQuery("SELECT java_throws_exception(1)"); }, SQLException.class); assertTrue(message.contains("Java scalar function threw exception")); @@ -119,8 +121,9 @@ public static void test_register_scalar_function_exception_propagation() throws public static void test_register_scalar_function_boolean() throws Exception { assertUnaryScalarFunction("java_not_bool", DuckDBColumnType.BOOLEAN, DuckDBColumnType.BOOLEAN, - (input, rowCount, out) + (input, out) -> { + int rowCount = input.rowCount(); DuckDBReadableVector in = input.vector(0); for (int i = 0; i < rowCount; i++) { if (in.isNull(i)) { @@ -144,8 +147,9 @@ public static void test_register_scalar_function_boolean() throws Exception { public static void test_register_scalar_function_tinyint() throws Exception { assertUnaryScalarFunction("java_add_tinyint", DuckDBColumnType.TINYINT, DuckDBColumnType.TINYINT, - (input, rowCount, out) + (input, out) -> { + int rowCount = input.rowCount(); DuckDBReadableVector in = input.vector(0); for (int i = 0; i < rowCount; i++) { if (in.isNull(i)) { @@ -170,8 +174,9 @@ public static void test_register_scalar_function_tinyint() throws Exception { public static void test_register_scalar_function_smallint() throws Exception { assertUnaryScalarFunction( "java_add_smallint", DuckDBColumnType.SMALLINT, DuckDBColumnType.SMALLINT, - (input, rowCount, out) + (input, out) -> { + int rowCount = input.rowCount(); DuckDBReadableVector in = input.vector(0); for (int i = 0; i < rowCount; i++) { if (in.isNull(i)) { @@ -195,8 +200,9 @@ public static void test_register_scalar_function_smallint() throws Exception { public static void test_register_scalar_function_integer() throws Exception { assertUnaryScalarFunction("java_add_int", DuckDBColumnType.INTEGER, DuckDBColumnType.INTEGER, - (input, rowCount, out) + (input, out) -> { + int rowCount = input.rowCount(); DuckDBReadableVector in = input.vector(0); for (int i = 0; i < rowCount; i++) { if (in.isNull(i)) { @@ -220,8 +226,9 @@ public static void test_register_scalar_function_integer() throws Exception { public static void test_register_scalar_function_integer_revalidates_after_null() throws Exception { assertUnaryScalarFunction("java_revalidate_int", DuckDBColumnType.INTEGER, DuckDBColumnType.INTEGER, - (input, rowCount, out) + (input, out) -> { + int rowCount = input.rowCount(); DuckDBReadableVector in = input.vector(0); for (int i = 0; i < rowCount; i++) { if (in.isNull(i)) { @@ -245,8 +252,9 @@ public static void test_register_scalar_function_integer_revalidates_after_null( public static void test_register_scalar_function_bigint() throws Exception { assertUnaryScalarFunction("java_add_bigint", DuckDBColumnType.BIGINT, DuckDBColumnType.BIGINT, - (input, rowCount, out) + (input, out) -> { + int rowCount = input.rowCount(); DuckDBReadableVector in = input.vector(0); for (int i = 0; i < rowCount; i++) { if (in.isNull(i)) { @@ -271,8 +279,9 @@ public static void test_register_scalar_function_bigint() throws Exception { public static void test_register_scalar_function_utinyint() throws Exception { assertUnaryScalarFunction( "java_add_utinyint", DuckDBColumnType.UTINYINT, DuckDBColumnType.UTINYINT, - (input, rowCount, out) + (input, out) -> { + int rowCount = input.rowCount(); DuckDBReadableVector in = input.vector(0); for (int i = 0; i < rowCount; i++) { if (in.isNull(i)) { @@ -297,8 +306,9 @@ public static void test_register_scalar_function_utinyint() throws Exception { public static void test_register_scalar_function_usmallint() throws Exception { assertUnaryScalarFunction( "java_add_usmallint", DuckDBColumnType.USMALLINT, DuckDBColumnType.USMALLINT, - (input, rowCount, out) + (input, out) -> { + int rowCount = input.rowCount(); DuckDBReadableVector in = input.vector(0); for (int i = 0; i < rowCount; i++) { if (in.isNull(i)) { @@ -323,8 +333,9 @@ public static void test_register_scalar_function_usmallint() throws Exception { public static void test_register_scalar_function_uinteger() throws Exception { assertUnaryScalarFunction( "java_add_uinteger", DuckDBColumnType.UINTEGER, DuckDBColumnType.UINTEGER, - (input, rowCount, out) + (input, out) -> { + int rowCount = input.rowCount(); DuckDBReadableVector in = input.vector(0); for (int i = 0; i < rowCount; i++) { if (in.isNull(i)) { @@ -348,8 +359,9 @@ public static void test_register_scalar_function_uinteger() throws Exception { public static void test_register_scalar_function_ubigint() throws Exception { assertUnaryScalarFunction("java_add_ubigint", DuckDBColumnType.UBIGINT, DuckDBColumnType.UBIGINT, - (input, rowCount, out) + (input, out) -> { + int rowCount = input.rowCount(); DuckDBReadableVector in = input.vector(0); BigInteger increment = BigInteger.ONE; for (int i = 0; i < rowCount; i++) { @@ -376,8 +388,9 @@ public static void test_register_scalar_function_ubigint() throws Exception { public static void test_register_scalar_function_float() throws Exception { assertUnaryScalarFunction("java_add_float", DuckDBColumnType.FLOAT, DuckDBColumnType.FLOAT, - (input, rowCount, out) + (input, out) -> { + int rowCount = input.rowCount(); DuckDBReadableVector in = input.vector(0); for (int i = 0; i < rowCount; i++) { if (in.isNull(i)) { @@ -401,8 +414,9 @@ public static void test_register_scalar_function_float() throws Exception { public static void test_register_scalar_function_double() throws Exception { assertUnaryScalarFunction("java_add_double", DuckDBColumnType.DOUBLE, DuckDBColumnType.DOUBLE, - (input, rowCount, out) + (input, out) -> { + int rowCount = input.rowCount(); DuckDBReadableVector in = input.vector(0); for (int i = 0; i < rowCount; i++) { if (in.isNull(i)) { @@ -427,8 +441,9 @@ public static void test_register_scalar_function_double() throws Exception { public static void test_register_scalar_function_decimal() throws Exception { try (DuckDBLogicalType decimalType = DuckDBLogicalType.decimal(38, 10)) { assertUnaryScalarFunction("java_add_decimal", decimalType, decimalType, - (input, rowCount, out) + (input, out) -> { + int rowCount = input.rowCount(); DuckDBReadableVector in = input.vector(0); BigDecimal increment = new BigDecimal("0.0000000001"); for (int i = 0; i < rowCount; i++) { @@ -461,7 +476,8 @@ public static void test_register_scalar_function_decimal_precision_overflow() th Statement stmt = conn.createStatement(); DuckDBLogicalType decimalType = DuckDBLogicalType.decimal(10, 2)) { conn.registerScalarFunction("java_decimal_precision_overflow", new DuckDBLogicalType[] {decimalType}, - decimalType, (input, rowCount, out) -> { + decimalType, (input, out) -> { + int rowCount = input.rowCount(); for (int i = 0; i < rowCount; i++) { out.setBigDecimal(i, new BigDecimal("12345678901.23")); } @@ -479,7 +495,8 @@ public static void test_register_scalar_function_decimal_scale_overflow() throws Statement stmt = conn.createStatement(); DuckDBLogicalType decimalType = DuckDBLogicalType.decimal(10, 2)) { conn.registerScalarFunction("java_decimal_scale_overflow", new DuckDBLogicalType[] {decimalType}, - decimalType, (input, rowCount, out) -> { + decimalType, (input, out) -> { + int rowCount = input.rowCount(); for (int i = 0; i < rowCount; i++) { out.setBigDecimal(i, new BigDecimal("1.234")); } @@ -495,8 +512,9 @@ public static void test_register_scalar_function_decimal_scale_overflow() throws public static void test_register_scalar_function_date() throws Exception { assertUnaryScalarFunction( "java_add_date", DuckDBColumnType.DATE, DuckDBColumnType.DATE, - (input, rowCount, out) + (input, out) -> { + int rowCount = input.rowCount(); DuckDBReadableVector in = input.vector(0); for (int i = 0; i < rowCount; i++) { if (in.isNull(i)) { @@ -521,8 +539,9 @@ public static void test_register_scalar_function_date() throws Exception { public static void test_register_scalar_function_date_from_java_util_date() throws Exception { assertUnaryScalarFunction("java_date_from_util_date", DuckDBColumnType.DATE, DuckDBColumnType.DATE, - (input, rowCount, out) + (input, out) -> { + int rowCount = input.rowCount(); DuckDBReadableVector in = input.vector(0); for (int i = 0; i < rowCount; i++) { if (in.isNull(i)) { @@ -545,8 +564,9 @@ public static void test_register_scalar_function_date_from_java_util_date() thro public static void test_register_scalar_function_timestamp() throws Exception { assertUnaryScalarFunction("java_add_timestamp", DuckDBColumnType.TIMESTAMP, DuckDBColumnType.TIMESTAMP, - (input, rowCount, out) + (input, out) -> { + int rowCount = input.rowCount(); DuckDBReadableVector in = input.vector(0); for (int i = 0; i < rowCount; i++) { if (in.isNull(i)) { @@ -577,8 +597,9 @@ public static void test_register_scalar_function_timestamp() throws Exception { public static void test_register_scalar_function_timestamp_s() throws Exception { assertUnaryScalarFunction( "java_add_timestamp_s", DuckDBColumnType.TIMESTAMP_S, DuckDBColumnType.TIMESTAMP_S, - (input, rowCount, out) + (input, out) -> { + int rowCount = input.rowCount(); DuckDBReadableVector in = input.vector(0); for (int i = 0; i < rowCount; i++) { if (in.isNull(i)) { @@ -601,8 +622,9 @@ public static void test_register_scalar_function_timestamp_s() throws Exception public static void test_register_scalar_function_timestamp_s_pre_epoch() throws Exception { assertUnaryScalarFunction("java_copy_timestamp_s_pre_epoch", DuckDBColumnType.TIMESTAMP, DuckDBColumnType.TIMESTAMP_S, - (input, rowCount, out) + (input, out) -> { + int rowCount = input.rowCount(); DuckDBReadableVector in = input.vector(0); for (int i = 0; i < rowCount; i++) { if (in.isNull(i)) { @@ -623,8 +645,9 @@ public static void test_register_scalar_function_timestamp_s_pre_epoch() throws public static void test_register_scalar_function_timestamp_ms() throws Exception { assertUnaryScalarFunction("java_add_timestamp_ms", DuckDBColumnType.TIMESTAMP_MS, DuckDBColumnType.TIMESTAMP_MS, - (input, rowCount, out) + (input, out) -> { + int rowCount = input.rowCount(); DuckDBReadableVector in = input.vector(0); for (int i = 0; i < rowCount; i++) { if (in.isNull(i)) { @@ -649,8 +672,9 @@ public static void test_register_scalar_function_timestamp_ms() throws Exception public static void test_register_scalar_function_timestamp_ms_pre_epoch() throws Exception { assertUnaryScalarFunction("java_copy_timestamp_ms_pre_epoch", DuckDBColumnType.TIMESTAMP, DuckDBColumnType.TIMESTAMP_MS, - (input, rowCount, out) + (input, out) -> { + int rowCount = input.rowCount(); DuckDBReadableVector in = input.vector(0); for (int i = 0; i < rowCount; i++) { if (in.isNull(i)) { @@ -672,8 +696,9 @@ public static void test_register_scalar_function_timestamp_ms_pre_epoch() throws public static void test_register_scalar_function_timestamp_ns() throws Exception { assertUnaryScalarFunction("java_add_timestamp_ns", DuckDBColumnType.TIMESTAMP_NS, DuckDBColumnType.TIMESTAMP_NS, - (input, rowCount, out) + (input, out) -> { + int rowCount = input.rowCount(); DuckDBReadableVector in = input.vector(0); for (int i = 0; i < rowCount; i++) { if (in.isNull(i)) { @@ -699,8 +724,9 @@ public static void test_register_scalar_function_timestamptz() throws Exception assertUnaryScalarFunction( "java_add_timestamptz", DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE, DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE, - (input, rowCount, out) + (input, out) -> { + int rowCount = input.rowCount(); DuckDBReadableVector in = input.vector(0); for (int i = 0; i < rowCount; i++) { if (in.isNull(i)) { @@ -726,8 +752,9 @@ public static void test_register_scalar_function_timestamptz_set_timestamp() thr assertUnaryScalarFunction( "java_copy_timestamptz_with_timestamp", DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE, DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE, - (input, rowCount, out) + (input, out) -> { + int rowCount = input.rowCount(); DuckDBReadableVector in = input.vector(0); for (int i = 0; i < rowCount; i++) { if (in.isNull(i)) { @@ -750,8 +777,9 @@ public static void test_register_scalar_function_timestamptz_set_timestamp() thr public static void test_register_scalar_function_timestamp_from_java_util_date() throws Exception { assertUnaryScalarFunction( "java_timestamp_from_util_date", DuckDBColumnType.TIMESTAMP, DuckDBColumnType.TIMESTAMP, - (input, rowCount, out) + (input, out) -> { + int rowCount = input.rowCount(); DuckDBReadableVector in = input.vector(0); long oneSecondMillis = 1000L; for (int i = 0; i < rowCount; i++) { @@ -779,8 +807,9 @@ public static void test_register_scalar_function_timestamp_from_java_util_date() public static void test_register_scalar_function_timestamp_from_java_util_date_typed_timestamp() throws Exception { assertUnaryScalarFunction( "java_timestamp_from_util_ts", DuckDBColumnType.TIMESTAMP, DuckDBColumnType.TIMESTAMP, - (input, rowCount, out) + (input, out) -> { + int rowCount = input.rowCount(); DuckDBReadableVector in = input.vector(0); for (int i = 0; i < rowCount; i++) { if (in.isNull(i)) { @@ -803,8 +832,9 @@ public static void test_register_scalar_function_timestamp_from_java_util_date_t public static void test_register_scalar_function_timestamp_from_java_util_date_typed_sql_date() throws Exception { assertUnaryScalarFunction( "java_timestamp_from_util_sql_date", DuckDBColumnType.DATE, DuckDBColumnType.TIMESTAMP, - (input, rowCount, out) + (input, out) -> { + int rowCount = input.rowCount(); DuckDBReadableVector in = input.vector(0); for (int i = 0; i < rowCount; i++) { if (in.isNull(i)) { @@ -826,8 +856,9 @@ public static void test_register_scalar_function_timestamp_from_java_util_date_t public static void test_register_scalar_function_timestamp_from_java_util_date_typed_sql_time() throws Exception { assertUnaryScalarFunction( "java_timestamp_from_util_sql_time", DuckDBColumnType.TIMESTAMP, DuckDBColumnType.TIMESTAMP, - (input, rowCount, out) + (input, out) -> { + int rowCount = input.rowCount(); for (int i = 0; i < rowCount; i++) { java.util.Date value = Time.valueOf("12:34:56"); out.setTimestamp(i, value); @@ -844,8 +875,9 @@ public static void test_register_scalar_function_timestamp_from_java_util_date_t public static void test_register_scalar_function_timestamp_from_local_date() throws Exception { assertUnaryScalarFunction( "java_timestamp_from_local_date", DuckDBColumnType.DATE, DuckDBColumnType.TIMESTAMP, - (input, rowCount, out) + (input, out) -> { + int rowCount = input.rowCount(); DuckDBReadableVector in = input.vector(0); for (int i = 0; i < rowCount; i++) { if (in.isNull(i)) { @@ -868,8 +900,9 @@ public static void test_register_scalar_function_timestamp_from_local_date() thr public static void test_register_scalar_function_varchar() throws Exception { assertUnaryScalarFunction("java_suffix_varchar", DuckDBColumnType.VARCHAR, DuckDBColumnType.VARCHAR, - (input, rowCount, out) + (input, out) -> { + int rowCount = input.rowCount(); DuckDBReadableVector in = input.vector(0); for (int i = 0; i < rowCount; i++) { if (in.isNull(i)) { @@ -894,8 +927,9 @@ public static void test_register_scalar_function_varchar() throws Exception { public static void test_register_scalar_function_varchar_get_string_handles_null() throws Exception { assertUnaryScalarFunction("java_echo_varchar_nullable", DuckDBColumnType.VARCHAR, DuckDBColumnType.VARCHAR, - (input, rowCount, out) + (input, out) -> { + int rowCount = input.rowCount(); DuckDBReadableVector in = input.vector(0); for (int i = 0; i < rowCount; i++) { out.setString(i, in.getString(i)); @@ -916,8 +950,9 @@ public static void test_register_scalar_function_varchar_get_string_handles_null public static void test_register_scalar_function_varchar_revalidates_after_null() throws Exception { assertUnaryScalarFunction("java_revalidate_varchar", DuckDBColumnType.VARCHAR, DuckDBColumnType.VARCHAR, - (input, rowCount, out) + (input, out) -> { + int rowCount = input.rowCount(); DuckDBReadableVector in = input.vector(0); for (int i = 0; i < rowCount; i++) { if (in.isNull(i)) { @@ -940,7 +975,7 @@ public static void test_register_scalar_function_varchar_revalidates_after_null( } private static void assertUnaryScalarFunction(String functionName, DuckDBColumnType parameterType, - DuckDBColumnType returnType, DuckDBVectorizedScalarFunction function, + DuckDBColumnType returnType, DuckDBScalarFunction function, String query, ResultSetVerifier verifier) throws Exception { try (DuckDBLogicalType parameterLogicalType = DuckDBLogicalType.of(parameterType); DuckDBLogicalType returnLogicalType = DuckDBLogicalType.of(returnType)) { @@ -949,7 +984,7 @@ private static void assertUnaryScalarFunction(String functionName, DuckDBColumnT } private static void assertUnaryScalarFunction(String functionName, DuckDBLogicalType parameterType, - DuckDBLogicalType returnType, DuckDBVectorizedScalarFunction function, + DuckDBLogicalType returnType, DuckDBScalarFunction function, String query, ResultSetVerifier verifier) throws Exception { try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); Statement stmt = conn.createStatement()) { From 211d7c19a2c0d279f66f32558227f897bd100bf1 Mon Sep 17 00:00:00 2001 From: Luis Fernando Kauer Date: Mon, 6 Apr 2026 21:50:22 -0300 Subject: [PATCH 5/9] Implement and rework Java scalar UDF APIs around context callbacks Collapse the builder introduction and callback rework into one step, and keep JNI overload symbols exported consistently across toolchains. --- UDF.MD | 226 +- duckdb_java.exp | 6 +- duckdb_java.map | 6 +- src/jni/bindings_scalar_function.cpp | 36 +- src/jni/bindings_vector.cpp | 33 +- src/jni/duckdb_java.cpp | 2 - src/jni/refs.cpp | 12 - src/jni/refs.hpp | 6 - src/jni/scalar_functions.cpp | 48 +- src/jni/scalar_functions.hpp | 3 +- src/main/java/org/duckdb/DuckDBBindings.java | 14 +- .../java/org/duckdb/DuckDBConnection.java | 47 - .../org/duckdb/DuckDBDataChunkReader.java | 10 +- src/main/java/org/duckdb/DuckDBDriver.java | 30 + .../java/org/duckdb/DuckDBFunctionKind.java | 3 + src/main/java/org/duckdb/DuckDBFunctions.java | 12 + .../java/org/duckdb/DuckDBReadableVector.java | 253 +-- .../org/duckdb/DuckDBReadableVectorImpl.java | 286 +++ .../org/duckdb/DuckDBRegisteredFunction.java | 97 + .../java/org/duckdb/DuckDBScalarContext.java | 88 + .../java/org/duckdb/DuckDBScalarFunction.java | 8 +- .../duckdb/DuckDBScalarFunctionAdapter.java | 498 ++++ .../duckdb/DuckDBScalarFunctionBuilder.java | 467 ++++ .../duckdb/DuckDBScalarFunctionWrapper.java | 9 +- src/main/java/org/duckdb/DuckDBScalarRow.java | 360 +++ .../java/org/duckdb/DuckDBWritableVector.java | 489 +--- .../org/duckdb/DuckDBWritableVectorImpl.java | 707 ++++++ src/test/java/org/duckdb/TestBindings.java | 163 +- .../java/org/duckdb/TestScalarFunctions.java | 1998 +++++++++++++---- 29 files changed, 4626 insertions(+), 1291 deletions(-) create mode 100644 src/main/java/org/duckdb/DuckDBFunctionKind.java create mode 100644 src/main/java/org/duckdb/DuckDBFunctions.java create mode 100644 src/main/java/org/duckdb/DuckDBReadableVectorImpl.java create mode 100644 src/main/java/org/duckdb/DuckDBRegisteredFunction.java create mode 100644 src/main/java/org/duckdb/DuckDBScalarContext.java create mode 100644 src/main/java/org/duckdb/DuckDBScalarFunctionAdapter.java create mode 100644 src/main/java/org/duckdb/DuckDBScalarFunctionBuilder.java create mode 100644 src/main/java/org/duckdb/DuckDBScalarRow.java create mode 100644 src/main/java/org/duckdb/DuckDBWritableVectorImpl.java diff --git a/UDF.MD b/UDF.MD index 1d641cc2c..0398165b7 100644 --- a/UDF.MD +++ b/UDF.MD @@ -1,71 +1,191 @@ # Java Scalar Functions (UDF) -Use `DuckDBConnection.registerScalarFunction` to register a scalar function in Java. +Use `DuckDBFunctions.scalarFunction()` to build and register scalar functions. +`register(java.sql.Connection)` returns a `DuckDBRegisteredFunction` with metadata about the registered scalar function. + +Registered functions are also tracked in a Java-side registry exposed by `DuckDBDriver.registeredFunctions()`. +This registry is bookkeeping for functions registered through the JDBC API, not an authoritative view of the DuckDB catalog. +`DuckDBDriver.clearFunctionsRegistry()` clears only the Java-side registry and does not de-register functions from DuckDB. + +## Recommended API (Functional Interfaces) + +Use these overloads for simple functions: + +- `withFunction(Supplier)` for zero arguments +- `withFunction(Function)` for one argument +- `withFunction(BiFunction)` for two arguments +- `withIntFunction(IntUnaryOperator | IntBinaryOperator)` for `INTEGER` unary/binary functions +- `withLongFunction(LongUnaryOperator | LongBinaryOperator)` for `BIGINT` unary/binary functions +- `withDoubleFunction(DoubleUnaryOperator | DoubleBinaryOperator)` for `DOUBLE` unary/binary functions + +### Simple example (`withIntFunction`) ```java -void registerScalarFunction( - String name, - DuckDBLogicalType[] parameterTypes, - DuckDBLogicalType returnType, - DuckDBScalarFunction function -) throws SQLException +try (Connection conn = DriverManager.getConnection("jdbc:duckdb:")) { + DuckDBRegisteredFunction function = DuckDBFunctions.scalarFunction() + .withName("java_add_one") + .withParameter(Integer.class) + .withReturnType(Integer.class) + .withIntFunction(x -> x + 1) + .register(conn); +} ``` -Notes: -- The API uses typed logical types (`DuckDBLogicalType`) instead of SQL type strings. -- The callback processes all rows from the current input chunk (`input.rowCount()`) and writes one value per row into `out`. -- `DuckDBDataChunkReader` / `DuckDBReadableVector` / `DuckDBWritableVector` are valid only during the callback. +```sql +SELECT java_add_one(41); +``` -## Simple example +### Slightly more complex example (`withDoubleFunction`) ```java -try (DuckDBConnection conn = DriverManager.getConnection("jdbc:duckdb:").unwrap(DuckDBConnection.class)) { - try (DuckDBLogicalType intType = DuckDBLogicalType.of(DuckDBColumnType.INTEGER)) { - conn.registerScalarFunction("java_add_one", new DuckDBLogicalType[] {intType}, intType, (input, out) -> { - DuckDBReadableVector in = input.vector(0); - int rowCount = input.rowCount(); - for (int i = 0; i < rowCount; i++) { - if (in.isNull(i)) { - out.setNull(i); - } else { - out.setInt(i, in.getInt(i) + 1); - } - } - }); - } +try (Connection conn = DriverManager.getConnection("jdbc:duckdb:")) { + DuckDBFunctions.scalarFunction() + .withName("java_weighted_sum") + .withParameter(Double.class) + .withParameter(Double.class) + .withReturnType(Double.class) + .withDoubleFunction((x, w) -> x * w + 10.0) + .register(conn); } ``` -## More complete example +```sql +SELECT java_weighted_sum(2.5, 4.0); +``` + +Behavior: + +- `propagateNulls(true)` is the default. +- With `propagateNulls(true)`, NULL input propagates to NULL output for `Function`/`BiFunction`. +- With `propagateNulls(false)`, functional callbacks receive `null` arguments and can decide the output. +- `withIntFunction(...)`, `withLongFunction(...)`, and `withDoubleFunction(...)` require `propagateNulls(true)`. +- For `Supplier`, returning `null` writes NULL output. +- `Function` and `BiFunction` are fixed arity only (no varargs). -Build a label from `TIMESTAMP` + `VARCHAR` + `DOUBLE`, preserving `NULL` behavior: +## Type declaration and mapping + +`withParameter(...)` and `withReturnType(...)` accept: + +- `Class` +- `DuckDBColumnType` +- `DuckDBLogicalType` + +Common class mappings include: + +- `Integer` -> `INTEGER` +- `Long` -> `BIGINT` +- `String` -> `VARCHAR` +- `BigDecimal` -> `DECIMAL` +- `LocalDate` and `java.sql.Date` -> `DATE` +- `LocalDateTime`, `java.sql.Timestamp`, and `java.util.Date` -> `TIMESTAMP` + +Use `DuckDBLogicalType.decimal(width, scale)` for explicit DECIMAL precision/scale. + +## Varargs + +Declare varargs type with `withVarArgs(DuckDBLogicalType)`. + +For functional varargs, use `withVarArgsFunction(Function)`: ```java -try (DuckDBConnection conn = DriverManager.getConnection("jdbc:duckdb:").unwrap(DuckDBConnection.class)) { - try (DuckDBLogicalType tsType = DuckDBLogicalType.of(DuckDBColumnType.TIMESTAMP); - DuckDBLogicalType strType = DuckDBLogicalType.of(DuckDBColumnType.VARCHAR); - DuckDBLogicalType dblType = DuckDBLogicalType.of(DuckDBColumnType.DOUBLE)) { - conn.registerScalarFunction( - "java_event_label", - new DuckDBLogicalType[] {tsType, strType, dblType}, - strType, - (input, out) -> { - DuckDBReadableVector ts = input.vector(0); - DuckDBReadableVector tag = input.vector(1); - DuckDBReadableVector score = input.vector(2); - int rowCount = input.rowCount(); - - for (int i = 0; i < rowCount; i++) { - if (ts.isNull(i) || tag.isNull(i) || score.isNull(i)) { - out.setNull(i); - continue; - } - String value = - ts.getLocalDateTime(i) + " | " + tag.getString(i).trim().toUpperCase() + " | " + score.getDouble(i); - out.setString(i, value); - } +try (Connection conn = DriverManager.getConnection("jdbc:duckdb:"); + DuckDBLogicalType intType = DuckDBLogicalType.of(DuckDBColumnType.INTEGER)) { + DuckDBFunctions.scalarFunction() + .withName("java_sum_varargs") + .withParameter(Integer.class) // fixed argument(s) + .withVarArgs(intType) // variadic argument type + .withReturnType(Integer.class) + .withVarArgsFunction(args -> { + int sum = 0; + for (Object arg : args) { + sum += (Integer) arg; } - ); - } + return sum; + }) + .register(conn); +} +``` + +```sql +SELECT java_sum_varargs(1, 2, 3, 4); +``` + +Notes: + +- `withFunction(Function)` and `withFunction(BiFunction)` reject varargs. +- `withVarArgsFunction(...)` requires `withVarArgs(...)` first. + +## Builder methods + +- `withName(String)` +- `withParameter(Class | DuckDBColumnType | DuckDBLogicalType)` +- `withReturnType(Class | DuckDBColumnType | DuckDBLogicalType)` +- `withFunction(Supplier | Function | BiFunction)` +- `withIntFunction(IntUnaryOperator | IntBinaryOperator)` +- `withLongFunction(LongUnaryOperator | LongBinaryOperator)` +- `withDoubleFunction(DoubleUnaryOperator | DoubleBinaryOperator)` +- `withVarArgs(DuckDBLogicalType)` +- `withVarArgsFunction(Function)` +- `withVectorizedFunction(DuckDBScalarFunction)` +- `propagateNulls(boolean)` +- `withVolatile()` +- `withSpecialHandling()` +- `register(java.sql.Connection)` + +## Registered Function Metadata And Registry + +`DuckDBRegisteredFunction` exposes immutable metadata about the successful registration result: + +- `name()` +- `functionKind()` +- `isScalar()` +- parameter and return type metadata +- callback and flags used at registration time + +To inspect Java-side registrations: + +```java +List functions = DuckDBDriver.registeredFunctions(); +``` + +The returned list is read-only. Duplicate function names may appear in the registry. + +## Advanced API (`DuckDBScalarFunction`) + +Use `withVectorizedFunction(...)` for full context control through `DuckDBScalarContext`. + +Example with multiple input types (`TIMESTAMP`, `VARCHAR`, `DOUBLE`) and `VARCHAR` output: + +```java +try (Connection conn = DriverManager.getConnection("jdbc:duckdb:"); + DuckDBLogicalType tsType = DuckDBLogicalType.of(DuckDBColumnType.TIMESTAMP); + DuckDBLogicalType strType = DuckDBLogicalType.of(DuckDBColumnType.VARCHAR); + DuckDBLogicalType dblType = DuckDBLogicalType.of(DuckDBColumnType.DOUBLE)) { + DuckDBFunctions.scalarFunction() + .withName("java_event_label") + .withParameter(tsType) + .withParameter(strType) + .withParameter(dblType) + .withReturnType(strType) + .propagateNulls(true) + .withVectorizedFunction(ctx -> { + ctx.stream().forEachOrdered(row -> { + String value = row.getLocalDateTime(0) + " | " + + row.getString(1).trim().toUpperCase() + + " | " + row.getDouble(2); + row.setString(value); + }); + }) + .register(conn); } ``` + +```sql +SELECT java_event_label(TIMESTAMP '2026-04-04 12:00:00', 'launch', 4.5); +``` + +Lifecycle rules: + +- `DuckDBScalarContext`, `DuckDBScalarRow`, `DuckDBReadableVector`, and `DuckDBWritableVector` are valid only during callback execution. +- Write exactly one output value per input row for each callback invocation. +- With `propagateNulls(true)`, `DuckDBScalarContext.stream()` skips rows that contain NULL in any input column and writes NULL to the output for those rows. diff --git a/duckdb_java.exp b/duckdb_java.exp index cad64ce4e..b0611b6f6 100644 --- a/duckdb_java.exp +++ b/duckdb_java.exp @@ -54,9 +54,13 @@ _Java_org_duckdb_DuckDBBindings_duckdb_1destroy_1scalar_1function _Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1name _Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1add_1parameter _Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1return_1type +_Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1varargs +_Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1volatile +_Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1special_1handling _Java_org_duckdb_DuckDBBindings_duckdb_1register_1scalar_1function _Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1error -_Java_org_duckdb_DuckDBBindings_duckdb_1jdbc_1varchar_1string_1bytes +_Java_org_duckdb_DuckDBBindings_duckdb_1vector_1get_1string +_Java_org_duckdb_DuckDBBindings_duckdb_1vector_1get_1string__Ljava_nio_ByteBuffer_2J _Java_org_duckdb_DuckDBBindings_duckdb_1create_1logical_1type _Java_org_duckdb_DuckDBBindings_duckdb_1create_1decimal_1type _Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1function diff --git a/duckdb_java.map b/duckdb_java.map index e8179a6b3..0fc9c6fc7 100644 --- a/duckdb_java.map +++ b/duckdb_java.map @@ -56,9 +56,13 @@ DUCKDB_JAVA { Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1name; Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1add_1parameter; Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1return_1type; + Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1varargs; + Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1volatile; + Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1special_1handling; Java_org_duckdb_DuckDBBindings_duckdb_1register_1scalar_1function; Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1error; - Java_org_duckdb_DuckDBBindings_duckdb_1jdbc_1varchar_1string_1bytes; + Java_org_duckdb_DuckDBBindings_duckdb_1vector_1get_1string; + Java_org_duckdb_DuckDBBindings_duckdb_1vector_1get_1string__Ljava_nio_ByteBuffer_2J; Java_org_duckdb_DuckDBBindings_duckdb_1create_1logical_1type; Java_org_duckdb_DuckDBBindings_duckdb_1create_1decimal_1type; Java_org_duckdb_DuckDBBindings_duckdb_1get_1type_1id; diff --git a/src/jni/bindings_scalar_function.cpp b/src/jni/bindings_scalar_function.cpp index f84a8f6f0..54e6557a9 100644 --- a/src/jni/bindings_scalar_function.cpp +++ b/src/jni/bindings_scalar_function.cpp @@ -91,6 +91,38 @@ JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1 duckdb_scalar_function_set_return_type(function, type); } +JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1varargs(JNIEnv *env, jclass, + jobject scalar_function, + jobject logical_type) { + auto function = scalar_function_buf_to_scalar_function(env, scalar_function); + if (env->ExceptionCheck()) { + return; + } + auto type = logical_type_buf_to_logical_type(env, logical_type); + if (env->ExceptionCheck()) { + return; + } + duckdb_scalar_function_set_varargs(function, type); +} + +JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1volatile(JNIEnv *env, jclass, + jobject scalar_function) { + auto function = scalar_function_buf_to_scalar_function(env, scalar_function); + if (env->ExceptionCheck()) { + return; + } + duckdb_scalar_function_set_volatile(function); +} + +JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1special_1handling( + JNIEnv *env, jclass, jobject scalar_function) { + auto function = scalar_function_buf_to_scalar_function(env, scalar_function); + if (env->ExceptionCheck()) { + return; + } + duckdb_scalar_function_set_special_handling(function); +} + JNIEXPORT jint JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1register_1scalar_1function(JNIEnv *env, jclass, jobject connection, jobject scalar_function) { @@ -106,9 +138,9 @@ JNIEXPORT jint JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1register_1scalar_1 } JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1function( - JNIEnv *env, jclass, jobject conn_ref_buf, jobject scalar_function_buf, jobject function_j) { + JNIEnv *env, jclass, jobject scalar_function_buf, jobject function_j) { try { - duckdb_jdbc_scalar_function_set_function(env, conn_ref_buf, scalar_function_buf, function_j); + scalar_function_set_function(env, scalar_function_buf, function_j); } catch (const std::exception &e) { duckdb::ErrorData error(e); ThrowJNI(env, error.Message().c_str()); diff --git a/src/jni/bindings_vector.cpp b/src/jni/bindings_vector.cpp index 9bda6e778..f8f37bf7c 100644 --- a/src/jni/bindings_vector.cpp +++ b/src/jni/bindings_vector.cpp @@ -23,11 +23,12 @@ static duckdb_vector vector_buf_to_vector(JNIEnv *env, jobject vector_buf) { /* * Class: org_duckdb_DuckDBBindings - * Method: duckdb_jdbc_varchar_string_bytes - * Signature: (Ljava/nio/ByteBuffer;Ljava/nio/ByteBuffer;JJ)[B + * Method: duckdb_vector_get_string + * Signature: (Ljava/nio/ByteBuffer;J)[B */ -JNIEXPORT jbyteArray JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1jdbc_1varchar_1string_1bytes( - JNIEnv *env, jclass, jobject vector_data, jobject validity, jlong row_count, jlong row) { +JNIEXPORT jbyteArray JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1vector_1get_1string(JNIEnv *env, jclass, + jobject vector_data, + jlong row) { if (vector_data == nullptr) { env->ThrowNew(J_SQLException, "Invalid vector data buffer"); @@ -38,34 +39,22 @@ JNIEXPORT jbyteArray JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1jdbc_1varcha env->ThrowNew(J_SQLException, "Invalid vector data"); return nullptr; } - idx_t row_count_idx = jlong_to_idx(env, row_count); - if (env->ExceptionCheck()) { - return nullptr; - } idx_t row_idx = jlong_to_idx(env, row); if (env->ExceptionCheck()) { return nullptr; } - if (row_idx >= row_count_idx) { - env->ThrowNew(J_SQLException, "Row index out of bounds"); - return nullptr; - } - if (validity != nullptr) { - auto mask = reinterpret_cast(env->GetDirectBufferAddress(validity)); - if (mask == nullptr) { - env->ThrowNew(J_SQLException, "Invalid validity buffer"); - return nullptr; - } - if ((mask[row_idx / 64] & (1ULL << (row_idx % 64))) == 0) { - return nullptr; - } - } auto &string_value = data[row_idx]; auto string_len = duckdb_string_t_length(string_value); auto string_ptr = duckdb_string_t_data(&string_value); return make_jbyteArray(env, string_ptr, string_len); } +extern "C" JNIEXPORT jbyteArray JNICALL +Java_org_duckdb_DuckDBBindings_duckdb_1vector_1get_1string__Ljava_nio_ByteBuffer_2J(JNIEnv *env, jclass clazz, + jobject vector_data, jlong row) { + return Java_org_duckdb_DuckDBBindings_duckdb_1vector_1get_1string(env, clazz, vector_data, row); +} + /* * Class: org_duckdb_DuckDBBindings * Method: duckdb_create_vector diff --git a/src/jni/duckdb_java.cpp b/src/jni/duckdb_java.cpp index 22c51f16c..f8a5b1964 100644 --- a/src/jni/duckdb_java.cpp +++ b/src/jni/duckdb_java.cpp @@ -65,8 +65,6 @@ JNIEXPORT void JNICALL JNI_OnUnload(JavaVM *vm, void *reserved) { delete_global_refs(env); } -jobject ProcessVector(JNIEnv *env, Connection *conn_ref, Vector &vec, idx_t row_count); - //! The database instance cache, used so that multiple connections to the same file point to the same database object duckdb::DBInstanceCache instance_cache; diff --git a/src/jni/refs.cpp b/src/jni/refs.cpp index 8e7e7799e..1dc6cbafe 100644 --- a/src/jni/refs.cpp +++ b/src/jni/refs.cpp @@ -61,12 +61,6 @@ jmethodID J_DuckVector_retainConstlenData; jfieldID J_DuckVector_constlen; jfieldID J_DuckVector_varlen; -jclass J_DuckDataChunkReader; -jmethodID J_DuckDataChunkReader_init; - -jclass J_DuckWritableVector; -jmethodID J_DuckWritableVector_init; - jclass J_DuckArray; jmethodID J_DuckArray_init; @@ -293,12 +287,6 @@ void create_refs(JNIEnv *env) { J_DuckVector_constlen = get_field_id(env, J_DuckVector, "constlen_data", "Ljava/nio/ByteBuffer;"); J_DuckVector_varlen = get_field_id(env, J_DuckVector, "varlen_data", "[Ljava/lang/Object;"); - J_DuckDataChunkReader = make_class_ref(env, "org/duckdb/DuckDBDataChunkReader"); - J_DuckDataChunkReader_init = get_method_id(env, J_DuckDataChunkReader, "", "(Ljava/nio/ByteBuffer;)V"); - - J_DuckWritableVector = make_class_ref(env, "org/duckdb/DuckDBWritableVector"); - J_DuckWritableVector_init = get_method_id(env, J_DuckWritableVector, "", "(Ljava/nio/ByteBuffer;I)V"); - J_ByteBuffer = make_class_ref(env, "java/nio/ByteBuffer"); J_ByteBuffer_order = get_method_id(env, J_ByteBuffer, "order", "(Ljava/nio/ByteOrder;)Ljava/nio/ByteBuffer;"); J_ByteOrder = make_class_ref(env, "java/nio/ByteOrder"); diff --git a/src/jni/refs.hpp b/src/jni/refs.hpp index 94236dcf8..cda859d33 100644 --- a/src/jni/refs.hpp +++ b/src/jni/refs.hpp @@ -58,12 +58,6 @@ extern jmethodID J_DuckVector_retainConstlenData; extern jfieldID J_DuckVector_constlen; extern jfieldID J_DuckVector_varlen; -extern jclass J_DuckDataChunkReader; -extern jmethodID J_DuckDataChunkReader_init; - -extern jclass J_DuckWritableVector; -extern jmethodID J_DuckWritableVector_init; - extern jclass J_DuckArray; extern jmethodID J_DuckArray_init; diff --git a/src/jni/scalar_functions.cpp b/src/jni/scalar_functions.cpp index 563ecd5c0..054561ac1 100644 --- a/src/jni/scalar_functions.cpp +++ b/src/jni/scalar_functions.cpp @@ -2,12 +2,27 @@ extern "C" { #include "duckdb.h" } -#include "duckdb/common/exception.hpp" #include "holders.hpp" #include "refs.hpp" #include "scalar_functions.hpp" #include "util.hpp" +#include +#include + +class ScalarFunctionException : public std::exception { +public: + explicit ScalarFunctionException(std::string message_p) : message(std::move(message_p)) { + } + + const char *what() const noexcept override { + return message.c_str(); + } + +private: + std::string message; +}; + struct JNIEnvGuard { JavaVM *vm; JNIEnv *env; @@ -15,18 +30,18 @@ struct JNIEnvGuard { explicit JNIEnvGuard(JavaVM *vm_p) : vm(vm_p), env(nullptr), detach_when_done(false) { if (!vm) { - throw duckdb::InvalidInputException("JVM is not available"); + throw ScalarFunctionException("JVM is not available"); } auto get_env_status = vm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6); if (get_env_status == JNI_OK) { return; } if (get_env_status != JNI_EDETACHED) { - throw duckdb::InvalidInputException("Failed to get JNI environment"); + throw ScalarFunctionException("Failed to get JNI environment"); } auto attach_status = vm->AttachCurrentThread(reinterpret_cast(&env), nullptr); if (attach_status != JNI_OK || !env) { - throw duckdb::InvalidInputException("Failed to attach current thread to JVM"); + throw ScalarFunctionException("Failed to attach current thread to JVM"); } detach_when_done = true; } @@ -106,7 +121,7 @@ static std::string consume_java_exception_message(JNIEnv *env) { static void get_or_attach_jni_env(JavaVM *vm, JNIEnv *&env, bool &detach_when_done) { if (!vm) { - throw duckdb::InvalidInputException("JVM is not available"); + throw ScalarFunctionException("JVM is not available"); } detach_when_done = false; @@ -115,12 +130,12 @@ static void get_or_attach_jni_env(JavaVM *vm, JNIEnv *&env, bool &detach_when_do return; } if (get_env_status != JNI_EDETACHED) { - throw duckdb::InvalidInputException("Failed to get JNI environment"); + throw ScalarFunctionException("Failed to get JNI environment"); } auto attach_status = vm->AttachCurrentThread(reinterpret_cast(&env), nullptr); if (attach_status != JNI_OK || !env) { - throw duckdb::InvalidInputException("Failed to attach current thread to JVM"); + throw ScalarFunctionException("Failed to attach current thread to JVM"); } detach_when_done = true; } @@ -142,8 +157,8 @@ static void execute_java_scalar_function(JNIEnv *env, JavaScalarFunctionState &s } if (env->ExceptionCheck()) { - throw duckdb::InvalidInputException("Java scalar function wrapper threw exception: %s", - consume_java_exception_message(env)); + throw ScalarFunctionException("Java scalar function wrapper threw exception: " + + consume_java_exception_message(env)); } } @@ -158,33 +173,28 @@ static jmethodID get_scalar_callback_method(JNIEnv *env, jobject function_j, con env->DeleteLocalRef(callback_class); if (!apply_method || env->ExceptionCheck()) { consume_java_exception_message(env); - throw duckdb::InvalidInputException("%s", error_message); + throw ScalarFunctionException(error_message); } return apply_method; } -void duckdb_jdbc_scalar_function_set_function(JNIEnv *env, jobject conn_ref_buf, jobject scalar_function_buf, - jobject function_j) { - auto connection = get_connection(env, conn_ref_buf); - if (!connection) { - throw duckdb::InvalidInputException("Invalid connection"); - } +void scalar_function_set_function(JNIEnv *env, jobject scalar_function_buf, jobject function_j) { auto scalar_function = scalar_function_buf_to_scalar_function(env, scalar_function_buf); if (env->ExceptionCheck()) { return; } if (!function_j) { - throw duckdb::InvalidInputException("Invalid scalar function callback"); + throw ScalarFunctionException("Invalid scalar function callback"); } JavaVM *vm = nullptr; if (env->GetJavaVM(&vm) != JNI_OK || !vm) { - throw duckdb::InvalidInputException("Failed to get JVM reference"); + throw ScalarFunctionException("Failed to get JVM reference"); } auto callback_ref = env->NewGlobalRef(function_j); if (!callback_ref) { - throw duckdb::InvalidInputException("Could not create global reference for scalar function callback"); + throw ScalarFunctionException("Could not create global reference for scalar function callback"); } try { diff --git a/src/jni/scalar_functions.hpp b/src/jni/scalar_functions.hpp index d5bea4265..966213bb0 100644 --- a/src/jni/scalar_functions.hpp +++ b/src/jni/scalar_functions.hpp @@ -2,5 +2,4 @@ #include "bindings.hpp" -void duckdb_jdbc_scalar_function_set_function(JNIEnv *env, jobject conn_ref_buf, jobject scalar_function_buf, - jobject function_j); +void scalar_function_set_function(JNIEnv *env, jobject scalar_function_buf, jobject function_j); diff --git a/src/main/java/org/duckdb/DuckDBBindings.java b/src/main/java/org/duckdb/DuckDBBindings.java index 56a4c5689..eb535881b 100644 --- a/src/main/java/org/duckdb/DuckDBBindings.java +++ b/src/main/java/org/duckdb/DuckDBBindings.java @@ -29,16 +29,18 @@ public class DuckDBBindings { static native void duckdb_scalar_function_set_return_type(ByteBuffer scalarFunction, ByteBuffer logicalType); + static native void duckdb_scalar_function_set_varargs(ByteBuffer scalarFunction, ByteBuffer logicalType); + + static native void duckdb_scalar_function_set_volatile(ByteBuffer scalarFunction); + + static native void duckdb_scalar_function_set_special_handling(ByteBuffer scalarFunction); + static native int duckdb_register_scalar_function(ByteBuffer connection, ByteBuffer scalarFunction); - static native void duckdb_scalar_function_set_function(ByteBuffer connection, ByteBuffer scalarFunction, - Object function); + static native void duckdb_scalar_function_set_function(ByteBuffer scalarFunction, Object function); static native void duckdb_scalar_function_set_error(ByteBuffer functionInfo, byte[] error); - static native byte[] duckdb_jdbc_varchar_string_bytes(ByteBuffer vectorData, ByteBuffer validity, long rowCount, - long row); - // logical type static native ByteBuffer duckdb_create_logical_type(int duckdb_type); @@ -89,6 +91,8 @@ static native byte[] duckdb_jdbc_varchar_string_bytes(ByteBuffer vectorData, Byt static native void duckdb_vector_assign_string_element_len(ByteBuffer vector, long index, byte[] str); + static native byte[] duckdb_vector_get_string(ByteBuffer vectorData, long row); + static native ByteBuffer duckdb_list_vector_get_child(ByteBuffer vector); static native long duckdb_list_vector_get_size(ByteBuffer vector); diff --git a/src/main/java/org/duckdb/DuckDBConnection.java b/src/main/java/org/duckdb/DuckDBConnection.java index b27a39597..d51c0c00e 100644 --- a/src/main/java/org/duckdb/DuckDBConnection.java +++ b/src/main/java/org/duckdb/DuckDBConnection.java @@ -500,53 +500,6 @@ public void registerArrowStream(String name, Object arrow_array_stream) { } } - public void registerScalarFunction(String name, DuckDBLogicalType[] parameterTypes, DuckDBLogicalType returnType, - DuckDBScalarFunction function) throws SQLException { - checkOpen(); - connRefLock.lock(); - ByteBuffer scalarFunction = null; - try { - checkOpen(); - if (name == null || name.trim().isEmpty()) { - throw new SQLException("Function name cannot be null or empty"); - } - if (parameterTypes == null) { - throw new SQLException("Parameter types cannot be null"); - } - for (int i = 0; i < parameterTypes.length; i++) { - if (parameterTypes[i] == null) { - throw new SQLException("Parameter type at index " + i + " cannot be null"); - } - } - if (returnType == null) { - throw new SQLException("Return type cannot be null"); - } - if (function == null) { - throw new SQLException("Scalar function callback cannot be null"); - } - - scalarFunction = DuckDBBindings.duckdb_create_scalar_function(); - DuckDBBindings.duckdb_scalar_function_set_name(scalarFunction, name.getBytes(UTF_8)); - - for (int i = 0; i < parameterTypes.length; i++) { - DuckDBBindings.duckdb_scalar_function_add_parameter(scalarFunction, parameterTypes[i].logicalTypeRef()); - } - - DuckDBBindings.duckdb_scalar_function_set_return_type(scalarFunction, returnType.logicalTypeRef()); - DuckDBBindings.duckdb_scalar_function_set_function(connRef, scalarFunction, - new DuckDBScalarFunctionWrapper(function)); - - if (DuckDBBindings.duckdb_register_scalar_function(connRef, scalarFunction) != 0) { - throw new SQLException("Failed to register scalar function '" + name + "'"); - } - } finally { - if (scalarFunction != null) { - DuckDBBindings.duckdb_destroy_scalar_function(scalarFunction); - } - connRefLock.unlock(); - } - } - public String getProfilingInformation(ProfilerPrintFormat format) throws SQLException { checkOpen(); connRefLock.lock(); diff --git a/src/main/java/org/duckdb/DuckDBDataChunkReader.java b/src/main/java/org/duckdb/DuckDBDataChunkReader.java index af628f5a3..a628ce1ff 100644 --- a/src/main/java/org/duckdb/DuckDBDataChunkReader.java +++ b/src/main/java/org/duckdb/DuckDBDataChunkReader.java @@ -7,7 +7,7 @@ public final class DuckDBDataChunkReader { private final ByteBuffer chunkRef; - private final int rowCount; + private final long rowCount; private final int columnCount; private final DuckDBReadableVector[] vectors; @@ -16,12 +16,12 @@ public final class DuckDBDataChunkReader { throw new SQLException("Invalid data chunk reference"); } this.chunkRef = chunkRef; - this.rowCount = (int) duckdb_data_chunk_get_size(chunkRef); - this.columnCount = (int) duckdb_data_chunk_get_column_count(chunkRef); + this.rowCount = duckdb_data_chunk_get_size(chunkRef); + this.columnCount = Math.toIntExact(duckdb_data_chunk_get_column_count(chunkRef)); this.vectors = new DuckDBReadableVector[columnCount]; } - public int rowCount() { + public long rowCount() { return rowCount; } @@ -36,7 +36,7 @@ public DuckDBReadableVector vector(int columnIndex) throws SQLException { DuckDBReadableVector vector = vectors[columnIndex]; if (vector == null) { ByteBuffer vectorRef = duckdb_data_chunk_get_vector(chunkRef, columnIndex); - vector = new DuckDBReadableVector(vectorRef, rowCount); + vector = new DuckDBReadableVectorImpl(vectorRef, rowCount); vectors[columnIndex] = vector; } return vector; diff --git a/src/main/java/org/duckdb/DuckDBDriver.java b/src/main/java/org/duckdb/DuckDBDriver.java index e045d14f7..9e9c4cffa 100644 --- a/src/main/java/org/duckdb/DuckDBDriver.java +++ b/src/main/java/org/duckdb/DuckDBDriver.java @@ -42,6 +42,9 @@ public class DuckDBDriver implements java.sql.Driver { private static boolean pinnedDbRefsShutdownHookRegistered = false; private static boolean pinnedDbRefsShutdownHookRun = false; + private static final ArrayList functionsRegistry = new ArrayList<>(); + private static final ReentrantLock functionsRegistryLock = new ReentrantLock(); + private static final Set supportedOptions = new LinkedHashSet<>(); private static final ReentrantLock supportedOptionsLock = new ReentrantLock(); @@ -263,6 +266,33 @@ public static boolean shutdownQueryCancelScheduler() { return true; } + public static List registeredFunctions() { + functionsRegistryLock.lock(); + try { + return Collections.unmodifiableList(new ArrayList<>(functionsRegistry)); + } finally { + functionsRegistryLock.unlock(); + } + } + + public static void clearFunctionsRegistry() { + functionsRegistryLock.lock(); + try { + functionsRegistry.clear(); + } finally { + functionsRegistryLock.unlock(); + } + } + + static void registerFunction(DuckDBRegisteredFunction function) { + functionsRegistryLock.lock(); + try { + functionsRegistry.add(function); + } finally { + functionsRegistryLock.unlock(); + } + } + private static DriverPropertyInfo createDriverPropInfo(String name, String value, String description) { DriverPropertyInfo dpi = new DriverPropertyInfo(name, value); dpi.description = description; diff --git a/src/main/java/org/duckdb/DuckDBFunctionKind.java b/src/main/java/org/duckdb/DuckDBFunctionKind.java new file mode 100644 index 000000000..1ec4c8f46 --- /dev/null +++ b/src/main/java/org/duckdb/DuckDBFunctionKind.java @@ -0,0 +1,3 @@ +package org.duckdb; + +public enum DuckDBFunctionKind { SCALAR } diff --git a/src/main/java/org/duckdb/DuckDBFunctions.java b/src/main/java/org/duckdb/DuckDBFunctions.java new file mode 100644 index 000000000..a8e1dad3e --- /dev/null +++ b/src/main/java/org/duckdb/DuckDBFunctions.java @@ -0,0 +1,12 @@ +package org.duckdb; + +import java.sql.SQLException; + +public final class DuckDBFunctions { + private DuckDBFunctions() { + } + + public static DuckDBScalarFunctionBuilder scalarFunction() throws SQLException { + return new DuckDBScalarFunctionBuilder(); + } +} diff --git a/src/main/java/org/duckdb/DuckDBReadableVector.java b/src/main/java/org/duckdb/DuckDBReadableVector.java index c2db43a7b..b997f53a1 100644 --- a/src/main/java/org/duckdb/DuckDBReadableVector.java +++ b/src/main/java/org/duckdb/DuckDBReadableVector.java @@ -1,262 +1,57 @@ package org.duckdb; -import static java.nio.charset.StandardCharsets.UTF_8; -import static org.duckdb.DuckDBBindings.*; - import java.math.BigDecimal; import java.math.BigInteger; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; import java.sql.Date; import java.sql.SQLException; import java.sql.Timestamp; -import java.time.Instant; import java.time.LocalDate; import java.time.LocalDateTime; import java.time.OffsetDateTime; -import java.time.ZoneId; -import java.time.temporal.ChronoUnit; - -public final class DuckDBReadableVector { - private static final BigDecimal ULONG_MULTIPLIER = new BigDecimal("18446744073709551616"); - private static final ByteOrder NATIVE_ORDER = ByteOrder.nativeOrder(); - - private final ByteBuffer vectorRef; - private final int rowCount; - private final DuckDBVectorTypeInfo typeInfo; - private final ByteBuffer data; - private final ByteBuffer validity; - - DuckDBReadableVector(ByteBuffer vectorRef, int rowCount) throws SQLException { - if (vectorRef == null) { - throw new SQLException("Invalid vector reference"); - } - this.vectorRef = vectorRef; - this.rowCount = rowCount; - this.typeInfo = DuckDBVectorTypeInfo.fromVector(vectorRef); - this.data = duckdb_vector_get_data(vectorRef, (long) rowCount * typeInfo.widthBytes); - this.validity = duckdb_vector_get_validity(vectorRef, rowCount); - } - - public DuckDBColumnType getType() { - return typeInfo.columnType; - } - - public int rowCount() { - return rowCount; - } - - public boolean isNull(int row) { - checkRowIndex(row); - if (validity == null) { - return false; - } - int entryPos = (row / 64) * Long.BYTES; - long mask = validity.order(NATIVE_ORDER).getLong(entryPos); - return (mask & (1L << (row % 64))) == 0; - } - - public boolean getBoolean(int row) throws SQLException { - checkRowIndex(row); - requireType(DuckDBColumnType.BOOLEAN); - return data.get(row) != 0; - } - - public byte getByte(int row) throws SQLException { - checkRowIndex(row); - requireType(DuckDBColumnType.TINYINT); - return data.get(row); - } +import java.util.stream.LongStream; - public short getShort(int row) throws SQLException { - checkRowIndex(row); - requireType(DuckDBColumnType.SMALLINT); - return data.order(NATIVE_ORDER).getShort(row * Short.BYTES); - } +public interface DuckDBReadableVector { + DuckDBColumnType getType(); - public short getUint8(int row) throws SQLException { - checkRowIndex(row); - requireType(DuckDBColumnType.UTINYINT); - return (short) Byte.toUnsignedInt(data.get(row)); - } + long rowCount(); - public int getUint16(int row) throws SQLException { - checkRowIndex(row); - requireType(DuckDBColumnType.USMALLINT); - return Short.toUnsignedInt(data.order(NATIVE_ORDER).getShort(row * Short.BYTES)); - } + LongStream rowIndexStream(); - public int getInt(int row) throws SQLException { - checkRowIndex(row); - requireType(DuckDBColumnType.INTEGER); - return data.order(NATIVE_ORDER).getInt(row * Integer.BYTES); - } + boolean isNull(long row); - public long getUint32(int row) throws SQLException { - checkRowIndex(row); - requireType(DuckDBColumnType.UINTEGER); - return Integer.toUnsignedLong(data.order(NATIVE_ORDER).getInt(row * Integer.BYTES)); - } + boolean getBoolean(long row) throws SQLException; - public long getLong(int row) throws SQLException { - checkRowIndex(row); - requireType(DuckDBColumnType.BIGINT); - return data.order(NATIVE_ORDER).getLong(row * Long.BYTES); - } + byte getByte(long row) throws SQLException; - public BigInteger getUint64(int row) throws SQLException { - checkRowIndex(row); - requireType(DuckDBColumnType.UBIGINT); - long value = data.order(NATIVE_ORDER).getLong(row * Long.BYTES); - return unsignedLongToBigInteger(value); - } + short getShort(long row) throws SQLException; - public float getFloat(int row) throws SQLException { - checkRowIndex(row); - requireType(DuckDBColumnType.FLOAT); - return data.order(NATIVE_ORDER).getFloat(row * Float.BYTES); - } + short getUint8(long row) throws SQLException; - public double getDouble(int row) throws SQLException { - checkRowIndex(row); - requireType(DuckDBColumnType.DOUBLE); - return data.order(NATIVE_ORDER).getDouble(row * Double.BYTES); - } + int getUint16(long row) throws SQLException; - public LocalDate getLocalDate(int row) throws SQLException { - checkRowIndex(row); - requireType(DuckDBColumnType.DATE); - return LocalDate.ofEpochDay(data.order(NATIVE_ORDER).getInt(row * Integer.BYTES)); - } + int getInt(long row) throws SQLException; - public Date getDate(int row) throws SQLException { - return Date.valueOf(getLocalDate(row)); - } + long getUint32(long row) throws SQLException; - public LocalDateTime getLocalDateTime(int row) throws SQLException { - checkRowIndex(row); - requireTimestampType(); - long epochValue = data.order(NATIVE_ORDER).getLong(row * Long.BYTES); - switch (typeInfo.capiType) { - case DUCKDB_TYPE_TIMESTAMP_S: - return DuckDBTimestamp.localDateTimeFromTimestamp(epochValue, ChronoUnit.SECONDS, null); - case DUCKDB_TYPE_TIMESTAMP_MS: - return DuckDBTimestamp.localDateTimeFromTimestamp(epochValue, ChronoUnit.MILLIS, null); - case DUCKDB_TYPE_TIMESTAMP: - return DuckDBTimestamp.localDateTimeFromTimestamp(epochValue, ChronoUnit.MICROS, null); - case DUCKDB_TYPE_TIMESTAMP_NS: - return DuckDBTimestamp.localDateTimeFromTimestamp(epochValue, ChronoUnit.NANOS, null); - case DUCKDB_TYPE_TIMESTAMP_TZ: - return DuckDBTimestamp.localDateTimeFromTimestampWithTimezone(epochValue, ChronoUnit.MICROS, null); - default: - throw new SQLException("Expected vector type TIMESTAMP*, found " + typeInfo.columnType); - } - } + long getLong(long row) throws SQLException; - public Timestamp getTimestamp(int row) throws SQLException { - return Timestamp.valueOf(getLocalDateTime(row)); - } + BigInteger getUint64(long row) throws SQLException; - public OffsetDateTime getOffsetDateTime(int row) throws SQLException { - checkRowIndex(row); - requireType(DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE); - long micros = data.order(NATIVE_ORDER).getLong(row * Long.BYTES); - Instant instant = instantFromEpoch(micros, ChronoUnit.MICROS); - return instant.atZone(ZoneId.systemDefault()).toOffsetDateTime(); - } + float getFloat(long row) throws SQLException; - public BigDecimal getBigDecimal(int row) throws SQLException { - checkRowIndex(row); - requireType(DuckDBColumnType.DECIMAL); - switch (typeInfo.storageType) { - case DUCKDB_TYPE_SMALLINT: - return BigDecimal.valueOf(data.order(NATIVE_ORDER).getShort(row * Short.BYTES), typeInfo.decimalMeta.scale); - case DUCKDB_TYPE_INTEGER: - return BigDecimal.valueOf(data.order(NATIVE_ORDER).getInt(row * Integer.BYTES), typeInfo.decimalMeta.scale); - case DUCKDB_TYPE_BIGINT: - return BigDecimal.valueOf(data.order(NATIVE_ORDER).getLong(row * Long.BYTES), typeInfo.decimalMeta.scale); - case DUCKDB_TYPE_HUGEINT: { - ByteBuffer slice = data.duplicate().order(NATIVE_ORDER); - slice.position(row * typeInfo.widthBytes); - long lower = slice.getLong(); - long upper = slice.getLong(); - return new BigDecimal(upper) - .multiply(ULONG_MULTIPLIER) - .add(new BigDecimal(Long.toUnsignedString(lower))) - .scaleByPowerOfTen(typeInfo.decimalMeta.scale * -1); - } - default: - throw new SQLException("Unsupported DECIMAL storage type: " + typeInfo.storageType); - } - } + double getDouble(long row) throws SQLException; - public String getString(int row) throws SQLException { - checkRowIndex(row); - requireType(DuckDBColumnType.VARCHAR); - if (isNull(row)) { - return null; - } - byte[] bytes = duckdb_jdbc_varchar_string_bytes(data, validity, rowCount, row); - if (bytes == null) { - return null; - } - return new String(bytes, UTF_8); - } + LocalDate getLocalDate(long row) throws SQLException; - ByteBuffer vectorRef() { - return vectorRef; - } + Date getDate(long row) throws SQLException; - private void requireType(DuckDBColumnType expected) throws SQLException { - if (typeInfo.columnType != expected) { - throw new SQLException("Expected vector type " + expected + ", found " + typeInfo.columnType); - } - } + LocalDateTime getLocalDateTime(long row) throws SQLException; - private void requireTimestampType() throws SQLException { - switch (typeInfo.columnType) { - case TIMESTAMP: - case TIMESTAMP_S: - case TIMESTAMP_MS: - case TIMESTAMP_NS: - case TIMESTAMP_WITH_TIME_ZONE: - return; - default: - throw new SQLException("Expected vector type TIMESTAMP*, found " + typeInfo.columnType); - } - } + Timestamp getTimestamp(long row) throws SQLException; - private void checkRowIndex(int row) { - if (row < 0 || row >= rowCount) { - throw new IndexOutOfBoundsException("Row index out of bounds: " + row); - } - } + OffsetDateTime getOffsetDateTime(long row) throws SQLException; - private static Instant instantFromEpoch(long value, ChronoUnit unit) throws SQLException { - switch (unit) { - case SECONDS: - return Instant.ofEpochSecond(value); - case MILLIS: - return Instant.ofEpochMilli(value); - case MICROS: { - long epochSecond = Math.floorDiv(value, 1_000_000L); - long nanoAdjustment = Math.floorMod(value, 1_000_000L) * 1000L; - return Instant.ofEpochSecond(epochSecond, nanoAdjustment); - } - case NANOS: { - long epochSecond = Math.floorDiv(value, 1_000_000_000L); - long nanoAdjustment = Math.floorMod(value, 1_000_000_000L); - return Instant.ofEpochSecond(epochSecond, nanoAdjustment); - } - default: - throw new SQLException("Unsupported unit type: " + unit); - } - } + BigDecimal getBigDecimal(long row) throws SQLException; - private static BigInteger unsignedLongToBigInteger(long value) { - if (value >= 0) { - return BigInteger.valueOf(value); - } - return BigInteger.valueOf(value & Long.MAX_VALUE).setBit(Long.SIZE - 1); - } + String getString(long row) throws SQLException; } diff --git a/src/main/java/org/duckdb/DuckDBReadableVectorImpl.java b/src/main/java/org/duckdb/DuckDBReadableVectorImpl.java new file mode 100644 index 000000000..1f3d27184 --- /dev/null +++ b/src/main/java/org/duckdb/DuckDBReadableVectorImpl.java @@ -0,0 +1,286 @@ +package org.duckdb; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.duckdb.DuckDBBindings.*; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.sql.Date; +import java.sql.SQLException; +import java.sql.Timestamp; +import java.time.Instant; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.OffsetDateTime; +import java.time.ZoneId; +import java.time.temporal.ChronoUnit; +import java.util.stream.LongStream; + +final class DuckDBReadableVectorImpl implements DuckDBReadableVector { + private static final BigDecimal ULONG_MULTIPLIER = new BigDecimal("18446744073709551616"); + private static final ByteOrder NATIVE_ORDER = ByteOrder.nativeOrder(); + + private final ByteBuffer vectorRef; + private final long rowCount; + private final DuckDBVectorTypeInfo typeInfo; + private final ByteBuffer data; + private final ByteBuffer validity; + + DuckDBReadableVectorImpl(ByteBuffer vectorRef, long rowCount) throws SQLException { + if (vectorRef == null) { + throw new SQLException("Invalid vector reference"); + } + this.vectorRef = vectorRef; + this.rowCount = rowCount; + this.typeInfo = DuckDBVectorTypeInfo.fromVector(vectorRef); + this.data = duckdb_vector_get_data(vectorRef, Math.multiplyExact(rowCount, typeInfo.widthBytes)); + this.validity = duckdb_vector_get_validity(vectorRef, rowCount); + } + + @Override + public DuckDBColumnType getType() { + return typeInfo.columnType; + } + + @Override + public long rowCount() { + return rowCount; + } + + @Override + public LongStream rowIndexStream() { + return LongStream.range(0, rowCount); + } + + @Override + public boolean isNull(long row) { + checkRowIndex(row); + if (validity == null) { + return false; + } + int entryPos = Math.toIntExact(Math.multiplyExact(row / Long.SIZE, (long) Long.BYTES)); + long mask = validity.order(NATIVE_ORDER).getLong(entryPos); + return (mask & (1L << (row % Long.SIZE))) == 0; + } + + @Override + public boolean getBoolean(long row) throws SQLException { + requireType(DuckDBColumnType.BOOLEAN); + return data.get(checkedRowIndex(row)) != 0; + } + + @Override + public byte getByte(long row) throws SQLException { + requireType(DuckDBColumnType.TINYINT); + return data.get(checkedRowIndex(row)); + } + + @Override + public short getShort(long row) throws SQLException { + requireType(DuckDBColumnType.SMALLINT); + return data.order(NATIVE_ORDER).getShort(checkedByteOffset(row, Short.BYTES)); + } + + @Override + public short getUint8(long row) throws SQLException { + requireType(DuckDBColumnType.UTINYINT); + return (short) Byte.toUnsignedInt(data.get(checkedRowIndex(row))); + } + + @Override + public int getUint16(long row) throws SQLException { + requireType(DuckDBColumnType.USMALLINT); + return Short.toUnsignedInt(data.order(NATIVE_ORDER).getShort(checkedByteOffset(row, Short.BYTES))); + } + + @Override + public int getInt(long row) throws SQLException { + requireType(DuckDBColumnType.INTEGER); + return data.order(NATIVE_ORDER).getInt(checkedByteOffset(row, Integer.BYTES)); + } + + @Override + public long getUint32(long row) throws SQLException { + requireType(DuckDBColumnType.UINTEGER); + return Integer.toUnsignedLong(data.order(NATIVE_ORDER).getInt(checkedByteOffset(row, Integer.BYTES))); + } + + @Override + public long getLong(long row) throws SQLException { + requireType(DuckDBColumnType.BIGINT); + return data.order(NATIVE_ORDER).getLong(checkedByteOffset(row, Long.BYTES)); + } + + @Override + public BigInteger getUint64(long row) throws SQLException { + requireType(DuckDBColumnType.UBIGINT); + long value = data.order(NATIVE_ORDER).getLong(checkedByteOffset(row, Long.BYTES)); + return unsignedLongToBigInteger(value); + } + + @Override + public float getFloat(long row) throws SQLException { + requireType(DuckDBColumnType.FLOAT); + return data.order(NATIVE_ORDER).getFloat(checkedByteOffset(row, Float.BYTES)); + } + + @Override + public double getDouble(long row) throws SQLException { + requireType(DuckDBColumnType.DOUBLE); + return data.order(NATIVE_ORDER).getDouble(checkedByteOffset(row, Double.BYTES)); + } + + @Override + public LocalDate getLocalDate(long row) throws SQLException { + requireType(DuckDBColumnType.DATE); + return LocalDate.ofEpochDay(data.order(NATIVE_ORDER).getInt(checkedByteOffset(row, Integer.BYTES))); + } + + @Override + public Date getDate(long row) throws SQLException { + return Date.valueOf(getLocalDate(row)); + } + + @Override + public LocalDateTime getLocalDateTime(long row) throws SQLException { + requireTimestampType(); + long epochValue = data.order(NATIVE_ORDER).getLong(checkedByteOffset(row, Long.BYTES)); + switch (typeInfo.capiType) { + case DUCKDB_TYPE_TIMESTAMP_S: + return DuckDBTimestamp.localDateTimeFromTimestamp(epochValue, ChronoUnit.SECONDS, null); + case DUCKDB_TYPE_TIMESTAMP_MS: + return DuckDBTimestamp.localDateTimeFromTimestamp(epochValue, ChronoUnit.MILLIS, null); + case DUCKDB_TYPE_TIMESTAMP: + return DuckDBTimestamp.localDateTimeFromTimestamp(epochValue, ChronoUnit.MICROS, null); + case DUCKDB_TYPE_TIMESTAMP_NS: + return DuckDBTimestamp.localDateTimeFromTimestamp(epochValue, ChronoUnit.NANOS, null); + case DUCKDB_TYPE_TIMESTAMP_TZ: + return DuckDBTimestamp.localDateTimeFromTimestampWithTimezone(epochValue, ChronoUnit.MICROS, null); + default: + throw new SQLException("Expected vector type TIMESTAMP*, found " + typeInfo.columnType); + } + } + + @Override + public Timestamp getTimestamp(long row) throws SQLException { + return Timestamp.valueOf(getLocalDateTime(row)); + } + + @Override + public OffsetDateTime getOffsetDateTime(long row) throws SQLException { + requireType(DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE); + long micros = data.order(NATIVE_ORDER).getLong(checkedByteOffset(row, Long.BYTES)); + Instant instant = instantFromEpoch(micros, ChronoUnit.MICROS); + return instant.atZone(ZoneId.systemDefault()).toOffsetDateTime(); + } + + @Override + public BigDecimal getBigDecimal(long row) throws SQLException { + requireType(DuckDBColumnType.DECIMAL); + switch (typeInfo.storageType) { + case DUCKDB_TYPE_SMALLINT: + return BigDecimal.valueOf(data.order(NATIVE_ORDER).getShort(checkedByteOffset(row, Short.BYTES)), + typeInfo.decimalMeta.scale); + case DUCKDB_TYPE_INTEGER: + return BigDecimal.valueOf(data.order(NATIVE_ORDER).getInt(checkedByteOffset(row, Integer.BYTES)), + typeInfo.decimalMeta.scale); + case DUCKDB_TYPE_BIGINT: + return BigDecimal.valueOf(data.order(NATIVE_ORDER).getLong(checkedByteOffset(row, Long.BYTES)), + typeInfo.decimalMeta.scale); + case DUCKDB_TYPE_HUGEINT: { + ByteBuffer slice = data.duplicate().order(NATIVE_ORDER); + slice.position(checkedByteOffset(row, typeInfo.widthBytes)); + long lower = slice.getLong(); + long upper = slice.getLong(); + return new BigDecimal(upper) + .multiply(ULONG_MULTIPLIER) + .add(new BigDecimal(Long.toUnsignedString(lower))) + .scaleByPowerOfTen(typeInfo.decimalMeta.scale * -1); + } + default: + throw new SQLException("Unsupported DECIMAL storage type: " + typeInfo.storageType); + } + } + + @Override + public String getString(long row) throws SQLException { + requireType(DuckDBColumnType.VARCHAR); + if (isNull(row)) { + return null; + } + byte[] bytes = duckdb_vector_get_string(data, row); + if (bytes == null) { + return null; + } + return new String(bytes, UTF_8); + } + + ByteBuffer vectorRef() { + return vectorRef; + } + + private void requireType(DuckDBColumnType expected) throws SQLException { + if (typeInfo.columnType != expected) { + throw new SQLException("Expected vector type " + expected + ", found " + typeInfo.columnType); + } + } + + private void requireTimestampType() throws SQLException { + switch (typeInfo.columnType) { + case TIMESTAMP: + case TIMESTAMP_S: + case TIMESTAMP_MS: + case TIMESTAMP_NS: + case TIMESTAMP_WITH_TIME_ZONE: + return; + default: + throw new SQLException("Expected vector type TIMESTAMP*, found " + typeInfo.columnType); + } + } + + private void checkRowIndex(long row) { + if (row < 0 || row >= rowCount) { + throw new IndexOutOfBoundsException("Row index out of bounds: " + row); + } + } + + private int checkedRowIndex(long row) { + checkRowIndex(row); + return Math.toIntExact(row); + } + + private int checkedByteOffset(long row, int elementWidth) { + checkRowIndex(row); + return Math.toIntExact(Math.multiplyExact(row, (long) elementWidth)); + } + + private static Instant instantFromEpoch(long value, ChronoUnit unit) throws SQLException { + switch (unit) { + case SECONDS: + return Instant.ofEpochSecond(value); + case MILLIS: + return Instant.ofEpochMilli(value); + case MICROS: { + long epochSecond = Math.floorDiv(value, 1_000_000L); + long nanoAdjustment = Math.floorMod(value, 1_000_000L) * 1000L; + return Instant.ofEpochSecond(epochSecond, nanoAdjustment); + } + case NANOS: { + long epochSecond = Math.floorDiv(value, 1_000_000_000L); + long nanoAdjustment = Math.floorMod(value, 1_000_000_000L); + return Instant.ofEpochSecond(epochSecond, nanoAdjustment); + } + default: + throw new SQLException("Unsupported unit type: " + unit); + } + } + + private static BigInteger unsignedLongToBigInteger(long value) { + if (value >= 0) { + return BigInteger.valueOf(value); + } + return BigInteger.valueOf(value & Long.MAX_VALUE).setBit(Long.SIZE - 1); + } +} diff --git a/src/main/java/org/duckdb/DuckDBRegisteredFunction.java b/src/main/java/org/duckdb/DuckDBRegisteredFunction.java new file mode 100644 index 000000000..f824a00b9 --- /dev/null +++ b/src/main/java/org/duckdb/DuckDBRegisteredFunction.java @@ -0,0 +1,97 @@ +package org.duckdb; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +public final class DuckDBRegisteredFunction { + private final String name; + private final DuckDBFunctionKind functionKind; + private final List parameterTypes; + private final List parameterColumnTypes; + private final DuckDBLogicalType returnType; + private final DuckDBColumnType returnColumnType; + private final DuckDBScalarFunction function; + private final DuckDBLogicalType varArgType; + private final boolean volatileFlag; + private final boolean specialHandlingFlag; + private final boolean propagateNullsFlag; + + private DuckDBRegisteredFunction(String name, DuckDBFunctionKind functionKind, + List parameterTypes, + List parameterColumnTypes, DuckDBLogicalType returnType, + DuckDBColumnType returnColumnType, DuckDBScalarFunction function, + DuckDBLogicalType varArgType, boolean volatileFlag, boolean specialHandlingFlag, + boolean propagateNullsFlag) { + this.name = name; + this.functionKind = functionKind; + this.parameterTypes = parameterTypes; + this.parameterColumnTypes = parameterColumnTypes; + this.returnType = returnType; + this.returnColumnType = returnColumnType; + this.function = function; + this.varArgType = varArgType; + this.volatileFlag = volatileFlag; + this.specialHandlingFlag = specialHandlingFlag; + this.propagateNullsFlag = propagateNullsFlag; + } + + public String name() { + return name; + } + + public DuckDBFunctionKind functionKind() { + return functionKind; + } + + public List parameterTypes() { + return parameterTypes; + } + + public List parameterColumnTypes() { + return parameterColumnTypes; + } + + public DuckDBLogicalType returnType() { + return returnType; + } + + public DuckDBColumnType returnColumnType() { + return returnColumnType; + } + + public DuckDBScalarFunction function() { + return function; + } + + public DuckDBLogicalType varArgType() { + return varArgType; + } + + public boolean isVolatile() { + return volatileFlag; + } + + public boolean hasSpecialHandling() { + return specialHandlingFlag; + } + + public boolean propagateNulls() { + return propagateNullsFlag; + } + + public boolean isScalar() { + return functionKind == DuckDBFunctionKind.SCALAR; + } + + static DuckDBRegisteredFunction of(String name, List parameterTypes, + List parameterColumnTypes, DuckDBLogicalType returnType, + DuckDBColumnType returnColumnType, DuckDBScalarFunction function, + DuckDBLogicalType varArgType, boolean volatileFlag, boolean specialHandlingFlag, + boolean propagateNullsFlag) { + return new DuckDBRegisteredFunction( + name, DuckDBFunctionKind.SCALAR, Collections.unmodifiableList(new ArrayList<>(parameterTypes)), + Collections.unmodifiableList(new ArrayList<>(parameterColumnTypes)), returnType, returnColumnType, function, + varArgType, volatileFlag, specialHandlingFlag, propagateNullsFlag); + } +} diff --git a/src/main/java/org/duckdb/DuckDBScalarContext.java b/src/main/java/org/duckdb/DuckDBScalarContext.java new file mode 100644 index 000000000..79689175b --- /dev/null +++ b/src/main/java/org/duckdb/DuckDBScalarContext.java @@ -0,0 +1,88 @@ +package org.duckdb; + +import java.sql.SQLException; +import java.util.stream.LongStream; +import java.util.stream.Stream; + +public final class DuckDBScalarContext { + private final DuckDBDataChunkReader input; + private final DuckDBWritableVector output; + private final boolean propagateNulls; + + DuckDBScalarContext(DuckDBDataChunkReader input, DuckDBWritableVector output, boolean propagateNulls) { + if (input == null) { + throw new IllegalArgumentException("Input chunk cannot be null"); + } + if (output == null) { + throw new IllegalArgumentException("Output vector cannot be null"); + } + this.input = input; + this.output = output; + this.propagateNulls = propagateNulls; + } + + public long rowCount() { + return input.rowCount(); + } + + public int columnCount() { + return input.columnCount(); + } + + public DuckDBReadableVector input(int columnIndex) throws SQLException { + return input.vector(columnIndex); + } + + public DuckDBWritableVector output() { + return output; + } + + public boolean propagateNulls() { + return propagateNulls; + } + + DuckDBDataChunkReader inputChunk() { + return input; + } + + public Stream stream() { + LongStream rows = LongStream.range(0, rowCount()); + if (propagateNulls) { + rows = rows.filter(this::rowHasNoNullInputs); + } + return rows.mapToObj(this::row); + } + + public DuckDBScalarRow row(long rowIndex) { + checkRowIndex(rowIndex); + return new DuckDBScalarRow(this, rowIndex); + } + + DuckDBReadableVector inputUnchecked(int columnIndex) { + try { + return input(columnIndex); + } catch (SQLException exception) { + throw new IllegalStateException("Failed to access input column " + columnIndex, exception); + } + } + + private void checkRowIndex(long rowIndex) { + if (rowIndex < 0 || rowIndex >= rowCount()) { + throw new IndexOutOfBoundsException("Row index out of bounds: " + rowIndex); + } + } + + private boolean rowHasNoNullInputs(long rowIndex) { + for (int columnIndex = 0; columnIndex < columnCount(); columnIndex++) { + if (inputUnchecked(columnIndex).isNull(rowIndex)) { + try { + output.setNull(rowIndex); + } catch (SQLException exception) { + throw new IllegalStateException("Failed to write NULL to output row " + rowIndex, exception); + } + return false; + } + } + return true; + } +} diff --git a/src/main/java/org/duckdb/DuckDBScalarFunction.java b/src/main/java/org/duckdb/DuckDBScalarFunction.java index 9ef0e6fca..beef14162 100644 --- a/src/main/java/org/duckdb/DuckDBScalarFunction.java +++ b/src/main/java/org/duckdb/DuckDBScalarFunction.java @@ -5,11 +5,11 @@ public interface DuckDBScalarFunction { /** * Processes a full input chunk and writes one output value per row directly into the DuckDB output vector. * - *

The input and output wrappers are valid only for the duration of the callback and must not be retained. + *

The context and all wrappers returned from it are valid only for the duration of the callback and must not + * be retained. * - * @param input input vectors for the current chunk - * @param out output vector for the current chunk + * @param ctx scalar function execution context for the current chunk * @throws Exception when function execution fails */ - void apply(DuckDBDataChunkReader input, DuckDBWritableVector out) throws Exception; + void apply(DuckDBScalarContext ctx) throws Exception; } diff --git a/src/main/java/org/duckdb/DuckDBScalarFunctionAdapter.java b/src/main/java/org/duckdb/DuckDBScalarFunctionAdapter.java new file mode 100644 index 000000000..dbefe0b50 --- /dev/null +++ b/src/main/java/org/duckdb/DuckDBScalarFunctionAdapter.java @@ -0,0 +1,498 @@ +package org.duckdb; + +import static org.duckdb.DuckDBBindings.*; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.sql.SQLException; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.OffsetDateTime; +import java.time.ZoneOffset; +import java.util.Date; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.function.BiFunction; +import java.util.function.DoubleBinaryOperator; +import java.util.function.DoubleUnaryOperator; +import java.util.function.Function; +import java.util.function.IntBinaryOperator; +import java.util.function.IntUnaryOperator; +import java.util.function.LongBinaryOperator; +import java.util.function.LongUnaryOperator; +import java.util.function.Supplier; + +final class DuckDBScalarFunctionAdapter { + private static final Map CODECS_BY_DUCKDB_TYPE = new LinkedHashMap<>(); + private static final Map, DuckDBColumnType> DUCKDB_TYPE_BY_JAVA_CLASS = new LinkedHashMap<>(); + private static final Map, Class> BOXED_CLASSES = new LinkedHashMap<>(); + private static final TypeCodec DATE_SQL_CODEC = + new TypeCodec(java.sql.Date.class, DuckDBReadableVector::getDate, + (vector, row, value) -> vector.setDate(row, (java.sql.Date) value)); + private static final TypeCodec TIMESTAMP_SQL_CODEC = + new TypeCodec(java.sql.Timestamp.class, DuckDBReadableVector::getTimestamp, + (vector, row, value) -> vector.setTimestamp(row, (java.sql.Timestamp) value)); + private static final TypeCodec TIMESTAMP_UTIL_DATE_CODEC = new TypeCodec( + Date.class, + (vector, row) + -> Date.from(vector.getLocalDateTime(row).toInstant(ZoneOffset.UTC)), + (vector, row, + value) -> vector.setTimestamp(row, LocalDateTime.ofInstant(((Date) value).toInstant(), ZoneOffset.UTC))); + + static { + register(DuckDBColumnType.BOOLEAN, Boolean.class, DuckDBReadableVector::getBoolean, + DuckDBWritableVector::setBoolean); + register(DuckDBColumnType.TINYINT, Byte.class, DuckDBReadableVector::getByte, DuckDBWritableVector::setByte); + register(DuckDBColumnType.UTINYINT, Short.class, DuckDBReadableVector::getUint8, + (out, row, value) -> out.setUint8(row, value)); + register(DuckDBColumnType.SMALLINT, Short.class, DuckDBReadableVector::getShort, + DuckDBWritableVector::setShort); + register(DuckDBColumnType.USMALLINT, Integer.class, DuckDBReadableVector::getUint16, + (out, row, value) -> out.setUint16(row, value)); + register(DuckDBColumnType.INTEGER, Integer.class, DuckDBReadableVector::getInt, DuckDBWritableVector::setInt); + register(DuckDBColumnType.UINTEGER, Long.class, DuckDBReadableVector::getUint32, + (out, row, value) -> out.setUint32(row, value)); + register(DuckDBColumnType.BIGINT, Long.class, DuckDBReadableVector::getLong, DuckDBWritableVector::setLong); + register(DuckDBColumnType.UBIGINT, BigInteger.class, DuckDBReadableVector::getUint64, + DuckDBWritableVector::setUint64); + register(DuckDBColumnType.FLOAT, Float.class, DuckDBReadableVector::getFloat, DuckDBWritableVector::setFloat); + register(DuckDBColumnType.DOUBLE, Double.class, DuckDBReadableVector::getDouble, + DuckDBWritableVector::setDouble); + register(DuckDBColumnType.DECIMAL, BigDecimal.class, DuckDBReadableVector::getBigDecimal, + DuckDBWritableVector::setBigDecimal); + register(DuckDBColumnType.VARCHAR, String.class, DuckDBReadableVector::getString, + DuckDBWritableVector::setString); + register(DuckDBColumnType.DATE, LocalDate.class, DuckDBReadableVector::getLocalDate, + DuckDBWritableVector::setDate); + register(DuckDBColumnType.TIMESTAMP_S, LocalDateTime.class, DuckDBReadableVector::getLocalDateTime, + DuckDBWritableVector::setTimestamp); + register(DuckDBColumnType.TIMESTAMP_MS, LocalDateTime.class, DuckDBReadableVector::getLocalDateTime, + DuckDBWritableVector::setTimestamp); + register(DuckDBColumnType.TIMESTAMP, LocalDateTime.class, DuckDBReadableVector::getLocalDateTime, + DuckDBWritableVector::setTimestamp); + register(DuckDBColumnType.TIMESTAMP_NS, LocalDateTime.class, DuckDBReadableVector::getLocalDateTime, + DuckDBWritableVector::setTimestamp); + register(DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE, OffsetDateTime.class, + DuckDBReadableVector::getOffsetDateTime, DuckDBWritableVector::setOffsetDateTime); + + DUCKDB_TYPE_BY_JAVA_CLASS.put(boolean.class, DuckDBColumnType.BOOLEAN); + DUCKDB_TYPE_BY_JAVA_CLASS.put(Boolean.class, DuckDBColumnType.BOOLEAN); + DUCKDB_TYPE_BY_JAVA_CLASS.put(byte.class, DuckDBColumnType.TINYINT); + DUCKDB_TYPE_BY_JAVA_CLASS.put(Byte.class, DuckDBColumnType.TINYINT); + DUCKDB_TYPE_BY_JAVA_CLASS.put(short.class, DuckDBColumnType.SMALLINT); + DUCKDB_TYPE_BY_JAVA_CLASS.put(Short.class, DuckDBColumnType.SMALLINT); + DUCKDB_TYPE_BY_JAVA_CLASS.put(int.class, DuckDBColumnType.INTEGER); + DUCKDB_TYPE_BY_JAVA_CLASS.put(Integer.class, DuckDBColumnType.INTEGER); + DUCKDB_TYPE_BY_JAVA_CLASS.put(long.class, DuckDBColumnType.BIGINT); + DUCKDB_TYPE_BY_JAVA_CLASS.put(Long.class, DuckDBColumnType.BIGINT); + DUCKDB_TYPE_BY_JAVA_CLASS.put(float.class, DuckDBColumnType.FLOAT); + DUCKDB_TYPE_BY_JAVA_CLASS.put(Float.class, DuckDBColumnType.FLOAT); + DUCKDB_TYPE_BY_JAVA_CLASS.put(double.class, DuckDBColumnType.DOUBLE); + DUCKDB_TYPE_BY_JAVA_CLASS.put(Double.class, DuckDBColumnType.DOUBLE); + DUCKDB_TYPE_BY_JAVA_CLASS.put(String.class, DuckDBColumnType.VARCHAR); + DUCKDB_TYPE_BY_JAVA_CLASS.put(BigDecimal.class, DuckDBColumnType.DECIMAL); + DUCKDB_TYPE_BY_JAVA_CLASS.put(BigInteger.class, DuckDBColumnType.UBIGINT); + DUCKDB_TYPE_BY_JAVA_CLASS.put(LocalDate.class, DuckDBColumnType.DATE); + DUCKDB_TYPE_BY_JAVA_CLASS.put(java.sql.Date.class, DuckDBColumnType.DATE); + DUCKDB_TYPE_BY_JAVA_CLASS.put(LocalDateTime.class, DuckDBColumnType.TIMESTAMP); + DUCKDB_TYPE_BY_JAVA_CLASS.put(java.sql.Timestamp.class, DuckDBColumnType.TIMESTAMP); + DUCKDB_TYPE_BY_JAVA_CLASS.put(Date.class, DuckDBColumnType.TIMESTAMP); + DUCKDB_TYPE_BY_JAVA_CLASS.put(OffsetDateTime.class, DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE); + + BOXED_CLASSES.put(boolean.class, Boolean.class); + BOXED_CLASSES.put(byte.class, Byte.class); + BOXED_CLASSES.put(short.class, Short.class); + BOXED_CLASSES.put(int.class, Integer.class); + BOXED_CLASSES.put(long.class, Long.class); + BOXED_CLASSES.put(float.class, Float.class); + BOXED_CLASSES.put(double.class, Double.class); + } + + static DuckDBScalarFunction unary(Function function, DuckDBColumnType parameterType, + Class parameterJavaType, DuckDBColumnType returnType, Class returnJavaType) + throws SQLException { + @SuppressWarnings("unchecked") Function typedFunction = (Function) function; + TypeCodec inCodec = codecFor(parameterType, parameterJavaType); + TypeCodec outCodec = codecFor(returnType, returnJavaType); + return ctx -> { + DuckDBReadableVector in = ctx.input(0); + DuckDBWritableVector out = ctx.output(); + long rowCount = ctx.rowCount(); + boolean propagateNulls = ctx.propagateNulls(); + for (long row = 0; row < rowCount; row++) { + if (propagateNulls && in.isNull(row)) { + out.setNull(row); + continue; + } + Object argument = in.isNull(row) ? null : inCodec.read(in, row); + Object result = typedFunction.apply(argument); + outCodec.write(out, row, result); + } + }; + } + + static DuckDBScalarFunction binary(BiFunction function, DuckDBColumnType leftType, Class leftJavaType, + DuckDBColumnType rightType, Class rightJavaType, DuckDBColumnType returnType, + Class returnJavaType) throws SQLException { + @SuppressWarnings("unchecked") + BiFunction typedFunction = (BiFunction) function; + TypeCodec leftCodec = codecFor(leftType, leftJavaType); + TypeCodec rightCodec = codecFor(rightType, rightJavaType); + TypeCodec outCodec = codecFor(returnType, returnJavaType); + return ctx -> { + DuckDBReadableVector left = ctx.input(0); + DuckDBReadableVector right = ctx.input(1); + DuckDBWritableVector out = ctx.output(); + long rowCount = ctx.rowCount(); + boolean propagateNulls = ctx.propagateNulls(); + for (long row = 0; row < rowCount; row++) { + boolean leftIsNull = left.isNull(row); + boolean rightIsNull = right.isNull(row); + if (propagateNulls && (leftIsNull || rightIsNull)) { + out.setNull(row); + continue; + } + Object leftValue = leftIsNull ? null : leftCodec.read(left, row); + Object rightValue = rightIsNull ? null : rightCodec.read(right, row); + Object result = typedFunction.apply(leftValue, rightValue); + outCodec.write(out, row, result); + } + }; + } + + static DuckDBScalarFunction intUnary(IntUnaryOperator function) { + return ctx -> { + if (!ctx.propagateNulls()) { + throw new IllegalStateException("withIntFunction requires propagateNulls(true)"); + } + DuckDBReadableVector in = ctx.input(0); + DuckDBWritableVector out = ctx.output(); + long rowCount = ctx.rowCount(); + for (long row = 0; row < rowCount; row++) { + if (in.isNull(row)) { + out.setNull(row); + } else { + out.setInt(row, function.applyAsInt(in.getInt(row))); + } + } + }; + } + + static DuckDBScalarFunction intBinary(IntBinaryOperator function) { + return ctx -> { + if (!ctx.propagateNulls()) { + throw new IllegalStateException("withIntFunction requires propagateNulls(true)"); + } + DuckDBReadableVector left = ctx.input(0); + DuckDBReadableVector right = ctx.input(1); + DuckDBWritableVector out = ctx.output(); + long rowCount = ctx.rowCount(); + for (long row = 0; row < rowCount; row++) { + if (left.isNull(row) || right.isNull(row)) { + out.setNull(row); + } else { + out.setInt(row, function.applyAsInt(left.getInt(row), right.getInt(row))); + } + } + }; + } + + static DuckDBScalarFunction doubleUnary(DoubleUnaryOperator function) { + return ctx -> { + if (!ctx.propagateNulls()) { + throw new IllegalStateException("withDoubleFunction requires propagateNulls(true)"); + } + DuckDBReadableVector in = ctx.input(0); + DuckDBWritableVector out = ctx.output(); + long rowCount = ctx.rowCount(); + for (long row = 0; row < rowCount; row++) { + if (in.isNull(row)) { + out.setNull(row); + } else { + out.setDouble(row, function.applyAsDouble(in.getDouble(row))); + } + } + }; + } + + static DuckDBScalarFunction doubleBinary(DoubleBinaryOperator function) { + return ctx -> { + if (!ctx.propagateNulls()) { + throw new IllegalStateException("withDoubleFunction requires propagateNulls(true)"); + } + DuckDBReadableVector left = ctx.input(0); + DuckDBReadableVector right = ctx.input(1); + DuckDBWritableVector out = ctx.output(); + long rowCount = ctx.rowCount(); + for (long row = 0; row < rowCount; row++) { + if (left.isNull(row) || right.isNull(row)) { + out.setNull(row); + } else { + out.setDouble(row, function.applyAsDouble(left.getDouble(row), right.getDouble(row))); + } + } + }; + } + + static DuckDBScalarFunction longUnary(LongUnaryOperator function) { + return ctx -> { + if (!ctx.propagateNulls()) { + throw new IllegalStateException("withLongFunction requires propagateNulls(true)"); + } + DuckDBReadableVector in = ctx.input(0); + DuckDBWritableVector out = ctx.output(); + long rowCount = ctx.rowCount(); + for (long row = 0; row < rowCount; row++) { + if (in.isNull(row)) { + out.setNull(row); + } else { + out.setLong(row, function.applyAsLong(in.getLong(row))); + } + } + }; + } + + static DuckDBScalarFunction longBinary(LongBinaryOperator function) { + return ctx -> { + if (!ctx.propagateNulls()) { + throw new IllegalStateException("withLongFunction requires propagateNulls(true)"); + } + DuckDBReadableVector left = ctx.input(0); + DuckDBReadableVector right = ctx.input(1); + DuckDBWritableVector out = ctx.output(); + long rowCount = ctx.rowCount(); + for (long row = 0; row < rowCount; row++) { + if (left.isNull(row) || right.isNull(row)) { + out.setNull(row); + } else { + out.setLong(row, function.applyAsLong(left.getLong(row), right.getLong(row))); + } + } + }; + } + + static DuckDBScalarFunction nullary(Supplier function, DuckDBColumnType returnType, Class returnJavaType) + throws SQLException { + @SuppressWarnings("unchecked") Supplier typedFunction = (Supplier) function; + TypeCodec outCodec = codecFor(returnType, returnJavaType); + return ctx -> { + DuckDBWritableVector out = ctx.output(); + long rowCount = ctx.rowCount(); + for (long row = 0; row < rowCount; row++) { + Object result = typedFunction.get(); + outCodec.write(out, row, result); + } + }; + } + + static DuckDBScalarFunction variadic(Function function, DuckDBColumnType[] fixedTypes, + Class[] fixedJavaTypes, DuckDBColumnType varArgType, + Class varArgJavaType, DuckDBColumnType returnType, Class returnJavaType) + throws SQLException { + TypeCodec outCodec = codecFor(returnType, returnJavaType); + TypeCodec varArgCodec = codecFor(varArgType, varArgJavaType); + TypeCodec[] fixedCodecs = new TypeCodec[fixedTypes.length]; + for (int i = 0; i < fixedTypes.length; i++) { + fixedCodecs[i] = codecFor(fixedTypes[i], fixedJavaTypes[i]); + } + return ctx -> { + DuckDBWritableVector out = ctx.output(); + long rowCount = ctx.rowCount(); + int vectorCount = ctx.columnCount(); + boolean propagateNulls = ctx.propagateNulls(); + DuckDBReadableVector[] vectors = new DuckDBReadableVector[vectorCount]; + TypeCodec[] codecs = new TypeCodec[vectorCount]; + for (int column = 0; column < vectorCount; column++) { + vectors[column] = ctx.input(column); + codecs[column] = column < fixedCodecs.length ? fixedCodecs[column] : varArgCodec; + } + Object[] args = new Object[vectorCount]; + for (long row = 0; row < rowCount; row++) { + boolean skipRow = false; + for (int column = 0; column < vectorCount; column++) { + DuckDBReadableVector vector = vectors[column]; + if (vector.isNull(row)) { + if (propagateNulls) { + out.setNull(row); + skipRow = true; + break; + } + args[column] = null; + } else { + args[column] = codecs[column].read(vector, row); + } + } + if (skipRow) { + continue; + } + Object result = function.apply(args); + outCodec.write(out, row, result); + } + }; + } + + static DuckDBColumnType mapJavaClassToDuckDBType(Class javaType) throws SQLException { + if (javaType == null) { + throw new SQLException("Java type cannot be null"); + } + Class normalizedClass = normalizeJavaClass(javaType); + DuckDBColumnType mappedType = DUCKDB_TYPE_BY_JAVA_CLASS.get(normalizedClass); + if (mappedType != null) { + return mappedType; + } + throw new SQLException("Unsupported Java type for scalar function mapping: " + javaType.getName()); + } + + static DuckDBColumnType mapLogicalTypeToDuckDBType(DuckDBLogicalType logicalType) throws SQLException { + if (logicalType == null) { + throw new SQLException("Logical type cannot be null"); + } + DuckDBBindings.CAPIType type = + DuckDBBindings.CAPIType.capiTypeFromTypeId(duckdb_get_type_id(logicalType.logicalTypeRef())); + switch (type) { + case DUCKDB_TYPE_BOOLEAN: + return DuckDBColumnType.BOOLEAN; + case DUCKDB_TYPE_TINYINT: + return DuckDBColumnType.TINYINT; + case DUCKDB_TYPE_UTINYINT: + return DuckDBColumnType.UTINYINT; + case DUCKDB_TYPE_SMALLINT: + return DuckDBColumnType.SMALLINT; + case DUCKDB_TYPE_USMALLINT: + return DuckDBColumnType.USMALLINT; + case DUCKDB_TYPE_INTEGER: + return DuckDBColumnType.INTEGER; + case DUCKDB_TYPE_UINTEGER: + return DuckDBColumnType.UINTEGER; + case DUCKDB_TYPE_BIGINT: + return DuckDBColumnType.BIGINT; + case DUCKDB_TYPE_UBIGINT: + return DuckDBColumnType.UBIGINT; + case DUCKDB_TYPE_FLOAT: + return DuckDBColumnType.FLOAT; + case DUCKDB_TYPE_DOUBLE: + return DuckDBColumnType.DOUBLE; + case DUCKDB_TYPE_DECIMAL: + return DuckDBColumnType.DECIMAL; + case DUCKDB_TYPE_VARCHAR: + return DuckDBColumnType.VARCHAR; + case DUCKDB_TYPE_DATE: + return DuckDBColumnType.DATE; + case DUCKDB_TYPE_TIMESTAMP_S: + return DuckDBColumnType.TIMESTAMP_S; + case DUCKDB_TYPE_TIMESTAMP_MS: + return DuckDBColumnType.TIMESTAMP_MS; + case DUCKDB_TYPE_TIMESTAMP: + return DuckDBColumnType.TIMESTAMP; + case DUCKDB_TYPE_TIMESTAMP_NS: + return DuckDBColumnType.TIMESTAMP_NS; + case DUCKDB_TYPE_TIMESTAMP_TZ: + return DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE; + default: + throw new SQLException("Unsupported logical type for Function/BiFunction mapping: " + type); + } + } + + private static TypeCodec codecFor(DuckDBColumnType type) throws SQLException { + TypeCodec codec = CODECS_BY_DUCKDB_TYPE.get(type); + if (codec != null) { + return codec; + } + throw new SQLException("Unsupported DuckDB type for Function/BiFunction mapping: " + type); + } + + private static TypeCodec codecFor(DuckDBColumnType type, Class declaredJavaType) throws SQLException { + if (declaredJavaType == null) { + return codecFor(type); + } + Class normalizedClass = normalizeJavaClass(declaredJavaType); + switch (type) { + case DATE: + if (normalizedClass == LocalDate.class) { + return codecFor(type); + } + if (normalizedClass == java.sql.Date.class) { + return DATE_SQL_CODEC; + } + break; + case TIMESTAMP: + case TIMESTAMP_S: + case TIMESTAMP_MS: + case TIMESTAMP_NS: + if (normalizedClass == LocalDateTime.class) { + return codecFor(type); + } + if (normalizedClass == java.sql.Timestamp.class) { + return TIMESTAMP_SQL_CODEC; + } + if (normalizedClass == Date.class) { + return TIMESTAMP_UTIL_DATE_CODEC; + } + break; + default: { + TypeCodec codec = codecFor(type); + if (codec.matches(normalizedClass)) { + return codec; + } + break; + } + } + throw new SQLException("Unsupported Java type " + normalizedClass.getName() + " for DuckDB type " + type + + " in functional scalar function mapping"); + } + + private static void register(DuckDBColumnType type, Class javaType, Reader reader, Writer writer) { + CODECS_BY_DUCKDB_TYPE.put(type, new TypeCodec(javaType, reader, writer)); + } + + private static Class normalizeJavaClass(Class javaType) { + Class boxedClass = BOXED_CLASSES.get(javaType); + return boxedClass != null ? boxedClass : javaType; + } + + private DuckDBScalarFunctionAdapter() { + } + + @FunctionalInterface + private interface Reader { + T read(DuckDBReadableVector vector, long row) throws SQLException; + } + + @FunctionalInterface + private interface Writer { + void write(DuckDBWritableVector vector, long row, T value) throws SQLException; + } + + private static final class TypeCodec { + private final Class javaType; + private final Reader reader; + private final Writer writer; + + private TypeCodec(Class javaType, Reader reader, Writer writer) { + this.javaType = javaType; + this.reader = reader; + this.writer = writer; + } + + boolean matches(Class declaredJavaType) { + return javaType == declaredJavaType; + } + + Object read(DuckDBReadableVector vector, long row) throws SQLException { + return reader.read(vector, row); + } + + void write(DuckDBWritableVector vector, long row, Object value) throws SQLException { + if (value == null) { + vector.setNull(row); + return; + } + if (!javaType.isInstance(value)) { + throw new ClassCastException("Expected value of type " + javaType.getName() + ", got " + + value.getClass().getName()); + } + @SuppressWarnings("unchecked") Writer typedWriter = (Writer) writer; + typedWriter.write(vector, row, value); + } + } +} diff --git a/src/main/java/org/duckdb/DuckDBScalarFunctionBuilder.java b/src/main/java/org/duckdb/DuckDBScalarFunctionBuilder.java new file mode 100644 index 000000000..446b7100d --- /dev/null +++ b/src/main/java/org/duckdb/DuckDBScalarFunctionBuilder.java @@ -0,0 +1,467 @@ +package org.duckdb; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.duckdb.DuckDBBindings.*; + +import java.nio.ByteBuffer; +import java.sql.Connection; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.locks.Lock; +import java.util.function.BiFunction; +import java.util.function.DoubleBinaryOperator; +import java.util.function.DoubleUnaryOperator; +import java.util.function.Function; +import java.util.function.IntBinaryOperator; +import java.util.function.IntUnaryOperator; +import java.util.function.LongBinaryOperator; +import java.util.function.LongUnaryOperator; +import java.util.function.Supplier; + +public final class DuckDBScalarFunctionBuilder implements AutoCloseable { + private ByteBuffer scalarFunctionRef; + private String functionName; + private DuckDBLogicalType returnType; + private DuckDBColumnType returnColumnType; + private Class returnJavaType; + private DuckDBScalarFunction callback; + private DuckDBLogicalType varArgType; + private final List parameterTypes = new ArrayList<>(); + private final List parameterColumnTypes = new ArrayList<>(); + private final List> parameterJavaTypes = new ArrayList<>(); + private boolean volatileFlag; + private boolean specialHandlingFlag; + private boolean propagateNullsFlag = true; + private boolean callbackRequiresNullPropagation; + private boolean finalized; + + DuckDBScalarFunctionBuilder() throws SQLException { + this.scalarFunctionRef = duckdb_create_scalar_function(); + if (scalarFunctionRef == null) { + throw new SQLException("Failed to create scalar function"); + } + } + + public DuckDBScalarFunctionBuilder withName(String name) throws SQLException { + ensureNotFinalized(); + if (name == null || name.trim().isEmpty()) { + throw new SQLException("Function name cannot be null or empty"); + } + this.functionName = name; + duckdb_scalar_function_set_name(scalarFunctionRef, name.getBytes(UTF_8)); + return this; + } + + public DuckDBScalarFunctionBuilder withReturnType(DuckDBLogicalType returnType) throws SQLException { + ensureNotFinalized(); + if (returnType == null) { + throw new SQLException("Return type cannot be null"); + } + this.returnType = returnType; + this.returnColumnType = null; + this.returnJavaType = null; + duckdb_scalar_function_set_return_type(scalarFunctionRef, returnType.logicalTypeRef()); + return this; + } + + public DuckDBScalarFunctionBuilder withParameter(DuckDBLogicalType parameterType) throws SQLException { + ensureNotFinalized(); + if (parameterType == null) { + throw new SQLException("Parameter type cannot be null"); + } + parameterTypes.add(parameterType); + parameterColumnTypes.add(null); + parameterJavaTypes.add(null); + duckdb_scalar_function_add_parameter(scalarFunctionRef, parameterType.logicalTypeRef()); + return this; + } + + public DuckDBScalarFunctionBuilder withReturnType(Class returnType) throws SQLException { + ensureNotFinalized(); + if (returnType == null) { + throw new SQLException("Return type cannot be null"); + } + DuckDBColumnType mappedType = DuckDBScalarFunctionAdapter.mapJavaClassToDuckDBType(returnType); + return setMappedReturnType(mappedType, returnType); + } + + public DuckDBScalarFunctionBuilder withParameter(Class parameterType) throws SQLException { + ensureNotFinalized(); + if (parameterType == null) { + throw new SQLException("Parameter type cannot be null"); + } + DuckDBColumnType mappedType = DuckDBScalarFunctionAdapter.mapJavaClassToDuckDBType(parameterType); + return addMappedParameterType(mappedType, parameterType); + } + + public DuckDBScalarFunctionBuilder withReturnType(DuckDBColumnType returnType) throws SQLException { + ensureNotFinalized(); + if (returnType == null) { + throw new SQLException("Return type cannot be null"); + } + return setMappedReturnType(returnType, null); + } + + public DuckDBScalarFunctionBuilder withParameter(DuckDBColumnType parameterType) throws SQLException { + ensureNotFinalized(); + if (parameterType == null) { + throw new SQLException("Parameter type cannot be null"); + } + return addMappedParameterType(parameterType, null); + } + + public DuckDBScalarFunctionBuilder withVectorizedFunction(DuckDBScalarFunction function) throws SQLException { + ensureNotFinalized(); + if (function == null) { + throw new SQLException("Scalar function callback cannot be null"); + } + return setCallback(function, false); + } + + public DuckDBScalarFunctionBuilder propagateNulls(boolean propagateNulls) throws SQLException { + ensureNotFinalized(); + if (!propagateNulls && callbackRequiresNullPropagation) { + throw new SQLException("Primitive scalar callbacks require propagateNulls(true)"); + } + this.propagateNullsFlag = propagateNulls; + if (callback != null) { + duckdb_scalar_function_set_function(scalarFunctionRef, + new DuckDBScalarFunctionWrapper(callback, propagateNullsFlag)); + } + return this; + } + + public DuckDBScalarFunctionBuilder withIntFunction(IntUnaryOperator function) throws SQLException { + ensureNotFinalized(); + if (function == null) { + throw new SQLException("Scalar function callback cannot be null"); + } + ensurePrimitiveCallbackCompatible("withIntFunction"); + ensureUnaryPrimitiveSignature(DuckDBColumnType.INTEGER, "withIntFunction"); + return setCallback(DuckDBScalarFunctionAdapter.intUnary(function), true); + } + + public DuckDBScalarFunctionBuilder withIntFunction(IntBinaryOperator function) throws SQLException { + ensureNotFinalized(); + if (function == null) { + throw new SQLException("Scalar function callback cannot be null"); + } + ensurePrimitiveCallbackCompatible("withIntFunction"); + ensureBinaryPrimitiveSignature(DuckDBColumnType.INTEGER, "withIntFunction"); + return setCallback(DuckDBScalarFunctionAdapter.intBinary(function), true); + } + + public DuckDBScalarFunctionBuilder withDoubleFunction(DoubleUnaryOperator function) throws SQLException { + ensureNotFinalized(); + if (function == null) { + throw new SQLException("Scalar function callback cannot be null"); + } + ensurePrimitiveCallbackCompatible("withDoubleFunction"); + ensureUnaryPrimitiveSignature(DuckDBColumnType.DOUBLE, "withDoubleFunction"); + return setCallback(DuckDBScalarFunctionAdapter.doubleUnary(function), true); + } + + public DuckDBScalarFunctionBuilder withDoubleFunction(DoubleBinaryOperator function) throws SQLException { + ensureNotFinalized(); + if (function == null) { + throw new SQLException("Scalar function callback cannot be null"); + } + ensurePrimitiveCallbackCompatible("withDoubleFunction"); + ensureBinaryPrimitiveSignature(DuckDBColumnType.DOUBLE, "withDoubleFunction"); + return setCallback(DuckDBScalarFunctionAdapter.doubleBinary(function), true); + } + + public DuckDBScalarFunctionBuilder withLongFunction(LongUnaryOperator function) throws SQLException { + ensureNotFinalized(); + if (function == null) { + throw new SQLException("Scalar function callback cannot be null"); + } + ensurePrimitiveCallbackCompatible("withLongFunction"); + ensureUnaryPrimitiveSignature(DuckDBColumnType.BIGINT, "withLongFunction"); + return setCallback(DuckDBScalarFunctionAdapter.longUnary(function), true); + } + + public DuckDBScalarFunctionBuilder withLongFunction(LongBinaryOperator function) throws SQLException { + ensureNotFinalized(); + if (function == null) { + throw new SQLException("Scalar function callback cannot be null"); + } + ensurePrimitiveCallbackCompatible("withLongFunction"); + ensureBinaryPrimitiveSignature(DuckDBColumnType.BIGINT, "withLongFunction"); + return setCallback(DuckDBScalarFunctionAdapter.longBinary(function), true); + } + + public DuckDBScalarFunctionBuilder withFunction(Function function) + throws SQLException { + ensureNotFinalized(); + if (function == null) { + throw new SQLException("Scalar function callback cannot be null"); + } + if (varArgType != null) { + throw new SQLException("Function callback does not support varargs; use withVarArgsFunction instead"); + } + if (parameterTypes.size() != 1) { + throw new SQLException("Function callback requires exactly 1 declared parameter"); + } + DuckDBColumnType parameterType = effectiveParameterType(0); + Class parameterJavaType = effectiveParameterJavaType(0); + DuckDBColumnType resolvedReturnType = effectiveReturnType(); + Class resolvedReturnJavaType = effectiveReturnJavaType(); + return withVectorizedFunction(DuckDBScalarFunctionAdapter.unary(function, parameterType, parameterJavaType, + resolvedReturnType, resolvedReturnJavaType)); + } + + public DuckDBScalarFunctionBuilder withFunction(BiFunction function) + throws SQLException { + ensureNotFinalized(); + if (function == null) { + throw new SQLException("Scalar function callback cannot be null"); + } + if (varArgType != null) { + throw new SQLException("BiFunction callback does not support varargs; use withVarArgsFunction instead"); + } + if (parameterTypes.size() != 2) { + throw new SQLException("BiFunction callback requires exactly 2 declared parameters"); + } + DuckDBColumnType leftType = effectiveParameterType(0); + Class leftJavaType = effectiveParameterJavaType(0); + DuckDBColumnType rightType = effectiveParameterType(1); + Class rightJavaType = effectiveParameterJavaType(1); + DuckDBColumnType resolvedReturnType = effectiveReturnType(); + Class resolvedReturnJavaType = effectiveReturnJavaType(); + return withVectorizedFunction(DuckDBScalarFunctionAdapter.binary( + function, leftType, leftJavaType, rightType, rightJavaType, resolvedReturnType, resolvedReturnJavaType)); + } + + public DuckDBScalarFunctionBuilder withFunction(Supplier function) throws SQLException { + ensureNotFinalized(); + if (function == null) { + throw new SQLException("Scalar function callback cannot be null"); + } + if (!parameterTypes.isEmpty()) { + throw new SQLException("Supplier callback requires zero declared parameters"); + } + if (varArgType != null) { + throw new SQLException("Supplier callback does not support varargs"); + } + DuckDBColumnType resolvedReturnType = effectiveReturnType(); + Class resolvedReturnJavaType = effectiveReturnJavaType(); + return withVectorizedFunction( + DuckDBScalarFunctionAdapter.nullary(function, resolvedReturnType, resolvedReturnJavaType)); + } + + public DuckDBScalarFunctionBuilder withVarArgsFunction(Function function) throws SQLException { + ensureNotFinalized(); + if (function == null) { + throw new SQLException("Scalar function callback cannot be null"); + } + if (varArgType == null) { + throw new SQLException("Varargs functional callback requires withVarArgs(...) declaration"); + } + DuckDBColumnType[] fixedTypes = effectiveFixedParameterTypes(); + Class[] fixedJavaTypes = effectiveFixedParameterJavaTypes(); + DuckDBColumnType varArgColumnType = DuckDBScalarFunctionAdapter.mapLogicalTypeToDuckDBType(varArgType); + DuckDBColumnType resolvedReturnType = effectiveReturnType(); + Class resolvedReturnJavaType = effectiveReturnJavaType(); + return withVectorizedFunction(DuckDBScalarFunctionAdapter.variadic( + function, fixedTypes, fixedJavaTypes, varArgColumnType, null, resolvedReturnType, resolvedReturnJavaType)); + } + + public DuckDBScalarFunctionBuilder withVarArgs(DuckDBLogicalType varArgType) throws SQLException { + ensureNotFinalized(); + if (varArgType == null) { + throw new SQLException("Varargs type cannot be null"); + } + this.varArgType = varArgType; + duckdb_scalar_function_set_varargs(scalarFunctionRef, varArgType.logicalTypeRef()); + return this; + } + + public DuckDBScalarFunctionBuilder withVolatile() throws SQLException { + ensureNotFinalized(); + this.volatileFlag = true; + duckdb_scalar_function_set_volatile(scalarFunctionRef); + return this; + } + + public DuckDBScalarFunctionBuilder withSpecialHandling() throws SQLException { + ensureNotFinalized(); + this.specialHandlingFlag = true; + duckdb_scalar_function_set_special_handling(scalarFunctionRef); + return this; + } + + public DuckDBRegisteredFunction register(Connection connection) throws SQLException { + ensureNotFinalized(); + if (connection == null) { + throw new SQLException("Connection cannot be null"); + } + if (functionName == null) { + throw new SQLException("Function name must be defined"); + } + if (returnType == null && returnColumnType == null) { + throw new SQLException("Return type must be defined"); + } + if (callback == null) { + throw new SQLException("Scalar function callback must be defined"); + } + DuckDBConnection duckConnection = unwrapConnection(connection); + Lock connectionLock = duckConnection.connRefLock; + connectionLock.lock(); + try { + duckConnection.checkOpen(); + int status = duckdb_register_scalar_function(duckConnection.connRef, scalarFunctionRef); + if (status != 0) { + throw new SQLException("Failed to register scalar function '" + functionName + "'"); + } + DuckDBRegisteredFunction registeredFunction = DuckDBRegisteredFunction.of( + functionName, parameterTypes, parameterColumnTypes, returnType, returnColumnType, callback, varArgType, + volatileFlag, specialHandlingFlag, propagateNullsFlag); + DuckDBDriver.registerFunction(registeredFunction); + return registeredFunction; + } finally { + connectionLock.unlock(); + close(); + } + } + + @Override + public void close() { + if (scalarFunctionRef != null) { + duckdb_destroy_scalar_function(scalarFunctionRef); + scalarFunctionRef = null; + } + finalized = true; + } + + private void ensureNotFinalized() throws SQLException { + if (finalized || scalarFunctionRef == null) { + throw new SQLException("Scalar function builder is already finalized"); + } + } + + private DuckDBColumnType effectiveParameterType(int index) throws SQLException { + DuckDBColumnType parameterColumnType = parameterColumnTypes.get(index); + if (parameterColumnType != null) { + return parameterColumnType; + } + DuckDBLogicalType parameterLogicalType = parameterTypes.get(index); + return DuckDBScalarFunctionAdapter.mapLogicalTypeToDuckDBType(parameterLogicalType); + } + + private DuckDBColumnType effectiveReturnType() throws SQLException { + if (returnColumnType != null) { + return returnColumnType; + } + if (returnType != null) { + return DuckDBScalarFunctionAdapter.mapLogicalTypeToDuckDBType(returnType); + } + throw new SQLException("Return type must be defined before functional callback"); + } + + private Class effectiveParameterJavaType(int index) { + return parameterJavaTypes.get(index); + } + + private Class effectiveReturnJavaType() { + return returnJavaType; + } + + private DuckDBColumnType[] effectiveFixedParameterTypes() throws SQLException { + DuckDBColumnType[] fixedTypes = new DuckDBColumnType[parameterTypes.size()]; + for (int i = 0; i < fixedTypes.length; i++) { + fixedTypes[i] = effectiveParameterType(i); + } + return fixedTypes; + } + + private Class[] effectiveFixedParameterJavaTypes() { + return parameterJavaTypes.toArray(new Class[ 0 ]); + } + + private DuckDBScalarFunctionBuilder setMappedReturnType(DuckDBColumnType mappedType, Class javaType) + throws SQLException { + this.returnType = null; + this.returnColumnType = mappedType; + this.returnJavaType = javaType; + try (DuckDBLogicalType logicalType = DuckDBLogicalType.of(mappedType)) { + duckdb_scalar_function_set_return_type(scalarFunctionRef, logicalType.logicalTypeRef()); + } + return this; + } + + private DuckDBScalarFunctionBuilder addMappedParameterType(DuckDBColumnType mappedType, Class javaType) + throws SQLException { + parameterTypes.add(null); + parameterColumnTypes.add(mappedType); + parameterJavaTypes.add(javaType); + try (DuckDBLogicalType logicalType = DuckDBLogicalType.of(mappedType)) { + duckdb_scalar_function_add_parameter(scalarFunctionRef, logicalType.logicalTypeRef()); + } + return this; + } + + private DuckDBScalarFunctionBuilder setCallback(DuckDBScalarFunction function, boolean requiresNullPropagation) + throws SQLException { + this.callback = function; + this.callbackRequiresNullPropagation = requiresNullPropagation; + duckdb_scalar_function_set_function(scalarFunctionRef, + new DuckDBScalarFunctionWrapper(function, propagateNullsFlag)); + return this; + } + + private void ensurePrimitiveCallbackCompatible(String callbackMethodName) throws SQLException { + if (!propagateNullsFlag) { + throw new SQLException(callbackMethodName + " requires propagateNulls(true)"); + } + if (varArgType != null) { + throw new SQLException(callbackMethodName + " does not support varargs; use withVarArgsFunction instead"); + } + } + + private void ensureUnaryPrimitiveSignature(DuckDBColumnType expectedType, String callbackMethodName) + throws SQLException { + if (parameterTypes.size() != 1) { + throw new SQLException(callbackMethodName + " requires exactly 1 declared parameter"); + } + ensurePrimitiveParameterType(0, expectedType, callbackMethodName); + ensurePrimitiveReturnType(expectedType, callbackMethodName); + } + + private void ensureBinaryPrimitiveSignature(DuckDBColumnType expectedType, String callbackMethodName) + throws SQLException { + if (parameterTypes.size() != 2) { + throw new SQLException(callbackMethodName + " requires exactly 2 declared parameters"); + } + ensurePrimitiveParameterType(0, expectedType, callbackMethodName); + ensurePrimitiveParameterType(1, expectedType, callbackMethodName); + ensurePrimitiveReturnType(expectedType, callbackMethodName); + } + + private void ensurePrimitiveParameterType(int index, DuckDBColumnType expectedType, String callbackMethodName) + throws SQLException { + DuckDBColumnType actualType = effectiveParameterType(index); + if (actualType != expectedType) { + throw new SQLException(callbackMethodName + " requires parameter " + index + " to be " + expectedType + + ", got " + actualType); + } + } + + private void ensurePrimitiveReturnType(DuckDBColumnType expectedType, String callbackMethodName) + throws SQLException { + DuckDBColumnType actualType = effectiveReturnType(); + if (actualType != expectedType) { + throw new SQLException(callbackMethodName + " requires return type " + expectedType + ", got " + + actualType); + } + } + + private static DuckDBConnection unwrapConnection(Connection connection) throws SQLException { + try { + return connection.unwrap(DuckDBConnection.class); + } catch (SQLException exception) { + throw new SQLException("Scalar function registration requires a DuckDB JDBC connection", exception); + } + } +} diff --git a/src/main/java/org/duckdb/DuckDBScalarFunctionWrapper.java b/src/main/java/org/duckdb/DuckDBScalarFunctionWrapper.java index 9390bd135..1bfc7f2d1 100644 --- a/src/main/java/org/duckdb/DuckDBScalarFunctionWrapper.java +++ b/src/main/java/org/duckdb/DuckDBScalarFunctionWrapper.java @@ -6,16 +6,19 @@ final class DuckDBScalarFunctionWrapper { private final DuckDBScalarFunction function; + private final boolean propagateNulls; - DuckDBScalarFunctionWrapper(DuckDBScalarFunction function) { + DuckDBScalarFunctionWrapper(DuckDBScalarFunction function, boolean propagateNulls) { this.function = function; + this.propagateNulls = propagateNulls; } public void execute(ByteBuffer functionInfo, ByteBuffer inputChunk, ByteBuffer outputVector) { try { DuckDBDataChunkReader inputReader = new DuckDBDataChunkReader(inputChunk); - DuckDBWritableVector outputWriter = new DuckDBWritableVector(outputVector, inputReader.rowCount()); - function.apply(inputReader, outputWriter); + DuckDBWritableVector outputWriter = new DuckDBWritableVectorImpl(outputVector, inputReader.rowCount()); + DuckDBScalarContext context = new DuckDBScalarContext(inputReader, outputWriter, propagateNulls); + function.apply(context); } catch (Throwable throwable) { reportError(functionInfo, throwable); } diff --git a/src/main/java/org/duckdb/DuckDBScalarRow.java b/src/main/java/org/duckdb/DuckDBScalarRow.java new file mode 100644 index 000000000..cd0e10721 --- /dev/null +++ b/src/main/java/org/duckdb/DuckDBScalarRow.java @@ -0,0 +1,360 @@ +package org.duckdb; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.sql.SQLException; +import java.sql.Timestamp; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.OffsetDateTime; + +public final class DuckDBScalarRow { + private final DuckDBScalarContext context; + private final long rowIndex; + + DuckDBScalarRow(DuckDBScalarContext context, long rowIndex) { + this.context = context; + this.rowIndex = rowIndex; + } + + public long index() { + return rowIndex; + } + + public boolean isNull(int columnIndex) { + return input(columnIndex).isNull(rowIndex); + } + + public boolean getBoolean(int columnIndex) { + try { + return input(columnIndex).getBoolean(rowIndex); + } catch (SQLException exception) { + throw readFailure("BOOLEAN", columnIndex, exception); + } + } + + public byte getByte(int columnIndex) { + try { + return input(columnIndex).getByte(rowIndex); + } catch (SQLException exception) { + throw readFailure("TINYINT", columnIndex, exception); + } + } + + public short getShort(int columnIndex) { + try { + return input(columnIndex).getShort(rowIndex); + } catch (SQLException exception) { + throw readFailure("SMALLINT", columnIndex, exception); + } + } + + public short getUint8(int columnIndex) { + try { + return input(columnIndex).getUint8(rowIndex); + } catch (SQLException exception) { + throw readFailure("UTINYINT", columnIndex, exception); + } + } + + public int getUint16(int columnIndex) { + try { + return input(columnIndex).getUint16(rowIndex); + } catch (SQLException exception) { + throw readFailure("USMALLINT", columnIndex, exception); + } + } + + public int getInt(int columnIndex) { + try { + return input(columnIndex).getInt(rowIndex); + } catch (SQLException exception) { + throw readFailure("INTEGER", columnIndex, exception); + } + } + + public long getUint32(int columnIndex) { + try { + return input(columnIndex).getUint32(rowIndex); + } catch (SQLException exception) { + throw readFailure("UINTEGER", columnIndex, exception); + } + } + + public long getLong(int columnIndex) { + try { + return input(columnIndex).getLong(rowIndex); + } catch (SQLException exception) { + throw readFailure("BIGINT", columnIndex, exception); + } + } + + public BigInteger getUint64(int columnIndex) { + try { + return input(columnIndex).getUint64(rowIndex); + } catch (SQLException exception) { + throw readFailure("UBIGINT", columnIndex, exception); + } + } + + public float getFloat(int columnIndex) { + try { + return input(columnIndex).getFloat(rowIndex); + } catch (SQLException exception) { + throw readFailure("FLOAT", columnIndex, exception); + } + } + + public double getDouble(int columnIndex) { + try { + return input(columnIndex).getDouble(rowIndex); + } catch (SQLException exception) { + throw readFailure("DOUBLE", columnIndex, exception); + } + } + + public LocalDate getLocalDate(int columnIndex) { + try { + return input(columnIndex).getLocalDate(rowIndex); + } catch (SQLException exception) { + throw readFailure("DATE", columnIndex, exception); + } + } + + public java.sql.Date getDate(int columnIndex) { + try { + return input(columnIndex).getDate(rowIndex); + } catch (SQLException exception) { + throw readFailure("DATE", columnIndex, exception); + } + } + + public LocalDateTime getLocalDateTime(int columnIndex) { + try { + return input(columnIndex).getLocalDateTime(rowIndex); + } catch (SQLException exception) { + throw readFailure("TIMESTAMP", columnIndex, exception); + } + } + + public Timestamp getTimestamp(int columnIndex) { + try { + return input(columnIndex).getTimestamp(rowIndex); + } catch (SQLException exception) { + throw readFailure("TIMESTAMP", columnIndex, exception); + } + } + + public OffsetDateTime getOffsetDateTime(int columnIndex) { + try { + return input(columnIndex).getOffsetDateTime(rowIndex); + } catch (SQLException exception) { + throw readFailure("TIMESTAMP WITH TIME ZONE", columnIndex, exception); + } + } + + public BigDecimal getBigDecimal(int columnIndex) { + try { + return input(columnIndex).getBigDecimal(rowIndex); + } catch (SQLException exception) { + throw readFailure("DECIMAL", columnIndex, exception); + } + } + + public String getString(int columnIndex) { + try { + return input(columnIndex).getString(rowIndex); + } catch (SQLException exception) { + throw readFailure("VARCHAR", columnIndex, exception); + } + } + + public void setNull() { + try { + context.output().setNull(rowIndex); + } catch (SQLException exception) { + throw writeFailure("NULL", exception); + } + } + + public void setBoolean(boolean value) { + try { + context.output().setBoolean(rowIndex, value); + } catch (SQLException exception) { + throw writeFailure("BOOLEAN", exception); + } + } + + public void setByte(byte value) { + try { + context.output().setByte(rowIndex, value); + } catch (SQLException exception) { + throw writeFailure("TINYINT", exception); + } + } + + public void setShort(short value) { + try { + context.output().setShort(rowIndex, value); + } catch (SQLException exception) { + throw writeFailure("SMALLINT", exception); + } + } + + public void setUint8(int value) { + try { + context.output().setUint8(rowIndex, value); + } catch (SQLException exception) { + throw writeFailure("UTINYINT", exception); + } + } + + public void setUint16(int value) { + try { + context.output().setUint16(rowIndex, value); + } catch (SQLException exception) { + throw writeFailure("USMALLINT", exception); + } + } + + public void setInt(int value) { + try { + context.output().setInt(rowIndex, value); + } catch (SQLException exception) { + throw writeFailure("INTEGER", exception); + } + } + + public void setUint32(long value) { + try { + context.output().setUint32(rowIndex, value); + } catch (SQLException exception) { + throw writeFailure("UINTEGER", exception); + } + } + + public void setLong(long value) { + try { + context.output().setLong(rowIndex, value); + } catch (SQLException exception) { + throw writeFailure("BIGINT", exception); + } + } + + public void setUint64(BigInteger value) { + try { + context.output().setUint64(rowIndex, value); + } catch (SQLException exception) { + throw writeFailure("UBIGINT", exception); + } + } + + public void setFloat(float value) { + try { + context.output().setFloat(rowIndex, value); + } catch (SQLException exception) { + throw writeFailure("FLOAT", exception); + } + } + + public void setDouble(double value) { + try { + context.output().setDouble(rowIndex, value); + } catch (SQLException exception) { + throw writeFailure("DOUBLE", exception); + } + } + + public void setDate(LocalDate value) { + try { + context.output().setDate(rowIndex, value); + } catch (SQLException exception) { + throw writeFailure("DATE", exception); + } + } + + public void setDate(java.sql.Date value) { + try { + context.output().setDate(rowIndex, value); + } catch (SQLException exception) { + throw writeFailure("DATE", exception); + } + } + + public void setDate(java.util.Date value) { + try { + context.output().setDate(rowIndex, value); + } catch (SQLException exception) { + throw writeFailure("DATE", exception); + } + } + + public void setTimestamp(LocalDateTime value) { + try { + context.output().setTimestamp(rowIndex, value); + } catch (SQLException exception) { + throw writeFailure("TIMESTAMP", exception); + } + } + + public void setTimestamp(Timestamp value) { + try { + context.output().setTimestamp(rowIndex, value); + } catch (SQLException exception) { + throw writeFailure("TIMESTAMP", exception); + } + } + + public void setTimestamp(java.util.Date value) { + try { + context.output().setTimestamp(rowIndex, value); + } catch (SQLException exception) { + throw writeFailure("TIMESTAMP", exception); + } + } + + public void setTimestamp(LocalDate value) { + try { + context.output().setTimestamp(rowIndex, value); + } catch (SQLException exception) { + throw writeFailure("TIMESTAMP", exception); + } + } + + public void setOffsetDateTime(OffsetDateTime value) { + try { + context.output().setOffsetDateTime(rowIndex, value); + } catch (SQLException exception) { + throw writeFailure("TIMESTAMP WITH TIME ZONE", exception); + } + } + + public void setBigDecimal(BigDecimal value) { + try { + context.output().setBigDecimal(rowIndex, value); + } catch (SQLException exception) { + throw writeFailure("DECIMAL", exception); + } + } + + public void setString(String value) { + try { + context.output().setString(rowIndex, value); + } catch (SQLException exception) { + throw writeFailure("VARCHAR", exception); + } + } + + private DuckDBReadableVector input(int columnIndex) { + return context.inputUnchecked(columnIndex); + } + + private IllegalStateException readFailure(String type, int columnIndex, SQLException exception) { + return new IllegalStateException( + "Failed to read " + type + " from input column " + columnIndex + " at row " + rowIndex, exception); + } + + private IllegalStateException writeFailure(String type, SQLException exception) { + return new IllegalStateException("Failed to write " + type + " to output row " + rowIndex, exception); + } +} diff --git a/src/main/java/org/duckdb/DuckDBWritableVector.java b/src/main/java/org/duckdb/DuckDBWritableVector.java index 606676a8e..d1532a976 100644 --- a/src/main/java/org/duckdb/DuckDBWritableVector.java +++ b/src/main/java/org/duckdb/DuckDBWritableVector.java @@ -1,406 +1,103 @@ package org.duckdb; -import static java.nio.charset.StandardCharsets.UTF_8; -import static org.duckdb.DuckDBBindings.*; - import java.math.BigDecimal; import java.math.BigInteger; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; import java.sql.SQLException; import java.sql.Timestamp; -import java.time.Instant; import java.time.LocalDate; import java.time.LocalDateTime; import java.time.OffsetDateTime; -import java.time.ZoneId; -import java.time.ZoneOffset; - -public final class DuckDBWritableVector { - private static final BigInteger UINT64_MAX = new BigInteger("18446744073709551615"); - private static final ByteOrder NATIVE_ORDER = ByteOrder.nativeOrder(); - - private final ByteBuffer vectorRef; - private final int rowCount; - private final DuckDBVectorTypeInfo typeInfo; - private final ByteBuffer data; - private ByteBuffer validity; - - DuckDBWritableVector(ByteBuffer vectorRef, int rowCount) throws SQLException { - if (vectorRef == null) { - throw new SQLException("Invalid vector reference"); - } - this.vectorRef = vectorRef; - this.rowCount = rowCount; - this.typeInfo = DuckDBVectorTypeInfo.fromVector(vectorRef); - this.data = duckdb_vector_get_data(vectorRef, (long) rowCount * typeInfo.widthBytes); - this.validity = duckdb_vector_get_validity(vectorRef, rowCount); - } - - public DuckDBColumnType getType() { - return typeInfo.columnType; - } - - public int rowCount() { - return rowCount; - } - - public void setNull(int row) throws SQLException { - checkRowIndex(row); - ensureValidity(); - duckdb_validity_set_row_validity(validity, row, false); - } - - public void setBoolean(int row, boolean value) throws SQLException { - checkRowIndex(row); - requireType(DuckDBColumnType.BOOLEAN); - data.put(row, value ? (byte) 1 : (byte) 0); - markValid(row); - } - - public void setByte(int row, byte value) throws SQLException { - checkRowIndex(row); - requireType(DuckDBColumnType.TINYINT); - data.put(row, value); - markValid(row); - } - - public void setShort(int row, short value) throws SQLException { - checkRowIndex(row); - requireType(DuckDBColumnType.SMALLINT); - data.order(NATIVE_ORDER).putShort(row * Short.BYTES, value); - markValid(row); - } - - public void setUint8(int row, int value) throws SQLException { - checkRowIndex(row); - requireType(DuckDBColumnType.UTINYINT); - checkUnsignedRange("UTINYINT", value, 0xFFL); - data.put(row, (byte) value); - markValid(row); - } - - public void setUint16(int row, int value) throws SQLException { - checkRowIndex(row); - requireType(DuckDBColumnType.USMALLINT); - checkUnsignedRange("USMALLINT", value, 0xFFFFL); - data.order(NATIVE_ORDER).putShort(row * Short.BYTES, (short) value); - markValid(row); - } - - public void setInt(int row, int value) throws SQLException { - checkRowIndex(row); - requireType(DuckDBColumnType.INTEGER); - data.order(NATIVE_ORDER).putInt(row * Integer.BYTES, value); - markValid(row); - } - - public void setUint32(int row, long value) throws SQLException { - checkRowIndex(row); - requireType(DuckDBColumnType.UINTEGER); - checkUnsignedRange("UINTEGER", value, 0xFFFFFFFFL); - data.order(NATIVE_ORDER).putInt(row * Integer.BYTES, (int) value); - markValid(row); - } - - public void setLong(int row, long value) throws SQLException { - checkRowIndex(row); - requireType(DuckDBColumnType.BIGINT); - data.order(NATIVE_ORDER).putLong(row * Long.BYTES, value); - markValid(row); - } - - public void setUint64(int row, BigInteger value) throws SQLException { - checkRowIndex(row); - requireType(DuckDBColumnType.UBIGINT); - if (value == null) { - setNull(row); - return; - } - if (value.signum() < 0 || value.compareTo(UINT64_MAX) > 0) { - throw new SQLException("Value out of range for UBIGINT: " + value); - } - data.order(NATIVE_ORDER).putLong(row * Long.BYTES, value.longValue()); - markValid(row); - } - - public void setFloat(int row, float value) throws SQLException { - checkRowIndex(row); - requireType(DuckDBColumnType.FLOAT); - data.order(NATIVE_ORDER).putFloat(row * Float.BYTES, value); - markValid(row); - } - - public void setDouble(int row, double value) throws SQLException { - checkRowIndex(row); - requireType(DuckDBColumnType.DOUBLE); - data.order(NATIVE_ORDER).putDouble(row * Double.BYTES, value); - markValid(row); - } - - public void setDate(int row, LocalDate value) throws SQLException { - checkRowIndex(row); - requireType(DuckDBColumnType.DATE); - if (value == null) { - setNull(row); - return; - } - long days = value.toEpochDay(); - if (days < Integer.MIN_VALUE || days > Integer.MAX_VALUE) { - throw new SQLException("Value out of range for DATE: " + value); - } - data.order(NATIVE_ORDER).putInt(row * Integer.BYTES, (int) days); - markValid(row); - } - - public void setDate(int row, java.sql.Date value) throws SQLException { - setDate(row, value == null ? null : value.toLocalDate()); - } - - public void setDate(int row, java.util.Date value) throws SQLException { - if (value == null) { - setNull(row); - return; - } - if (value instanceof java.sql.Date) { - setDate(row, (java.sql.Date) value); - return; - } - LocalDate localDate = Instant.ofEpochMilli(value.getTime()).atZone(ZoneOffset.UTC).toLocalDate(); - setDate(row, localDate); - } - - public void setTimestamp(int row, LocalDateTime value) throws SQLException { - checkRowIndex(row); - requireTimestampType(false); - if (value == null) { - setNull(row); - return; - } - data.order(NATIVE_ORDER).putLong(row * Long.BYTES, encodeLocalDateTime(value)); - markValid(row); - } - - public void setTimestamp(int row, Timestamp value) throws SQLException { - if (value == null) { - setNull(row); - return; - } - if (typeInfo.columnType == DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE) { - checkRowIndex(row); - data.order(NATIVE_ORDER).putLong(row * Long.BYTES, encodeInstant(value.toInstant())); - markValid(row); - return; - } - setTimestamp(row, value.toLocalDateTime()); - } - - public void setTimestamp(int row, java.util.Date value) throws SQLException { - checkRowIndex(row); - requireTimestampType(false); - if (value == null) { - setNull(row); - return; - } - if (value instanceof Timestamp) { - setTimestamp(row, (Timestamp) value); - return; - } - data.order(NATIVE_ORDER).putLong(row * Long.BYTES, encodeJavaUtilDate(value)); - markValid(row); - } - - public void setTimestamp(int row, LocalDate value) throws SQLException { - if (value == null) { - setNull(row); - return; - } - if (typeInfo.columnType == DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE) { - checkRowIndex(row); - Instant instant = value.atStartOfDay(ZoneId.systemDefault()).toInstant(); - data.order(NATIVE_ORDER).putLong(row * Long.BYTES, encodeInstant(instant)); - markValid(row); - return; - } - setTimestamp(row, value.atStartOfDay()); - } - - public void setOffsetDateTime(int row, OffsetDateTime value) throws SQLException { - checkRowIndex(row); - requireTimestampType(true); - if (value == null) { - setNull(row); - return; - } - data.order(NATIVE_ORDER) - .putLong(row * Long.BYTES, DuckDBTimestamp.localDateTime2Micros( - value.withOffsetSameInstant(ZoneOffset.UTC).toLocalDateTime())); - markValid(row); - } - - public void setBigDecimal(int row, BigDecimal value) throws SQLException { - checkRowIndex(row); - requireType(DuckDBColumnType.DECIMAL); - if (value == null) { - setNull(row); - return; - } - BigDecimal scaled; - try { - scaled = value.setScale(typeInfo.decimalMeta.scale); - } catch (ArithmeticException e) { - throw decimalOutOfRange(value, e); - } - if (scaled.precision() > typeInfo.decimalMeta.width) { - throw decimalOutOfRange(value); - } - switch (typeInfo.storageType) { - case DUCKDB_TYPE_SMALLINT: - try { - data.order(NATIVE_ORDER).putShort(row * Short.BYTES, scaled.unscaledValue().shortValueExact()); - } catch (ArithmeticException e) { - throw decimalOutOfRange(value, e); - } - break; - case DUCKDB_TYPE_INTEGER: - try { - data.order(NATIVE_ORDER).putInt(row * Integer.BYTES, scaled.unscaledValue().intValueExact()); - } catch (ArithmeticException e) { - throw decimalOutOfRange(value, e); - } - break; - case DUCKDB_TYPE_BIGINT: - try { - data.order(NATIVE_ORDER).putLong(row * Long.BYTES, scaled.unscaledValue().longValueExact()); - } catch (ArithmeticException e) { - throw decimalOutOfRange(value, e); - } - break; - case DUCKDB_TYPE_HUGEINT: { - BigInteger unscaled = scaled.unscaledValue(); - ByteBuffer slice = data.duplicate().order(NATIVE_ORDER); - slice.position(row * typeInfo.widthBytes); - slice.putLong(unscaled.longValue()); - slice.putLong(unscaled.shiftRight(Long.SIZE).longValue()); - break; - } - default: - throw new SQLException("Unsupported DECIMAL storage type: " + typeInfo.storageType); - } - markValid(row); - } - - public void setString(int row, String value) throws SQLException { - checkRowIndex(row); - requireType(DuckDBColumnType.VARCHAR); - if (value == null) { - setNull(row); - return; - } - duckdb_vector_assign_string_element_len(vectorRef, row, value.getBytes(UTF_8)); - markValid(row); - } - - ByteBuffer vectorRef() { - return vectorRef; - } - - private void ensureValidity() throws SQLException { - if (validity != null) { - return; - } - duckdb_vector_ensure_validity_writable(vectorRef); - validity = duckdb_vector_get_validity(vectorRef, rowCount); - if (validity == null) { - throw new SQLException("Cannot initialize vector validity"); - } - } - - private void markValid(int row) { - if (validity == null) { - return; - } - duckdb_validity_set_row_validity(validity, row, true); - } - - private void requireType(DuckDBColumnType expected) throws SQLException { - if (typeInfo.columnType != expected) { - throw new SQLException("Expected vector type " + expected + ", found " + typeInfo.columnType); - } - } - - private void checkRowIndex(int row) { - if (row < 0 || row >= rowCount) { - throw new IndexOutOfBoundsException("Row index out of bounds: " + row); - } - } - - private void requireTimestampType(boolean requireTimezone) throws SQLException { - if (requireTimezone) { - if (typeInfo.columnType != DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE) { - throw new SQLException("Expected vector type TIMESTAMP WITH TIME ZONE, found " + typeInfo.columnType); - } - return; - } - switch (typeInfo.columnType) { - case TIMESTAMP: - case TIMESTAMP_S: - case TIMESTAMP_MS: - case TIMESTAMP_NS: - case TIMESTAMP_WITH_TIME_ZONE: - return; - default: - throw new SQLException("Expected vector type TIMESTAMP*, found " + typeInfo.columnType); - } - } - - private long encodeLocalDateTime(LocalDateTime value) throws SQLException { - Instant instant; - if (typeInfo.columnType == DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE) { - instant = value.atZone(ZoneId.systemDefault()).toInstant(); - } else { - instant = value.toInstant(ZoneOffset.UTC); - } - return encodeInstant(instant); - } - - private long encodeJavaUtilDate(java.util.Date value) throws SQLException { - return encodeInstant(Instant.ofEpochMilli(value.getTime())); - } - - private long encodeInstant(Instant instant) throws SQLException { - long epochSeconds = instant.getEpochSecond(); - int nano = instant.getNano(); - switch (typeInfo.capiType) { - case DUCKDB_TYPE_TIMESTAMP_S: - return epochSeconds; - case DUCKDB_TYPE_TIMESTAMP_MS: - return Math.addExact(Math.multiplyExact(epochSeconds, 1_000L), nano / 1_000_000L); - case DUCKDB_TYPE_TIMESTAMP: - case DUCKDB_TYPE_TIMESTAMP_TZ: - return Math.addExact(Math.multiplyExact(epochSeconds, 1_000_000L), nano / 1_000L); - case DUCKDB_TYPE_TIMESTAMP_NS: - return Math.addExact(Math.multiplyExact(epochSeconds, 1_000_000_000L), nano); - default: - throw new SQLException("Expected vector type TIMESTAMP*, found " + typeInfo.columnType); - } - } - - private static void checkUnsignedRange(String typeName, long value, long maxValue) throws SQLException { - if (value < 0 || value > maxValue) { - throw new SQLException("Value out of range for " + typeName + ": " + value); - } - } - - private SQLException decimalOutOfRange(BigDecimal value) { - return new SQLException("Value out of range for " + decimalTypeName() + ": " + value); - } - - private SQLException decimalOutOfRange(BigDecimal value, ArithmeticException cause) { - return new SQLException("Value out of range for " + decimalTypeName() + ": " + value, cause); - } - - private String decimalTypeName() { - return "DECIMAL(" + typeInfo.decimalMeta.width + "," + typeInfo.decimalMeta.scale + ")"; - } + +public interface DuckDBWritableVector { + DuckDBColumnType getType(); + + long rowCount(); + + void addNull() throws SQLException; + + void setNull(long row) throws SQLException; + + void addBoolean(boolean value) throws SQLException; + + void setBoolean(long row, boolean value) throws SQLException; + + void addByte(byte value) throws SQLException; + + void setByte(long row, byte value) throws SQLException; + + void addShort(short value) throws SQLException; + + void setShort(long row, short value) throws SQLException; + + void addUint8(int value) throws SQLException; + + void setUint8(long row, int value) throws SQLException; + + void addUint16(int value) throws SQLException; + + void setUint16(long row, int value) throws SQLException; + + void addInt(int value) throws SQLException; + + void setInt(long row, int value) throws SQLException; + + void addUint32(long value) throws SQLException; + + void setUint32(long row, long value) throws SQLException; + + void addLong(long value) throws SQLException; + + void setLong(long row, long value) throws SQLException; + + void addUint64(BigInteger value) throws SQLException; + + void setUint64(long row, BigInteger value) throws SQLException; + + void addFloat(float value) throws SQLException; + + void setFloat(long row, float value) throws SQLException; + + void addDouble(double value) throws SQLException; + + void setDouble(long row, double value) throws SQLException; + + void addDate(LocalDate value) throws SQLException; + + void setDate(long row, LocalDate value) throws SQLException; + + void addDate(java.sql.Date value) throws SQLException; + + void setDate(long row, java.sql.Date value) throws SQLException; + + void addDate(java.util.Date value) throws SQLException; + + void setDate(long row, java.util.Date value) throws SQLException; + + void addTimestamp(LocalDateTime value) throws SQLException; + + void setTimestamp(long row, LocalDateTime value) throws SQLException; + + void addTimestamp(Timestamp value) throws SQLException; + + void setTimestamp(long row, Timestamp value) throws SQLException; + + void addTimestamp(java.util.Date value) throws SQLException; + + void setTimestamp(long row, java.util.Date value) throws SQLException; + + void addTimestamp(LocalDate value) throws SQLException; + + void setTimestamp(long row, LocalDate value) throws SQLException; + + void addOffsetDateTime(OffsetDateTime value) throws SQLException; + + void setOffsetDateTime(long row, OffsetDateTime value) throws SQLException; + + void addBigDecimal(BigDecimal value) throws SQLException; + + void setBigDecimal(long row, BigDecimal value) throws SQLException; + + void addString(String value) throws SQLException; + + void setString(long row, String value) throws SQLException; } diff --git a/src/main/java/org/duckdb/DuckDBWritableVectorImpl.java b/src/main/java/org/duckdb/DuckDBWritableVectorImpl.java new file mode 100644 index 000000000..4b5a68e47 --- /dev/null +++ b/src/main/java/org/duckdb/DuckDBWritableVectorImpl.java @@ -0,0 +1,707 @@ +package org.duckdb; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.duckdb.DuckDBBindings.*; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.LongBuffer; +import java.sql.SQLException; +import java.sql.Timestamp; +import java.time.Instant; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.OffsetDateTime; +import java.time.ZoneId; +import java.time.ZoneOffset; + +final class DuckDBWritableVectorImpl implements DuckDBWritableVector { + private static final BigInteger UINT64_MAX = new BigInteger("18446744073709551615"); + private static final ByteOrder NATIVE_ORDER = ByteOrder.nativeOrder(); + + private final ByteBuffer vectorRef; + private final long rowCount; + private final DuckDBVectorTypeInfo typeInfo; + private final ByteBuffer data; + private ByteBuffer validity; + private long appendIndex; + + DuckDBWritableVectorImpl(ByteBuffer vectorRef, long rowCount) throws SQLException { + if (vectorRef == null) { + throw new SQLException("Invalid vector reference"); + } + this.vectorRef = vectorRef; + this.rowCount = rowCount; + this.typeInfo = DuckDBVectorTypeInfo.fromVector(vectorRef); + this.data = duckdb_vector_get_data(vectorRef, Math.multiplyExact(rowCount, typeInfo.widthBytes)); + this.validity = duckdb_vector_get_validity(vectorRef, rowCount); + } + + @Override + public DuckDBColumnType getType() { + return typeInfo.columnType; + } + + @Override + public long rowCount() { + return rowCount; + } + + @Override + public void addNull() throws SQLException { + setNull(nextAppendRow()); + } + + @Override + public void setNull(long row) throws SQLException { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + ensureValidity(); + setRowValidity(row, false); + advanceAppendIndex(row); + } + + @Override + public void addBoolean(boolean value) throws SQLException { + setBoolean(nextAppendRow(), value); + } + + @Override + public void setBoolean(long row, boolean value) throws SQLException { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = typeMismatchMessage(DuckDBColumnType.BOOLEAN); + if (typeError != null) { + throw new SQLException(typeError); + } + data.put(checkedRowIndex(row), value ? (byte) 1 : (byte) 0); + markValid(row); + } + + @Override + public void addByte(byte value) throws SQLException { + setByte(nextAppendRow(), value); + } + + @Override + public void setByte(long row, byte value) throws SQLException { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = typeMismatchMessage(DuckDBColumnType.TINYINT); + if (typeError != null) { + throw new SQLException(typeError); + } + data.put(checkedRowIndex(row), value); + markValid(row); + } + + @Override + public void addShort(short value) throws SQLException { + setShort(nextAppendRow(), value); + } + + @Override + public void setShort(long row, short value) throws SQLException { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = typeMismatchMessage(DuckDBColumnType.SMALLINT); + if (typeError != null) { + throw new SQLException(typeError); + } + data.order(NATIVE_ORDER).putShort(checkedByteOffset(row, Short.BYTES), value); + markValid(row); + } + + @Override + public void addUint8(int value) throws SQLException { + setUint8(nextAppendRow(), value); + } + + @Override + public void setUint8(long row, int value) throws SQLException { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = typeMismatchMessage(DuckDBColumnType.UTINYINT); + if (typeError != null) { + throw new SQLException(typeError); + } + String rangeError = unsignedRangeErrorMessage("UTINYINT", value, 0xFFL); + if (rangeError != null) { + throw new SQLException(rangeError); + } + data.put(checkedRowIndex(row), (byte) value); + markValid(row); + } + + @Override + public void addUint16(int value) throws SQLException { + setUint16(nextAppendRow(), value); + } + + @Override + public void setUint16(long row, int value) throws SQLException { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = typeMismatchMessage(DuckDBColumnType.USMALLINT); + if (typeError != null) { + throw new SQLException(typeError); + } + String rangeError = unsignedRangeErrorMessage("USMALLINT", value, 0xFFFFL); + if (rangeError != null) { + throw new SQLException(rangeError); + } + data.order(NATIVE_ORDER).putShort(checkedByteOffset(row, Short.BYTES), (short) value); + markValid(row); + } + + @Override + public void addInt(int value) throws SQLException { + setInt(nextAppendRow(), value); + } + + @Override + public void setInt(long row, int value) throws SQLException { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = typeMismatchMessage(DuckDBColumnType.INTEGER); + if (typeError != null) { + throw new SQLException(typeError); + } + data.order(NATIVE_ORDER).putInt(checkedByteOffset(row, Integer.BYTES), value); + markValid(row); + } + + @Override + public void addUint32(long value) throws SQLException { + setUint32(nextAppendRow(), value); + } + + @Override + public void setUint32(long row, long value) throws SQLException { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = typeMismatchMessage(DuckDBColumnType.UINTEGER); + if (typeError != null) { + throw new SQLException(typeError); + } + String rangeError = unsignedRangeErrorMessage("UINTEGER", value, 0xFFFFFFFFL); + if (rangeError != null) { + throw new SQLException(rangeError); + } + data.order(NATIVE_ORDER).putInt(checkedByteOffset(row, Integer.BYTES), (int) value); + markValid(row); + } + + @Override + public void addLong(long value) throws SQLException { + setLong(nextAppendRow(), value); + } + + @Override + public void setLong(long row, long value) throws SQLException { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = typeMismatchMessage(DuckDBColumnType.BIGINT); + if (typeError != null) { + throw new SQLException(typeError); + } + data.order(NATIVE_ORDER).putLong(checkedByteOffset(row, Long.BYTES), value); + markValid(row); + } + + @Override + public void addUint64(BigInteger value) throws SQLException { + setUint64(nextAppendRow(), value); + } + + @Override + public void setUint64(long row, BigInteger value) throws SQLException { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = typeMismatchMessage(DuckDBColumnType.UBIGINT); + if (typeError != null) { + throw new SQLException(typeError); + } + if (value == null) { + setNull(row); + return; + } + if (value.signum() < 0 || value.compareTo(UINT64_MAX) > 0) { + throw new SQLException("Value out of range for UBIGINT: " + value); + } + data.order(NATIVE_ORDER).putLong(checkedByteOffset(row, Long.BYTES), value.longValue()); + markValid(row); + } + + @Override + public void addFloat(float value) throws SQLException { + setFloat(nextAppendRow(), value); + } + + @Override + public void setFloat(long row, float value) throws SQLException { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = typeMismatchMessage(DuckDBColumnType.FLOAT); + if (typeError != null) { + throw new SQLException(typeError); + } + data.order(NATIVE_ORDER).putFloat(checkedByteOffset(row, Float.BYTES), value); + markValid(row); + } + + @Override + public void addDouble(double value) throws SQLException { + setDouble(nextAppendRow(), value); + } + + @Override + public void setDouble(long row, double value) throws SQLException { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = typeMismatchMessage(DuckDBColumnType.DOUBLE); + if (typeError != null) { + throw new SQLException(typeError); + } + data.order(NATIVE_ORDER).putDouble(checkedByteOffset(row, Double.BYTES), value); + markValid(row); + } + + @Override + public void addDate(LocalDate value) throws SQLException { + setDate(nextAppendRow(), value); + } + + @Override + public void setDate(long row, LocalDate value) throws SQLException { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = typeMismatchMessage(DuckDBColumnType.DATE); + if (typeError != null) { + throw new SQLException(typeError); + } + if (value == null) { + setNull(row); + return; + } + long days = value.toEpochDay(); + if (days < Integer.MIN_VALUE || days > Integer.MAX_VALUE) { + throw new SQLException("Value out of range for DATE: " + value); + } + data.order(NATIVE_ORDER).putInt(checkedByteOffset(row, Integer.BYTES), (int) days); + markValid(row); + } + + @Override + public void addDate(java.sql.Date value) throws SQLException { + setDate(nextAppendRow(), value); + } + + @Override + public void setDate(long row, java.sql.Date value) throws SQLException { + setDate(row, value == null ? null : value.toLocalDate()); + } + + @Override + public void addDate(java.util.Date value) throws SQLException { + setDate(nextAppendRow(), value); + } + + @Override + public void setDate(long row, java.util.Date value) throws SQLException { + if (value == null) { + setNull(row); + return; + } + if (value instanceof java.sql.Date) { + setDate(row, (java.sql.Date) value); + return; + } + LocalDate localDate = Instant.ofEpochMilli(value.getTime()).atZone(ZoneOffset.UTC).toLocalDate(); + setDate(row, localDate); + } + + @Override + public void addTimestamp(LocalDateTime value) throws SQLException { + setTimestamp(nextAppendRow(), value); + } + + @Override + public void setTimestamp(long row, LocalDateTime value) throws SQLException { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = timestampTypeMismatchMessage(false); + if (typeError != null) { + throw new SQLException(typeError); + } + if (value == null) { + setNull(row); + return; + } + data.order(NATIVE_ORDER).putLong(checkedByteOffset(row, Long.BYTES), encodeLocalDateTime(value)); + markValid(row); + } + + @Override + public void addTimestamp(Timestamp value) throws SQLException { + setTimestamp(nextAppendRow(), value); + } + + @Override + public void setTimestamp(long row, Timestamp value) throws SQLException { + if (value == null) { + setNull(row); + return; + } + if (typeInfo.columnType == DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE) { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + data.order(NATIVE_ORDER).putLong(checkedByteOffset(row, Long.BYTES), encodeInstant(value.toInstant())); + markValid(row); + return; + } + setTimestamp(row, value.toLocalDateTime()); + } + + @Override + public void addTimestamp(java.util.Date value) throws SQLException { + setTimestamp(nextAppendRow(), value); + } + + @Override + public void setTimestamp(long row, java.util.Date value) throws SQLException { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = timestampTypeMismatchMessage(false); + if (typeError != null) { + throw new SQLException(typeError); + } + if (value == null) { + setNull(row); + return; + } + if (value instanceof Timestamp) { + setTimestamp(row, (Timestamp) value); + return; + } + data.order(NATIVE_ORDER).putLong(checkedByteOffset(row, Long.BYTES), encodeJavaUtilDate(value)); + markValid(row); + } + + @Override + public void addTimestamp(LocalDate value) throws SQLException { + setTimestamp(nextAppendRow(), value); + } + + @Override + public void setTimestamp(long row, LocalDate value) throws SQLException { + if (value == null) { + setNull(row); + return; + } + if (typeInfo.columnType == DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE) { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + Instant instant = value.atStartOfDay(ZoneId.systemDefault()).toInstant(); + data.order(NATIVE_ORDER).putLong(checkedByteOffset(row, Long.BYTES), encodeInstant(instant)); + markValid(row); + return; + } + setTimestamp(row, value.atStartOfDay()); + } + + @Override + public void addOffsetDateTime(OffsetDateTime value) throws SQLException { + setOffsetDateTime(nextAppendRow(), value); + } + + @Override + public void setOffsetDateTime(long row, OffsetDateTime value) throws SQLException { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = timestampTypeMismatchMessage(true); + if (typeError != null) { + throw new SQLException(typeError); + } + if (value == null) { + setNull(row); + return; + } + data.order(NATIVE_ORDER) + .putLong( + checkedByteOffset(row, Long.BYTES), + DuckDBTimestamp.localDateTime2Micros(value.withOffsetSameInstant(ZoneOffset.UTC).toLocalDateTime())); + markValid(row); + } + + @Override + public void addBigDecimal(BigDecimal value) throws SQLException { + setBigDecimal(nextAppendRow(), value); + } + + @Override + public void setBigDecimal(long row, BigDecimal value) throws SQLException { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = typeMismatchMessage(DuckDBColumnType.DECIMAL); + if (typeError != null) { + throw new SQLException(typeError); + } + if (value == null) { + setNull(row); + return; + } + BigDecimal scaled; + try { + scaled = value.setScale(typeInfo.decimalMeta.scale); + } catch (ArithmeticException e) { + throw decimalOutOfRange(value, e); + } + if (scaled.precision() > typeInfo.decimalMeta.width) { + throw decimalOutOfRange(value); + } + switch (typeInfo.storageType) { + case DUCKDB_TYPE_SMALLINT: + try { + data.order(NATIVE_ORDER) + .putShort(checkedByteOffset(row, Short.BYTES), scaled.unscaledValue().shortValueExact()); + } catch (ArithmeticException e) { + throw decimalOutOfRange(value, e); + } + break; + case DUCKDB_TYPE_INTEGER: + try { + data.order(NATIVE_ORDER) + .putInt(checkedByteOffset(row, Integer.BYTES), scaled.unscaledValue().intValueExact()); + } catch (ArithmeticException e) { + throw decimalOutOfRange(value, e); + } + break; + case DUCKDB_TYPE_BIGINT: + try { + data.order(NATIVE_ORDER) + .putLong(checkedByteOffset(row, Long.BYTES), scaled.unscaledValue().longValueExact()); + } catch (ArithmeticException e) { + throw decimalOutOfRange(value, e); + } + break; + case DUCKDB_TYPE_HUGEINT: { + BigInteger unscaled = scaled.unscaledValue(); + ByteBuffer slice = data.duplicate().order(NATIVE_ORDER); + slice.position(checkedByteOffset(row, typeInfo.widthBytes)); + slice.putLong(unscaled.longValue()); + slice.putLong(unscaled.shiftRight(Long.SIZE).longValue()); + break; + } + default: + throw new SQLException("Unsupported DECIMAL storage type: " + typeInfo.storageType); + } + markValid(row); + } + + @Override + public void addString(String value) throws SQLException { + setString(nextAppendRow(), value); + } + + @Override + public void setString(long row, String value) throws SQLException { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = typeMismatchMessage(DuckDBColumnType.VARCHAR); + if (typeError != null) { + throw new SQLException(typeError); + } + if (value == null) { + setNull(row); + return; + } + duckdb_vector_assign_string_element_len(vectorRef, row, value.getBytes(UTF_8)); + markValid(row); + } + + ByteBuffer vectorRef() { + return vectorRef; + } + + private void ensureValidity() throws SQLException { + if (validity != null) { + return; + } + duckdb_vector_ensure_validity_writable(vectorRef); + validity = duckdb_vector_get_validity(vectorRef, rowCount); + if (validity == null) { + throw new SQLException("Cannot initialize vector validity"); + } + } + + private void markValid(long row) { + if (validity == null) { + advanceAppendIndex(row); + return; + } + setRowValidity(row, true); + advanceAppendIndex(row); + } + + private void setRowValidity(long row, boolean valid) { + LongBuffer entries = validity.asLongBuffer(); + int entryIndex = Math.toIntExact(row / Long.SIZE); + long bitIndex = row % Long.SIZE; + long mask = 1L << bitIndex; + long entry = entries.get(entryIndex); + if (valid) { + entry |= mask; + } else { + entry &= ~mask; + } + entries.put(entryIndex, entry); + } + + private String typeMismatchMessage(DuckDBColumnType expected) { + if (typeInfo.columnType != expected) { + return "Expected vector type " + expected + ", found " + typeInfo.columnType; + } + return null; + } + + private String rowIndexErrorMessage(long row) { + if (row < 0 || row >= rowCount) { + return "Row index out of bounds: " + row; + } + return null; + } + + private String timestampTypeMismatchMessage(boolean requireTimezone) { + if (requireTimezone) { + if (typeInfo.columnType != DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE) { + return "Expected vector type TIMESTAMP WITH TIME ZONE, found " + typeInfo.columnType; + } + return null; + } + switch (typeInfo.columnType) { + case TIMESTAMP: + case TIMESTAMP_S: + case TIMESTAMP_MS: + case TIMESTAMP_NS: + case TIMESTAMP_WITH_TIME_ZONE: + return null; + default: + return "Expected vector type TIMESTAMP*, found " + typeInfo.columnType; + } + } + + private long encodeLocalDateTime(LocalDateTime value) throws SQLException { + Instant instant; + if (typeInfo.columnType == DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE) { + instant = value.atZone(ZoneId.systemDefault()).toInstant(); + } else { + instant = value.toInstant(ZoneOffset.UTC); + } + return encodeInstant(instant); + } + + private long encodeJavaUtilDate(java.util.Date value) throws SQLException { + return encodeInstant(Instant.ofEpochMilli(value.getTime())); + } + + private long encodeInstant(Instant instant) throws SQLException { + long epochSeconds = instant.getEpochSecond(); + int nano = instant.getNano(); + switch (typeInfo.capiType) { + case DUCKDB_TYPE_TIMESTAMP_S: + return epochSeconds; + case DUCKDB_TYPE_TIMESTAMP_MS: + return Math.addExact(Math.multiplyExact(epochSeconds, 1_000L), nano / 1_000_000L); + case DUCKDB_TYPE_TIMESTAMP: + case DUCKDB_TYPE_TIMESTAMP_TZ: + return Math.addExact(Math.multiplyExact(epochSeconds, 1_000_000L), nano / 1_000L); + case DUCKDB_TYPE_TIMESTAMP_NS: + return Math.addExact(Math.multiplyExact(epochSeconds, 1_000_000_000L), nano); + default: + throw new SQLException("Expected vector type TIMESTAMP*, found " + typeInfo.columnType); + } + } + + private static String unsignedRangeErrorMessage(String typeName, long value, long maxValue) { + if (value < 0 || value > maxValue) { + return "Value out of range for " + typeName + ": " + value; + } + return null; + } + + private SQLException decimalOutOfRange(BigDecimal value) { + return new SQLException("Value out of range for " + decimalTypeName() + ": " + value); + } + + private SQLException decimalOutOfRange(BigDecimal value, ArithmeticException cause) { + SQLException exception = decimalOutOfRange(value); + exception.initCause(cause); + return exception; + } + + private String decimalTypeName() { + return "DECIMAL(" + typeInfo.decimalMeta.width + "," + typeInfo.decimalMeta.scale + ")"; + } + + private int checkedRowIndex(long row) { + return Math.toIntExact(row); + } + + private int checkedByteOffset(long row, int elementWidth) { + return Math.toIntExact(Math.multiplyExact(row, (long) elementWidth)); + } + + private void advanceAppendIndex(long row) { + appendIndex = Math.max(appendIndex, Math.addExact(row, 1)); + } + + private long nextAppendRow() { + if (appendIndex >= rowCount) { + throw new IndexOutOfBoundsException("Append index out of bounds: " + appendIndex); + } + return appendIndex; + } +} diff --git a/src/test/java/org/duckdb/TestBindings.java b/src/test/java/org/duckdb/TestBindings.java index ca8b2431f..a947304e9 100644 --- a/src/test/java/org/duckdb/TestBindings.java +++ b/src/test/java/org/duckdb/TestBindings.java @@ -22,6 +22,71 @@ public static void test_bindings_vector_size() throws Exception { assertTrue(size > 0); } + public static void test_bindings_vector_row_index_stream() throws Exception { + ByteBuffer lt = duckdb_create_logical_type(DUCKDB_TYPE_INTEGER.typeId); + ByteBuffer inputVec = duckdb_create_vector(lt); + ByteBuffer outputVec = duckdb_create_vector(lt); + + DuckDBWritableVector input = new DuckDBWritableVectorImpl(inputVec, 3); + input.setInt(0, 1); + input.setInt(1, 41); + input.setInt(2, -5); + + DuckDBReadableVector readable = new DuckDBReadableVectorImpl(inputVec, 3); + DuckDBWritableVector output = new DuckDBWritableVectorImpl(outputVec, 3); + readable.rowIndexStream().forEachOrdered(row -> { + try { + output.setInt(row, readable.getInt(row) + 1); + } catch (SQLException exception) { + throw new RuntimeException(exception); + } + }); + + DuckDBReadableVector result = new DuckDBReadableVectorImpl(outputVec, 3); + assertEquals(result.getInt(0), 2); + assertEquals(result.getInt(1), 42); + assertEquals(result.getInt(2), -4); + + duckdb_destroy_vector(outputVec); + duckdb_destroy_vector(inputVec); + duckdb_destroy_logical_type(lt); + } + + public static void test_bindings_writable_vector_append_after_indexed_write() throws Exception { + ByteBuffer lt = duckdb_create_logical_type(DUCKDB_TYPE_INTEGER.typeId); + ByteBuffer vec = duckdb_create_vector(lt); + + DuckDBWritableVector writable = new DuckDBWritableVectorImpl(vec, 3); + writable.setNull(0); + writable.setInt(1, 41); + writable.addInt(42); + + DuckDBReadableVector readable = new DuckDBReadableVectorImpl(vec, 3); + assertTrue(readable.isNull(0)); + assertEquals(readable.getInt(1), 41); + assertEquals(readable.getInt(2), 42); + + duckdb_destroy_vector(vec); + duckdb_destroy_logical_type(lt); + } + + public static void test_bindings_writable_vector_failed_append_does_not_advance() throws Exception { + ByteBuffer lt = duckdb_create_logical_type(DUCKDB_TYPE_INTEGER.typeId); + ByteBuffer vec = duckdb_create_vector(lt); + + DuckDBWritableVector writable = new DuckDBWritableVectorImpl(vec, 2); + assertThrows(() -> { writable.addString("boom"); }, SQLException.class); + writable.addInt(7); + writable.addInt(8); + + DuckDBReadableVector readable = new DuckDBReadableVectorImpl(vec, 2); + assertEquals(readable.getInt(0), 7); + assertEquals(readable.getInt(1), 8); + + duckdb_destroy_vector(vec); + duckdb_destroy_logical_type(lt); + } + public static void test_bindings_logical_type() throws Exception { ByteBuffer lt = duckdb_create_logical_type(DUCKDB_TYPE_INTEGER.typeId); assertNotNull(lt); @@ -130,19 +195,19 @@ public static void test_bindings_vector_strings() throws Exception { duckdb_destroy_logical_type(lt); } - public static void test_bindings_varchar_string_bytes_null_row() throws Exception { + public static void test_bindings_vector_get_string() throws Exception { ByteBuffer lt = duckdb_create_logical_type(DUCKDB_TYPE_VARCHAR.typeId); ByteBuffer vec = duckdb_create_vector(lt); long rowCount = duckdb_vector_size(); - ByteBuffer data = duckdb_vector_get_data(vec, rowCount * STRING_T_SIZE_BYTES); - duckdb_vector_ensure_validity_writable(vec); - ByteBuffer validity = duckdb_vector_get_validity(vec, rowCount); + DuckDBWritableVector writable = new DuckDBWritableVectorImpl(vec, rowCount); + writable.setNull(0); + writable.setString(1, "duckdb"); - duckdb_validity_set_row_validity(validity, 0L, false); - assertNull(duckdb_jdbc_varchar_string_bytes(data, validity, rowCount, 0L)); - assertThrows( - () -> { duckdb_jdbc_varchar_string_bytes(data, validity, rowCount, rowCount); }, SQLException.class); + DuckDBReadableVector readable = new DuckDBReadableVectorImpl(vec, rowCount); + assertNull(readable.getString(0)); + assertEquals(readable.getString(1), "duckdb"); + assertThrows(() -> { readable.getString(rowCount); }, IndexOutOfBoundsException.class); duckdb_destroy_vector(vec); duckdb_destroy_logical_type(lt); @@ -154,19 +219,46 @@ public static void test_bindings_vector_native_endian_roundtrip() throws Excepti int rowCount = (int) duckdb_vector_size(); int expected = 0x01020304; - DuckDBWritableVector writable = new DuckDBWritableVector(vec, rowCount); + DuckDBWritableVector writable = new DuckDBWritableVectorImpl(vec, rowCount); writable.setInt(0, expected); ByteBuffer rawData = duckdb_vector_get_data(vec, (long) rowCount * Integer.BYTES); assertEquals(rawData.order(ByteOrder.nativeOrder()).getInt(0), expected); - DuckDBReadableVector readable = new DuckDBReadableVector(vec, rowCount); + DuckDBReadableVector readable = new DuckDBReadableVectorImpl(vec, rowCount); assertEquals(readable.getInt(0), expected); duckdb_destroy_vector(vec); duckdb_destroy_logical_type(lt); } + public static void test_bindings_writable_vector_stack_trace_origin() throws Exception { + ByteBuffer lt = duckdb_create_logical_type(DUCKDB_TYPE_VARCHAR.typeId); + ByteBuffer vec = duckdb_create_vector(lt); + + int rowCount = (int) duckdb_vector_size(); + DuckDBWritableVector writable = new DuckDBWritableVectorImpl(vec, rowCount); + + try { + writable.setInt(0, 42); + fail("Expected setInt to reject VARCHAR vector"); + } catch (SQLException exception) { + assertTrue(exception.getMessage().contains("Expected vector type INTEGER, found VARCHAR")); + assertEquals(exception.getStackTrace()[0].getMethodName(), "setInt"); + } + + try { + writable.setString(rowCount, "boom"); + fail("Expected setString to reject out-of-bounds row"); + } catch (IndexOutOfBoundsException exception) { + assertTrue(exception.getMessage().contains("Row index out of bounds")); + assertEquals(exception.getStackTrace()[0].getMethodName(), "setString"); + } + + duckdb_destroy_vector(vec); + duckdb_destroy_logical_type(lt); + } + public static void test_bindings_vector_ubigint_native_endian_roundtrip() throws Exception { ByteBuffer lt = duckdb_create_logical_type(DUCKDB_TYPE_UBIGINT.typeId); ByteBuffer vec = duckdb_create_vector(lt); @@ -177,7 +269,7 @@ public static void test_bindings_vector_ubigint_native_endian_roundtrip() throws new BigInteger[] {BigInteger.ZERO, new BigInteger("42"), new BigInteger("9223372036854775808"), new BigInteger("18446744073709551615")}; - DuckDBWritableVector writable = new DuckDBWritableVector(vec, rowCount); + DuckDBWritableVector writable = new DuckDBWritableVectorImpl(vec, rowCount); for (int i = 0; i < values.length; i++) { writable.setUint64(i, values[i]); } @@ -188,7 +280,7 @@ public static void test_bindings_vector_ubigint_native_endian_roundtrip() throws assertEquals(nativeData.getLong(i * Long.BYTES), values[i].longValue()); } - DuckDBReadableVector readable = new DuckDBReadableVector(vec, rowCount); + DuckDBReadableVector readable = new DuckDBReadableVectorImpl(vec, rowCount); for (int i = 0; i < values.length; i++) { assertEquals(readable.getUint64(i), values[i]); } @@ -300,6 +392,53 @@ public static void test_bindings_validity() throws Exception { duckdb_destroy_logical_type(lt); } + public static void test_bindings_writable_vector_validity_word_boundaries() throws Exception { + ByteBuffer lt = duckdb_create_logical_type(DUCKDB_TYPE_INTEGER.typeId); + ByteBuffer vec = duckdb_create_vector(lt); + + long rowCount = 70; + DuckDBWritableVector writable = new DuckDBWritableVectorImpl(vec, rowCount); + long[] boundaryRows = new long[] {63, 64, 65, rowCount - 1}; + long[] sentinelRows = new long[] {62, 66, 67, 68}; + + for (long row : boundaryRows) { + writable.setInt(row, (int) row); + } + for (long row : sentinelRows) { + writable.setInt(row, (int) (row + 10_000)); + } + + for (long row : boundaryRows) { + writable.setNull(row); + } + + DuckDBReadableVector readable = new DuckDBReadableVectorImpl(vec, rowCount); + for (long row : boundaryRows) { + assertTrue(readable.isNull(row)); + } + for (long row : sentinelRows) { + assertFalse(readable.isNull(row)); + assertEquals(readable.getInt(row), (int) (row + 10_000)); + } + + for (long row : boundaryRows) { + writable.setInt(row, (int) (row + 1000)); + } + + DuckDBReadableVector revalidated = new DuckDBReadableVectorImpl(vec, rowCount); + for (long row : boundaryRows) { + assertFalse(revalidated.isNull(row)); + assertEquals(revalidated.getInt(row), (int) (row + 1000)); + } + for (long row : sentinelRows) { + assertFalse(revalidated.isNull(row)); + assertEquals(revalidated.getInt(row), (int) (row + 10_000)); + } + + duckdb_destroy_vector(vec); + duckdb_destroy_logical_type(lt); + } + public static void test_bindings_data_chunk() throws Exception { ByteBuffer intType = duckdb_create_logical_type(DUCKDB_TYPE_INTEGER.typeId); ByteBuffer varcharType = duckdb_create_logical_type(DUCKDB_TYPE_VARCHAR.typeId); diff --git a/src/test/java/org/duckdb/TestScalarFunctions.java b/src/test/java/org/duckdb/TestScalarFunctions.java index 54bd6255b..e730b1012 100644 --- a/src/test/java/org/duckdb/TestScalarFunctions.java +++ b/src/test/java/org/duckdb/TestScalarFunctions.java @@ -6,9 +6,13 @@ import static org.duckdb.TestDuckDBJDBC.JDBC_URL; import static org.duckdb.test.Assertions.*; +import java.lang.reflect.Proxy; import java.math.BigDecimal; import java.math.BigInteger; import java.nio.ByteBuffer; +import java.nio.file.Files; +import java.nio.file.Path; +import java.sql.Connection; import java.sql.Date; import java.sql.DriverManager; import java.sql.ResultSet; @@ -20,12 +24,26 @@ import java.time.LocalDateTime; import java.time.OffsetDateTime; import java.time.ZoneOffset; +import java.util.List; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.function.Supplier; public class TestScalarFunctions { private interface ResultSetVerifier { void verify(ResultSet rs) throws Exception; } + private static int sumNonNullIntColumns(DuckDBScalarContext ctx, DuckDBScalarRow row) { + int sum = 0; + for (int columnIndex = 0; columnIndex < ctx.columnCount(); columnIndex++) { + if (!row.isNull(columnIndex)) { + sum += row.getInt(columnIndex); + } + } + return sum; + } + public static void test_bindings_scalar_function() throws Exception { ByteBuffer intType = duckdb_create_logical_type(DUCKDB_TYPE_INTEGER.typeId); ByteBuffer scalarFunction = duckdb_create_scalar_function(); @@ -49,26 +67,1084 @@ public static void test_bindings_scalar_function() throws Exception { assertThrows(() -> { duckdb_scalar_function_set_name(null, "x".getBytes(UTF_8)); }, SQLException.class); } - public static void test_register_scalar_function() throws Exception { - test_register_scalar_function_integer(); + public static void test_register_scalar_function() throws Exception { + test_register_scalar_function_integer(); + } + + public static void test_register_scalar_function_builder() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement(); + DuckDBLogicalType intType = DuckDBLogicalType.of(DuckDBColumnType.INTEGER)) { + DuckDBRegisteredFunction function = + DuckDBFunctions.scalarFunction() + .withName("java_add_int_builder") + .withParameter(intType) + .withReturnType(intType) + .propagateNulls(true) + .withVectorizedFunction( + ctx -> { ctx.stream().forEachOrdered(row -> row.setInt(row.getInt(0) + 1)); }) + .register(conn); + assertEquals(function.name(), "java_add_int_builder"); + assertEquals(function.parameterTypes().size(), 1); + assertEquals(function.parameterTypes().get(0), intType); + assertEquals(function.returnType(), intType); + assertEquals(function.varArgType(), null); + assertEquals(function.isVolatile(), false); + assertEquals(function.hasSpecialHandling(), false); + assertEquals(function.propagateNulls(), true); + + try (ResultSet rs = + stmt.executeQuery("SELECT java_add_int_builder(v) FROM (VALUES (1), (NULL), (41)) t(v)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 2); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 42); + assertFalse(rs.next()); + } + } + } + + public static void test_register_scalar_function_builder_connection_without_unwrap() throws Exception { + try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement()) { + DuckDBFunctions.scalarFunction() + .withName("java_add_int_connection") + .withParameter(Integer.class) + .withReturnType(Integer.class) + .withFunction((Integer x) -> x + 1) + .register(conn); + + try (ResultSet rs = + stmt.executeQuery("SELECT java_add_int_connection(v) FROM (VALUES (1), (NULL), (41)) t(v)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 2); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 42); + assertFalse(rs.next()); + } + } + } + + public static void test_register_scalar_function_builder_returns_detached_metadata() throws Exception { + DuckDBRegisteredFunction function; + try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement(); + DuckDBScalarFunctionBuilder builder = DuckDBFunctions.scalarFunction()) { + function = builder.withName("java_add_int_detached") + .withParameter(Integer.class) + .withReturnType(Integer.class) + .withFunction((Integer x) -> x + 1) + .register(conn); + + String message = + assertThrows(() -> { builder.withName("java_add_int_detached_again"); }, SQLException.class); + assertTrue(message.contains("already finalized")); + + assertEquals(function.name(), "java_add_int_detached"); + assertEquals(function.parameterColumnTypes().size(), 1); + assertEquals(function.parameterColumnTypes().get(0), DuckDBColumnType.INTEGER); + assertEquals(function.returnColumnType(), DuckDBColumnType.INTEGER); + assertNotNull(function.function()); + assertEquals(function.functionKind(), DuckDBFunctionKind.SCALAR); + assertTrue(function.isScalar()); + assertEquals(function.propagateNulls(), true); + + try (ResultSet rs = + stmt.executeQuery("SELECT java_add_int_detached(v) FROM (VALUES (1), (NULL), (41)) t(v)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 2); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 42); + assertFalse(rs.next()); + } + } + } + + public static void test_register_scalar_function_registry_records_registered_functions() throws Exception { + DuckDBDriver.clearFunctionsRegistry(); + try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement()) { + DuckDBRegisteredFunction function = DuckDBFunctions.scalarFunction() + .withName("java_registry_recorded") + .withParameter(Integer.class) + .withReturnType(Integer.class) + .withFunction((Integer x) -> x + 1) + .register(conn); + + List registeredFunctions = DuckDBDriver.registeredFunctions(); + assertEquals(registeredFunctions.size(), 1); + assertEquals(registeredFunctions.get(0), function); + assertEquals(registeredFunctions.get(0).functionKind(), DuckDBFunctionKind.SCALAR); + assertTrue(registeredFunctions.get(0).isScalar()); + + try (ResultSet rs = stmt.executeQuery("SELECT java_registry_recorded(41)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 42); + assertFalse(rs.next()); + } + } finally { + DuckDBDriver.clearFunctionsRegistry(); + } + } + + public static void test_register_scalar_function_registry_is_read_only() throws Exception { + DuckDBDriver.clearFunctionsRegistry(); + try (Connection conn = DriverManager.getConnection(JDBC_URL)) { + DuckDBFunctions.scalarFunction() + .withName("java_registry_read_only") + .withParameter(Integer.class) + .withReturnType(Integer.class) + .withFunction((Integer x) -> x + 1) + .register(conn); + + List registeredFunctions = DuckDBDriver.registeredFunctions(); + assertThrows(() -> { registeredFunctions.add(null); }, UnsupportedOperationException.class); + } finally { + DuckDBDriver.clearFunctionsRegistry(); + } + } + + public static void test_register_scalar_function_registry_clear_only_clears_java_registry() throws Exception { + DuckDBDriver.clearFunctionsRegistry(); + try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement()) { + DuckDBFunctions.scalarFunction() + .withName("java_registry_clear_only") + .withParameter(Integer.class) + .withReturnType(Integer.class) + .withFunction((Integer x) -> x + 1) + .register(conn); + + assertEquals(DuckDBDriver.registeredFunctions().size(), 1); + DuckDBDriver.clearFunctionsRegistry(); + assertEquals(DuckDBDriver.registeredFunctions().size(), 0); + + try (ResultSet rs = stmt.executeQuery("SELECT java_registry_clear_only(41)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 42); + assertFalse(rs.next()); + } + } finally { + DuckDBDriver.clearFunctionsRegistry(); + } + } + + public static void test_register_scalar_function_registry_allows_duplicate_names() throws Exception { + DuckDBDriver.clearFunctionsRegistry(); + Path tempDir = Files.createTempDirectory("duckdb-registry"); + Path dbPathA = tempDir.resolve("registry-a.db"); + Path dbPathB = tempDir.resolve("registry-b.db"); + String urlA = "jdbc:duckdb:" + dbPathA.toAbsolutePath(); + String urlB = "jdbc:duckdb:" + dbPathB.toAbsolutePath(); + + try (Connection connA = DriverManager.getConnection(urlA); + Connection connB = DriverManager.getConnection(urlB)) { + DuckDBFunctions.scalarFunction() + .withName("java_registry_duplicate_name") + .withParameter(Integer.class) + .withReturnType(Integer.class) + .withFunction((Integer x) -> x + 1) + .register(connA); + DuckDBFunctions.scalarFunction() + .withName("java_registry_duplicate_name") + .withParameter(Integer.class) + .withReturnType(Integer.class) + .withFunction((Integer x) -> x + 2) + .register(connB); + + List registeredFunctions = DuckDBDriver.registeredFunctions(); + assertEquals(registeredFunctions.size(), 2); + assertEquals(registeredFunctions.get(0).name(), "java_registry_duplicate_name"); + assertEquals(registeredFunctions.get(1).name(), "java_registry_duplicate_name"); + } finally { + DuckDBDriver.clearFunctionsRegistry(); + Files.deleteIfExists(dbPathA); + Files.deleteIfExists(dbPathB); + Files.deleteIfExists(tempDir); + } + } + + public static void test_register_scalar_function_builder_rejects_non_duckdb_connection() throws Exception { + Connection connection = (Connection) Proxy.newProxyInstance( + Connection.class.getClassLoader(), new Class[] {Connection.class}, (proxy, method, args) -> { + switch (method.getName()) { + case "unwrap": + throw new SQLException("not a DuckDB connection"); + case "isWrapperFor": + return false; + case "toString": + return "invalid-connection"; + case "hashCode": + return System.identityHashCode(proxy); + case "equals": + return proxy == args[0]; + default: + throw new UnsupportedOperationException(method.getName()); + } + }); + + try (DuckDBScalarFunctionBuilder builder = DuckDBFunctions.scalarFunction()) { + builder.withName("java_invalid_connection") + .withParameter(Integer.class) + .withReturnType(Integer.class) + .withFunction((Integer x) -> x + 1); + + String message = assertThrows(() -> { builder.register(connection); }, SQLException.class); + assertTrue(message.contains("requires a DuckDB JDBC connection")); + } + } + + public static void test_register_scalar_function_builder_varargs_and_flags() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement(); + DuckDBLogicalType intType = DuckDBLogicalType.of(DuckDBColumnType.INTEGER)) { + DuckDBRegisteredFunction function = + DuckDBFunctions.scalarFunction() + .withName("java_sum_varargs_builder") + .withParameter(intType) + .withVarArgs(intType) + .withReturnType(intType) + .propagateNulls(true) + .withVolatile() + .withSpecialHandling() + .withVectorizedFunction( + ctx -> { ctx.stream().forEachOrdered(row -> { row.setInt(sumNonNullIntColumns(ctx, row)); }); }) + .register(conn); + assertEquals(function.varArgType(), intType); + assertEquals(function.isVolatile(), true); + assertEquals(function.hasSpecialHandling(), true); + assertEquals(function.propagateNulls(), true); + + try (ResultSet rs = + stmt.executeQuery("SELECT java_sum_varargs_builder(1, 2, 3), java_sum_varargs_builder(5)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 6); + assertEquals(rs.getObject(2, Integer.class), 5); + assertFalse(rs.next()); + } + } + } + + public static void test_register_scalar_function_builder_column_type_overloads() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBRegisteredFunction function = + DuckDBFunctions.scalarFunction() + .withName("java_add_int_builder_col_type") + .withParameter(DuckDBColumnType.INTEGER) + .withReturnType(DuckDBColumnType.INTEGER) + .propagateNulls(true) + .withVectorizedFunction( + ctx -> { ctx.stream().forEachOrdered(row -> row.setInt(row.getInt(0) + 1)); }) + .register(conn); + assertEquals(function.parameterColumnTypes().size(), 1); + assertEquals(function.parameterColumnTypes().get(0), DuckDBColumnType.INTEGER); + assertEquals(function.parameterTypes().get(0), null); + assertEquals(function.returnColumnType(), DuckDBColumnType.INTEGER); + assertEquals(function.returnType(), null); + + try (ResultSet rs = stmt.executeQuery( + "SELECT java_add_int_builder_col_type(v) FROM (VALUES (1), (NULL), (41)) t(v)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 2); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 42); + assertFalse(rs.next()); + } + } + } + + public static void test_register_scalar_function_builder_java_function() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBFunctions.scalarFunction() + .withName("java_add_int_function") + .withParameter(Integer.class) + .withReturnType(Integer.class) + .withFunction((Integer x) -> x + 1) + .register(conn); + + try (ResultSet rs = + stmt.executeQuery("SELECT java_add_int_function(v) FROM (VALUES (1), (NULL), (41)) t(v)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 2); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 42); + assertFalse(rs.next()); + } + } + } + + public static void test_register_scalar_function_builder_java_function_propagate_nulls_false() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBRegisteredFunction function = DuckDBFunctions.scalarFunction() + .withName("java_add_int_function_nullable") + .withParameter(Integer.class) + .withReturnType(Integer.class) + .withFunction((Integer x) -> x == null ? 99 : x + 1) + .propagateNulls(false) + .register(conn); + assertEquals(function.propagateNulls(), false); + + try (ResultSet rs = stmt.executeQuery( + "SELECT java_add_int_function_nullable(v) FROM (VALUES (1), (NULL), (41)) t(v)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 2); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 99); + assertFalse(rs.wasNull()); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 42); + assertFalse(rs.next()); + } + } + } + + public static void test_register_scalar_function_builder_java_bifunction() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBFunctions.scalarFunction() + .withName("java_add_int_bifunction") + .withParameter(Integer.class) + .withParameter(Integer.class) + .withReturnType(Integer.class) + .withFunction((Integer x, Integer y) -> x + y) + .register(conn); + + try ( + ResultSet rs = stmt.executeQuery( + "SELECT java_add_int_bifunction(a, b) FROM (VALUES (1, 2), (NULL, 2), (39, 3), (5, NULL)) t(a, b)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 3); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 42); + assertTrue(rs.next()); + assertNullRow(rs); + assertFalse(rs.next()); + } + } + } + + public static void test_register_scalar_function_builder_java_bifunction_propagate_nulls_false() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBRegisteredFunction function = + DuckDBFunctions.scalarFunction() + .withName("java_add_int_bifunction_nullable") + .withParameter(Integer.class) + .withParameter(Integer.class) + .withReturnType(Integer.class) + .withFunction( + (Integer left, Integer right) -> (left == null ? 0 : left) + (right == null ? 0 : right)) + .propagateNulls(false) + .register(conn); + assertEquals(function.propagateNulls(), false); + + try ( + ResultSet rs = stmt.executeQuery( + "SELECT java_add_int_bifunction_nullable(a, b) FROM (VALUES (1, 2), (NULL, 2), (39, NULL), (NULL, NULL)) t(a, b)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 3); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 2); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 39); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 0); + assertFalse(rs.next()); + } + } + } + + public static void test_register_scalar_function_builder_with_int_function() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBRegisteredFunction function = DuckDBFunctions.scalarFunction() + .withName("java_add_int_with_int_function") + .withParameter(Integer.class) + .withReturnType(Integer.class) + .withIntFunction(x -> x + 1) + .register(conn); + assertEquals(function.propagateNulls(), true); + + try (ResultSet rs = stmt.executeQuery( + "SELECT java_add_int_with_int_function(v) FROM (VALUES (1), (NULL), (41)) t(v)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 2); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 42); + assertFalse(rs.next()); + } + } + } + + public static void test_register_scalar_function_builder_with_int_binary_function() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBRegisteredFunction function = DuckDBFunctions.scalarFunction() + .withName("java_add_int_with_int_binary_function") + .withParameter(Integer.class) + .withParameter(Integer.class) + .withReturnType(Integer.class) + .withIntFunction((left, right) -> left + right) + .register(conn); + assertEquals(function.propagateNulls(), true); + + try (ResultSet rs = stmt.executeQuery("SELECT java_add_int_with_int_binary_function(a, b) " + + "FROM (VALUES (1, 2), (NULL, 2), (39, 3), (5, NULL)) t(a, b)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 3); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 42); + assertTrue(rs.next()); + assertNullRow(rs); + assertFalse(rs.next()); + } + } + } + + public static void test_register_scalar_function_builder_with_double_function() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBRegisteredFunction function = DuckDBFunctions.scalarFunction() + .withName("java_add_double_with_double_function") + .withParameter(Double.class) + .withReturnType(Double.class) + .withDoubleFunction(x -> x + 0.5d) + .register(conn); + assertEquals(function.propagateNulls(), true); + + try (ResultSet rs = stmt.executeQuery( + "SELECT java_add_double_with_double_function(v) FROM (VALUES (41.5), (NULL), (-2.5)) t(v)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Double.class), 42.0d); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Double.class), -2.0d); + assertFalse(rs.next()); + } + } + } + + public static void test_register_scalar_function_builder_with_double_binary_function() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBRegisteredFunction function = DuckDBFunctions.scalarFunction() + .withName("java_add_double_with_double_binary_function") + .withParameter(Double.class) + .withParameter(Double.class) + .withReturnType(Double.class) + .withDoubleFunction((left, right) -> left + right) + .register(conn); + assertEquals(function.propagateNulls(), true); + + try (ResultSet rs = + stmt.executeQuery("SELECT java_add_double_with_double_binary_function(a, b) FROM " + + "(VALUES (1.0, 2.0), (NULL, 2.0), (39.5, 2.5), (5.0, NULL)) t(a, b)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Double.class), 3.0d); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Double.class), 42.0d); + assertTrue(rs.next()); + assertNullRow(rs); + assertFalse(rs.next()); + } + } + } + + public static void test_register_scalar_function_builder_with_long_function() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBRegisteredFunction function = DuckDBFunctions.scalarFunction() + .withName("java_add_long_with_long_function") + .withParameter(Long.class) + .withReturnType(Long.class) + .withLongFunction(x -> x + 3) + .register(conn); + assertEquals(function.propagateNulls(), true); + + try ( + ResultSet rs = stmt.executeQuery( + "SELECT java_add_long_with_long_function(v) FROM (VALUES (39::BIGINT), (NULL), (-5::BIGINT)) t(v)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Long.class), 42L); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Long.class), -2L); + assertFalse(rs.next()); + } + } + } + + public static void test_register_scalar_function_builder_with_long_binary_function() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBRegisteredFunction function = DuckDBFunctions.scalarFunction() + .withName("java_add_long_with_long_binary_function") + .withParameter(Long.class) + .withParameter(Long.class) + .withReturnType(Long.class) + .withLongFunction((left, right) -> left + right) + .register(conn); + assertEquals(function.propagateNulls(), true); + + try (ResultSet rs = stmt.executeQuery("SELECT java_add_long_with_long_binary_function(a, b) " + + "FROM (VALUES (1::BIGINT, 2::BIGINT), (NULL, 2::BIGINT), " + + "(39::BIGINT, 3::BIGINT), (5::BIGINT, NULL)) t(a, b)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Long.class), 3L); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Long.class), 42L); + assertTrue(rs.next()); + assertNullRow(rs); + assertFalse(rs.next()); + } + } + } + + public static void test_register_scalar_function_builder_java_function_class_cast_error() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBFunctions.scalarFunction() + .withName("java_invalid_cast_function") + .withParameter(Integer.class) + .withReturnType(Integer.class) + .withFunction((String value) -> value.length()) + .register(conn); + + String message = + assertThrows(() -> { stmt.executeQuery("SELECT java_invalid_cast_function(1)"); }, SQLException.class); + assertTrue(message.contains("Java scalar function threw exception")); + assertTrue(message.contains("ClassCastException")); + } + } + + public static void test_register_scalar_function_builder_java_supplier() throws Exception { + assertNullaryJavaFunction("java_constant_supplier", Integer.class, + () -> 42, "SELECT java_constant_supplier() FROM range(3)", rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 42); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 42); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 42); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_builder_java_supplier_null_value() throws Exception { + assertNullaryJavaFunction("java_null_supplier", String.class, + () -> null, "SELECT java_null_supplier() FROM range(2)", rs -> { + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertNullRow(rs); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_builder_java_supplier_class_cast_error() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBFunctions.scalarFunction() + .withName("java_invalid_supplier_cast") + .withReturnType(Integer.class) + .withFunction(() -> "not_an_integer") + .register(conn); + + String message = + assertThrows(() -> { stmt.executeQuery("SELECT java_invalid_supplier_cast()"); }, SQLException.class); + assertTrue(message.contains("Java scalar function threw exception")); + assertTrue(message.contains("ClassCastException")); + } + } + + public static void test_register_scalar_function_builder_java_supplier_rejects_parameters() throws Exception { + try (DuckDBScalarFunctionBuilder builder = DuckDBFunctions.scalarFunction()) { + builder.withName("java_invalid_supplier_arity").withParameter(Integer.class).withReturnType(Integer.class); + String message = assertThrows(() -> { builder.withFunction(() -> 1); }, SQLException.class); + assertTrue(message.contains("Supplier callback requires zero declared parameters")); + } + } + + public static void test_register_scalar_function_builder_java_supplier_rejects_varargs() throws Exception { + try (DuckDBScalarFunctionBuilder builder = DuckDBFunctions.scalarFunction(); + DuckDBLogicalType intType = DuckDBLogicalType.of(DuckDBColumnType.INTEGER)) { + builder.withName("java_invalid_supplier_varargs").withReturnType(Integer.class).withVarArgs(intType); + String message = assertThrows(() -> { builder.withFunction(() -> 1); }, SQLException.class); + assertTrue(message.contains("Supplier callback does not support varargs")); + } + } + + public static void test_register_scalar_function_builder_java_function_rejects_varargs() throws Exception { + try (DuckDBScalarFunctionBuilder builder = DuckDBFunctions.scalarFunction(); + DuckDBLogicalType intType = DuckDBLogicalType.of(DuckDBColumnType.INTEGER)) { + builder.withName("java_invalid_function_varargs") + .withParameter(Integer.class) + .withReturnType(Integer.class) + .withVarArgs(intType); + String message = assertThrows(() -> { builder.withFunction((Integer x) -> x + 1); }, SQLException.class); + assertTrue(message.contains("Function callback does not support varargs")); + assertTrue(message.contains("withVarArgsFunction")); + } + } + + public static void test_register_scalar_function_builder_with_int_function_rejects_varargs() throws Exception { + try (DuckDBScalarFunctionBuilder builder = DuckDBFunctions.scalarFunction(); + DuckDBLogicalType intType = DuckDBLogicalType.of(DuckDBColumnType.INTEGER)) { + builder.withName("java_invalid_int_function_varargs") + .withParameter(Integer.class) + .withReturnType(Integer.class) + .withVarArgs(intType); + String message = assertThrows(() -> { builder.withIntFunction(x -> x + 1); }, SQLException.class); + assertTrue(message.contains("withIntFunction does not support varargs")); + } + } + + public static void test_register_scalar_function_builder_with_int_function_rejects_propagate_nulls_false() + throws Exception { + try (DuckDBScalarFunctionBuilder builder = DuckDBFunctions.scalarFunction()) { + builder.withName("java_invalid_int_function_null_propagation") + .withParameter(Integer.class) + .withReturnType(Integer.class) + .propagateNulls(false); + String message = assertThrows(() -> { builder.withIntFunction(x -> x + 1); }, SQLException.class); + assertTrue(message.contains("withIntFunction requires propagateNulls(true)")); + } + } + + public static void test_register_scalar_function_builder_with_int_function_rejects_disabling_propagate_nulls() + throws Exception { + try (DuckDBScalarFunctionBuilder builder = DuckDBFunctions.scalarFunction()) { + builder.withName("java_invalid_int_function_disable_null_propagation") + .withParameter(Integer.class) + .withReturnType(Integer.class) + .withIntFunction(x -> x + 1); + String message = assertThrows(() -> { builder.propagateNulls(false); }, SQLException.class); + assertTrue(message.contains("Primitive scalar callbacks require propagateNulls(true)")); + } + } + + public static void test_register_scalar_function_builder_with_int_function_rejects_wrong_types() throws Exception { + try (DuckDBScalarFunctionBuilder builder = DuckDBFunctions.scalarFunction()) { + builder.withName("java_invalid_int_function_type") + .withParameter(Double.class) + .withReturnType(Integer.class); + String message = assertThrows(() -> { builder.withIntFunction(x -> x + 1); }, SQLException.class); + assertTrue(message.contains("withIntFunction requires parameter 0 to be INTEGER")); + } + } + + public static void test_register_scalar_function_builder_java_bifunction_rejects_varargs() throws Exception { + try (DuckDBScalarFunctionBuilder builder = DuckDBFunctions.scalarFunction(); + DuckDBLogicalType intType = DuckDBLogicalType.of(DuckDBColumnType.INTEGER)) { + builder.withName("java_invalid_bifunction_varargs") + .withParameter(Integer.class) + .withParameter(Integer.class) + .withReturnType(Integer.class) + .withVarArgs(intType); + String message = + assertThrows(() -> { builder.withFunction((Integer x, Integer y) -> x + y); }, SQLException.class); + assertTrue(message.contains("BiFunction callback does not support varargs")); + assertTrue(message.contains("withVarArgsFunction")); + } + } + + public static void test_register_scalar_function_builder_with_double_function_rejects_propagate_nulls_false() + throws Exception { + try (DuckDBScalarFunctionBuilder builder = DuckDBFunctions.scalarFunction()) { + builder.withName("java_invalid_double_function_null_propagation") + .withParameter(Double.class) + .withReturnType(Double.class) + .propagateNulls(false); + String message = assertThrows(() -> { builder.withDoubleFunction(x -> x + 0.5d); }, SQLException.class); + assertTrue(message.contains("withDoubleFunction requires propagateNulls(true)")); + } + } + + public static void test_register_scalar_function_builder_with_double_function_rejects_wrong_types() + throws Exception { + try (DuckDBScalarFunctionBuilder builder = DuckDBFunctions.scalarFunction()) { + builder.withName("java_invalid_double_function_type") + .withParameter(Integer.class) + .withReturnType(Double.class); + String message = assertThrows(() -> { builder.withDoubleFunction(x -> x + 0.5d); }, SQLException.class); + assertTrue(message.contains("withDoubleFunction requires parameter 0 to be DOUBLE")); + } + } + + public static void test_register_scalar_function_builder_with_long_function_rejects_varargs() throws Exception { + try (DuckDBScalarFunctionBuilder builder = DuckDBFunctions.scalarFunction(); + DuckDBLogicalType bigintType = DuckDBLogicalType.of(DuckDBColumnType.BIGINT)) { + builder.withName("java_invalid_long_function_varargs") + .withParameter(Long.class) + .withReturnType(Long.class) + .withVarArgs(bigintType); + String message = assertThrows(() -> { builder.withLongFunction(x -> x + 1); }, SQLException.class); + assertTrue(message.contains("withLongFunction does not support varargs")); + } + } + + public static void test_register_scalar_function_builder_with_long_function_rejects_propagate_nulls_false() + throws Exception { + try (DuckDBScalarFunctionBuilder builder = DuckDBFunctions.scalarFunction()) { + builder.withName("java_invalid_long_function_null_propagation") + .withParameter(Long.class) + .withReturnType(Long.class) + .propagateNulls(false); + String message = assertThrows(() -> { builder.withLongFunction(x -> x + 1); }, SQLException.class); + assertTrue(message.contains("withLongFunction requires propagateNulls(true)")); + } + } + + public static void test_register_scalar_function_builder_with_long_function_rejects_wrong_types() throws Exception { + try (DuckDBScalarFunctionBuilder builder = DuckDBFunctions.scalarFunction()) { + builder.withName("java_invalid_long_function_type").withParameter(Integer.class).withReturnType(Long.class); + String message = assertThrows(() -> { builder.withLongFunction(x -> x + 1); }, SQLException.class); + assertTrue(message.contains("withLongFunction requires parameter 0 to be BIGINT")); + } + } + + public static void test_register_scalar_function_builder_java_varargs_function() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement(); + DuckDBLogicalType intType = DuckDBLogicalType.of(DuckDBColumnType.INTEGER)) { + DuckDBFunctions.scalarFunction() + .withName("java_sum_varargs_function") + .withParameter(Integer.class) + .withVarArgs(intType) + .withReturnType(Integer.class) + .withVarArgsFunction(args -> { + int sum = 0; + for (Object arg : args) { + sum += (Integer) arg; + } + return sum; + }) + .register(conn); + + try (ResultSet rs = + stmt.executeQuery("SELECT java_sum_varargs_function(1, 2, 3), java_sum_varargs_function(5), " + + "java_sum_varargs_function(NULL, 2), java_sum_varargs_function(2, NULL)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 6); + assertEquals(rs.getObject(2, Integer.class), 5); + assertEquals(rs.getObject(3), null); + assertEquals(rs.getObject(4), null); + assertFalse(rs.next()); + } + } + } + + public static void test_register_scalar_function_builder_java_varargs_function_requires_varargs() throws Exception { + try (DuckDBScalarFunctionBuilder builder = DuckDBFunctions.scalarFunction()) { + builder.withName("java_invalid_varargs_function") + .withParameter(Integer.class) + .withReturnType(Integer.class); + String message = assertThrows(() -> { builder.withVarArgsFunction(args -> 0); }, SQLException.class); + assertTrue(message.contains("requires withVarArgs")); + } + } + + public static void test_register_scalar_function_builder_java_function_supported_class_types() throws Exception { + Function notBoolean = value -> !value; + Function addTinyInt = value -> (byte) (value + 1); + Function addBigInt = value -> value + 3; + Function addDouble = value -> value + 0.5; + Function suffixString = value -> value + "_ok"; + Function addDate = value -> value.plusDays(2); + Function addTimestamp = value -> value.plusMinutes(30); + Function addTimestampTz = value -> value.plusMinutes(5); + assertUnaryJavaFunction("java_not_bool_function", Boolean.class, Boolean.class, notBoolean, + "SELECT java_not_bool_function(v) FROM (VALUES (TRUE), (NULL), (FALSE)) t(v)", rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Boolean.class), false); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Boolean.class), true); + assertFalse(rs.next()); + }); + + assertUnaryJavaFunction( + "java_add_tinyint_function", Byte.class, Byte.class, addTinyInt, + "SELECT java_add_tinyint_function(v) FROM (VALUES (41::TINYINT), (NULL), (-2::TINYINT)) t(v)", rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Byte.class), (byte) 42); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Byte.class), (byte) -1); + assertFalse(rs.next()); + }); + + assertUnaryJavaFunction( + "java_add_bigint_function", Long.class, Long.class, addBigInt, + "SELECT java_add_bigint_function(v) FROM (VALUES (39::BIGINT), (NULL), (-5::BIGINT)) t(v)", rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Long.class), 42L); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Long.class), -2L); + assertFalse(rs.next()); + }); + + assertUnaryJavaFunction( + "java_add_double_function", Double.class, Double.class, addDouble, + "SELECT java_add_double_function(v) FROM (VALUES (41.5::DOUBLE), (NULL), (-2.5::DOUBLE)) t(v)", rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Double.class), 42.0); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Double.class), -2.0); + assertFalse(rs.next()); + }); + + assertUnaryJavaFunction("java_suffix_varchar_function", String.class, String.class, suffixString, + "SELECT java_suffix_varchar_function(v) FROM (VALUES ('duck'), (NULL), ('db')) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, String.class), "duck_ok"); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, String.class), "db_ok"); + assertFalse(rs.next()); + }); + + assertUnaryJavaFunction( + "java_add_date_function", LocalDate.class, LocalDate.class, addDate, + "SELECT java_add_date_function(v) FROM (VALUES (DATE '2024-07-21'), (NULL), (DATE '2024-07-30')) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, LocalDate.class), LocalDate.of(2024, 7, 23)); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, LocalDate.class), LocalDate.of(2024, 8, 1)); + assertFalse(rs.next()); + }); + + assertUnaryJavaFunction( + "java_add_timestamp_function", LocalDateTime.class, LocalDateTime.class, addTimestamp, + "SELECT java_add_timestamp_function(v) FROM (VALUES " + + "(TIMESTAMP '2024-07-21 12:34:56.123456'), " + + "(NULL), " + + "(TIMESTAMP '1969-12-31 23:45:00')) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, LocalDateTime.class), LocalDateTime.of(2024, 7, 21, 13, 4, 56, 123456000)); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, LocalDateTime.class), LocalDateTime.of(1970, 1, 1, 0, 15, 0)); + assertFalse(rs.next()); + }); + + assertUnaryJavaFunction( + "java_add_timestamptz_function", OffsetDateTime.class, OffsetDateTime.class, addTimestampTz, + "SELECT java_add_timestamptz_function(v) FROM (VALUES " + + "(TIMESTAMPTZ '2024-07-21 12:34:56.123456+02:00'), (NULL)) t(v)", + rs -> { + assertTrue(rs.next()); + assertTrue(rs.getObject(1, OffsetDateTime.class) + .isEqual(OffsetDateTime.of(2024, 7, 21, 10, 39, 56, 123456000, ZoneOffset.UTC))); + assertTrue(rs.next()); + assertNullRow(rs); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_builder_java_function_supported_unsigned_types() throws Exception { + Function addUTinyInt = value -> (short) (value + 1); + Function addUSmallInt = value -> value + 2; + Function addUInteger = value -> value + 3; + Function addUBigInt = value -> value.add(BigInteger.ONE); + assertUnaryJavaFunction( + "java_add_utinyint_function", DuckDBColumnType.UTINYINT, DuckDBColumnType.UTINYINT, addUTinyInt, + "SELECT java_add_utinyint_function(v) FROM (VALUES (41::UTINYINT), (NULL), (254::UTINYINT)) t(v)", rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Short.class), (short) 42); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Short.class), (short) 255); + assertFalse(rs.next()); + }); + + assertUnaryJavaFunction( + "java_add_usmallint_function", DuckDBColumnType.USMALLINT, DuckDBColumnType.USMALLINT, addUSmallInt, + "SELECT java_add_usmallint_function(v) FROM (VALUES (40::USMALLINT), (NULL), (65533::USMALLINT)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 42); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 65535); + assertFalse(rs.next()); + }); + + assertUnaryJavaFunction( + "java_add_uinteger_function", DuckDBColumnType.UINTEGER, DuckDBColumnType.UINTEGER, addUInteger, + "SELECT java_add_uinteger_function(v) FROM (VALUES (39::UINTEGER), (NULL), (4294967292::UINTEGER)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Long.class), 42L); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Long.class), 4294967295L); + assertFalse(rs.next()); + }); + + assertUnaryJavaFunction( + "java_add_ubigint_function", DuckDBColumnType.UBIGINT, DuckDBColumnType.UBIGINT, addUBigInt, + "SELECT java_add_ubigint_function(v) FROM (VALUES (41::UBIGINT), (NULL), " + + "(18446744073709551614::UBIGINT)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, BigInteger.class), new BigInteger("42")); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, BigInteger.class), new BigInteger("18446744073709551615")); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_builder_java_function_decimal() throws Exception { + try (DuckDBLogicalType decimalType = DuckDBLogicalType.decimal(10, 2)) { + Function addDecimal = value -> value.add(new BigDecimal("1.25")); + assertUnaryJavaFunction("java_add_decimal_function", decimalType, decimalType, addDecimal, + "SELECT java_add_decimal_function(v) FROM (VALUES (CAST(40.75 AS DECIMAL(10,2))), " + + "(NULL), (CAST(-1.25 AS DECIMAL(10,2)))) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, BigDecimal.class), new BigDecimal("42.00")); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, BigDecimal.class), new BigDecimal("0.00")); + assertFalse(rs.next()); + }); + } + } + + public static void test_register_scalar_function_builder_java_function_sql_date_class_mapping() throws Exception { + Function addSqlDate = + value -> java.sql.Date.valueOf(value.toLocalDate().plusDays(1)); + assertUnaryJavaFunction( + "java_add_sql_date_function", java.sql.Date.class, java.sql.Date.class, addSqlDate, + "SELECT java_add_sql_date_function(v) FROM (VALUES (DATE '2024-07-21'), (NULL), (DATE '2024-07-30')) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, java.sql.Date.class), java.sql.Date.valueOf("2024-07-22")); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, java.sql.Date.class), java.sql.Date.valueOf("2024-07-31")); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_builder_java_function_sql_timestamp_class_mapping() + throws Exception { + Function addSqlTimestamp = + value -> java.sql.Timestamp.valueOf(value.toLocalDateTime().plusSeconds(1)); + assertUnaryJavaFunction("java_add_sql_timestamp_function", java.sql.Timestamp.class, java.sql.Timestamp.class, + addSqlTimestamp, + "SELECT java_add_sql_timestamp_function(v) FROM (VALUES " + + "(TIMESTAMP '2024-07-21 12:34:56.123456'), (NULL)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, java.sql.Timestamp.class), + java.sql.Timestamp.valueOf("2024-07-21 12:34:57.123456")); + assertTrue(rs.next()); + assertNullRow(rs); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_builder_java_function_java_util_date_class_mapping() + throws Exception { + Function addUtilDate = value -> new java.util.Date(value.getTime() + 1000L); + assertUnaryJavaFunction("java_add_java_util_date_function", java.util.Date.class, java.util.Date.class, + addUtilDate, + "SELECT java_add_java_util_date_function(v) = date_trunc('millisecond', v) + " + + "INTERVAL 1 SECOND FROM (VALUES " + + "(TIMESTAMP '2024-07-21 12:34:56.123456'), (NULL)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Boolean.class), true); + assertTrue(rs.next()); + assertNullRow(rs); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_builder_java_bifunction_supported_types() throws Exception { + BiFunction concatUnderscore = (left, right) -> left + "_" + right; + BiFunction sumDouble = Double::sum; + assertBinaryJavaFunction( + "java_concat_varchar_bifunction", String.class, String.class, String.class, concatUnderscore, + "SELECT java_concat_varchar_bifunction(a, b) FROM (VALUES ('duck', 'db'), (NULL, 'x'), ('a', NULL)) t(a, b)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, String.class), "duck_db"); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertNullRow(rs); + assertFalse(rs.next()); + }); + + assertBinaryJavaFunction( + "java_add_double_bifunction", Double.class, Double.class, Double.class, sumDouble, + "SELECT java_add_double_bifunction(a, b) FROM (VALUES (10.5::DOUBLE, 31.5::DOUBLE), (NULL, 2::DOUBLE), (2::DOUBLE, NULL)) t(a, b)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Double.class), 42.0); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertNullRow(rs); + assertFalse(rs.next()); + }); } public static void test_register_scalar_function_typed_logical_type() throws Exception { try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); Statement stmt = conn.createStatement(); DuckDBLogicalType intType = DuckDBLogicalType.of(DuckDBColumnType.INTEGER)) { - conn.registerScalarFunction("java_add_int_typed", new DuckDBLogicalType[] {intType}, intType, - (input, out) -> { - int rowCount = input.rowCount(); - DuckDBReadableVector in = input.vector(0); - for (int i = 0; i < rowCount; i++) { - if (in.isNull(i)) { - out.setNull(i); - } else { - out.setInt(i, in.getInt(i) + 1); - } - } - }); + DuckDBFunctions.scalarFunction() + .withName("java_add_int_typed") + .withParameter(intType) + .withReturnType(intType) + .propagateNulls(true) + .withVectorizedFunction(ctx -> { ctx.stream().forEachOrdered(row -> row.setInt(row.getInt(0) + 1)); }) + .register(conn); try (ResultSet rs = stmt.executeQuery("SELECT java_add_int_typed(v) FROM (VALUES (1), (NULL), (41)) t(v)")) { assertTrue(rs.next()); @@ -87,14 +1163,20 @@ public static void test_register_scalar_function_parallel() throws Exception { Statement stmt = conn.createStatement(); DuckDBLogicalType bigintType = DuckDBLogicalType.of(DuckDBColumnType.BIGINT)) { stmt.execute("PRAGMA threads=4"); - conn.registerScalarFunction("java_add_one_bigint", new DuckDBLogicalType[] {bigintType}, bigintType, - (input, out) -> { - int rowCount = input.rowCount(); - DuckDBReadableVector in = input.vector(0); - for (int i = 0; i < rowCount; i++) { - out.setLong(i, in.getLong(i) + 1); - } - }); + DuckDBFunctions.scalarFunction() + .withName("java_add_one_bigint") + .withParameter(bigintType) + .withReturnType(bigintType) + .propagateNulls(true) + .withVectorizedFunction(ctx -> { + DuckDBWritableVector out = ctx.output(); + DuckDBReadableVector in = ctx.input(0); + long rowCount = ctx.rowCount(); + for (long row = 0; row < rowCount; row++) { + out.setLong(row, in.getLong(row) + 1); + } + }) + .register(conn); try (ResultSet rs = stmt.executeQuery("SELECT sum(java_add_one_bigint(i)) FROM range(1000000) t(i)")) { assertTrue(rs.next()); @@ -105,12 +1187,138 @@ public static void test_register_scalar_function_parallel() throws Exception { } } + public static void test_register_scalar_function_context_row_stream_int() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement(); + DuckDBLogicalType intType = DuckDBLogicalType.of(DuckDBColumnType.INTEGER)) { + DuckDBFunctions.scalarFunction() + .withName("java_add_int_row_stream") + .withParameter(intType) + .withReturnType(intType) + .propagateNulls(true) + .withVectorizedFunction(ctx -> { ctx.stream().forEachOrdered(row -> row.setInt(row.getInt(0) + 1)); }) + .register(conn); + + try (ResultSet rs = + stmt.executeQuery("SELECT java_add_int_row_stream(v) FROM (VALUES (1), (NULL), (41)) t(v)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 2); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 42); + assertFalse(rs.next()); + } + } + } + + public static void test_register_scalar_function_context_row_stream_double() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement(); + DuckDBLogicalType doubleType = DuckDBLogicalType.of(DuckDBColumnType.DOUBLE)) { + DuckDBFunctions.scalarFunction() + .withName("java_add_double_row_stream") + .withParameter(doubleType) + .withReturnType(doubleType) + .propagateNulls(true) + .withVectorizedFunction( + ctx -> { ctx.stream().forEachOrdered(row -> row.setDouble(row.getDouble(0) + 1.5d)); }) + .register(conn); + + try ( + ResultSet rs = stmt.executeQuery( + "SELECT java_add_double_row_stream(v) FROM (VALUES (40.5::DOUBLE), (NULL), (-3.0::DOUBLE)) t(v)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Double.class), 42.0d); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Double.class), -1.5d); + assertFalse(rs.next()); + } + } + } + + public static void test_register_scalar_function_context_row_stream_propagate_nulls_false() throws Exception { + assertUnaryScalarFunction( + "java_suffix_varchar_row_stream_nullable", DuckDBColumnType.VARCHAR, DuckDBColumnType.VARCHAR, false, + ctx + -> { + ctx.stream().forEachOrdered(row -> { + String value = row.getString(0); + row.setString(value == null ? "NULL_SEEN" : value + "_ok"); + }); + }, + "SELECT java_suffix_varchar_row_stream_nullable(v) FROM (VALUES ('duck'), (NULL), ('db')) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, String.class), "duck_ok"); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, String.class), "NULL_SEEN"); + assertFalse(rs.wasNull()); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, String.class), "db_ok"); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_integer_append() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBFunctions.scalarFunction() + .withName("java_add_int_append") + .withParameter(DuckDBColumnType.INTEGER) + .withReturnType(DuckDBColumnType.INTEGER) + .withIntFunction(x -> x + 1) + .register(conn); + + try (ResultSet rs = + stmt.executeQuery("SELECT java_add_int_append(v) FROM (VALUES (41), (NULL), (-2)) t(v)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 42); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), -1); + assertFalse(rs.next()); + } + } + } + + public static void test_register_scalar_function_bigint_append() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBFunctions.scalarFunction() + .withName("java_add_bigint_append") + .withParameter(DuckDBColumnType.BIGINT) + .withReturnType(DuckDBColumnType.BIGINT) + .withLongFunction(x -> x + 1) + .register(conn); + + try (ResultSet rs = stmt.executeQuery( + "SELECT java_add_bigint_append(v) FROM (VALUES (41::BIGINT), (NULL), (-2::BIGINT)) t(v)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Long.class), 42L); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Long.class), -1L); + assertFalse(rs.next()); + } + } + } + public static void test_register_scalar_function_exception_propagation() throws Exception { try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); Statement stmt = conn.createStatement(); DuckDBLogicalType intType = DuckDBLogicalType.of(DuckDBColumnType.INTEGER)) { - conn.registerScalarFunction("java_throws_exception", new DuckDBLogicalType[] {intType}, intType, - (input, out) -> { throw new IllegalStateException("boom"); }); + DuckDBFunctions.scalarFunction() + .withName("java_throws_exception") + .withParameter(intType) + .withReturnType(intType) + .propagateNulls(true) + .withVectorizedFunction(ctx -> { throw new IllegalStateException("boom"); }) + .register(conn); String message = assertThrows(() -> { stmt.executeQuery("SELECT java_throws_exception(1)"); }, SQLException.class); assertTrue(message.contains("Java scalar function threw exception")); @@ -121,18 +1329,8 @@ public static void test_register_scalar_function_exception_propagation() throws public static void test_register_scalar_function_boolean() throws Exception { assertUnaryScalarFunction("java_not_bool", DuckDBColumnType.BOOLEAN, DuckDBColumnType.BOOLEAN, - (input, out) - -> { - int rowCount = input.rowCount(); - DuckDBReadableVector in = input.vector(0); - for (int i = 0; i < rowCount; i++) { - if (in.isNull(i)) { - out.setNull(i); - } else { - out.setBoolean(i, !in.getBoolean(i)); - } - } - }, + ctx + -> { ctx.stream().forEachOrdered(row -> row.setBoolean(!row.getBoolean(0))); }, "SELECT java_not_bool(v) FROM (VALUES (TRUE), (NULL), (FALSE)) t(v)", rs -> { assertTrue(rs.next()); @@ -147,18 +1345,8 @@ public static void test_register_scalar_function_boolean() throws Exception { public static void test_register_scalar_function_tinyint() throws Exception { assertUnaryScalarFunction("java_add_tinyint", DuckDBColumnType.TINYINT, DuckDBColumnType.TINYINT, - (input, out) - -> { - int rowCount = input.rowCount(); - DuckDBReadableVector in = input.vector(0); - for (int i = 0; i < rowCount; i++) { - if (in.isNull(i)) { - out.setNull(i); - } else { - out.setByte(i, (byte) (in.getByte(i) + 1)); - } - } - }, + ctx + -> { ctx.stream().forEachOrdered(row -> row.setByte((byte) (row.getByte(0) + 1))); }, "SELECT java_add_tinyint(v) FROM (VALUES (41::TINYINT), (NULL), (-2::TINYINT)) t(v)", rs -> { assertTrue(rs.next()); @@ -174,18 +1362,8 @@ public static void test_register_scalar_function_tinyint() throws Exception { public static void test_register_scalar_function_smallint() throws Exception { assertUnaryScalarFunction( "java_add_smallint", DuckDBColumnType.SMALLINT, DuckDBColumnType.SMALLINT, - (input, out) - -> { - int rowCount = input.rowCount(); - DuckDBReadableVector in = input.vector(0); - for (int i = 0; i < rowCount; i++) { - if (in.isNull(i)) { - out.setNull(i); - } else { - out.setShort(i, (short) (in.getShort(i) + 2)); - } - } - }, + ctx + -> { ctx.stream().forEachOrdered(row -> row.setShort((short) (row.getShort(0) + 2))); }, "SELECT java_add_smallint(v) FROM (VALUES (40::SMALLINT), (NULL), (-4::SMALLINT)) t(v)", rs -> { assertTrue(rs.next()); @@ -200,18 +1378,8 @@ public static void test_register_scalar_function_smallint() throws Exception { public static void test_register_scalar_function_integer() throws Exception { assertUnaryScalarFunction("java_add_int", DuckDBColumnType.INTEGER, DuckDBColumnType.INTEGER, - (input, out) - -> { - int rowCount = input.rowCount(); - DuckDBReadableVector in = input.vector(0); - for (int i = 0; i < rowCount; i++) { - if (in.isNull(i)) { - out.setNull(i); - } else { - out.setInt(i, in.getInt(i) + 1); - } - } - }, + ctx + -> { ctx.stream().forEachOrdered(row -> row.setInt(row.getInt(0) + 1)); }, "SELECT java_add_int(v) FROM (VALUES (1), (NULL), (41)) t(v)", rs -> { assertTrue(rs.next()); @@ -226,18 +1394,12 @@ public static void test_register_scalar_function_integer() throws Exception { public static void test_register_scalar_function_integer_revalidates_after_null() throws Exception { assertUnaryScalarFunction("java_revalidate_int", DuckDBColumnType.INTEGER, DuckDBColumnType.INTEGER, - (input, out) - -> { - int rowCount = input.rowCount(); - DuckDBReadableVector in = input.vector(0); - for (int i = 0; i < rowCount; i++) { - if (in.isNull(i)) { - out.setNull(i); - } else { - out.setNull(i); - out.setInt(i, in.getInt(i) + 1); - } - } + ctx + -> { + ctx.stream().forEachOrdered(row -> { + row.setNull(); + row.setInt(row.getInt(0) + 1); + }); }, "SELECT java_revalidate_int(v) FROM (VALUES (41), (NULL)) t(v)", rs -> { @@ -252,18 +1414,8 @@ public static void test_register_scalar_function_integer_revalidates_after_null( public static void test_register_scalar_function_bigint() throws Exception { assertUnaryScalarFunction("java_add_bigint", DuckDBColumnType.BIGINT, DuckDBColumnType.BIGINT, - (input, out) - -> { - int rowCount = input.rowCount(); - DuckDBReadableVector in = input.vector(0); - for (int i = 0; i < rowCount; i++) { - if (in.isNull(i)) { - out.setNull(i); - } else { - out.setLong(i, in.getLong(i) + 3); - } - } - }, + ctx + -> { ctx.stream().forEachOrdered(row -> row.setLong(row.getLong(0) + 3)); }, "SELECT java_add_bigint(v) FROM (VALUES (39::BIGINT), (NULL), (-5::BIGINT)) t(v)", rs -> { assertTrue(rs.next()); @@ -279,18 +1431,8 @@ public static void test_register_scalar_function_bigint() throws Exception { public static void test_register_scalar_function_utinyint() throws Exception { assertUnaryScalarFunction( "java_add_utinyint", DuckDBColumnType.UTINYINT, DuckDBColumnType.UTINYINT, - (input, out) - -> { - int rowCount = input.rowCount(); - DuckDBReadableVector in = input.vector(0); - for (int i = 0; i < rowCount; i++) { - if (in.isNull(i)) { - out.setNull(i); - } else { - out.setUint8(i, in.getUint8(i) + 1); - } - } - }, + ctx + -> { ctx.stream().forEachOrdered(row -> row.setUint8(row.getUint8(0) + 1)); }, "SELECT java_add_utinyint(v) FROM (VALUES (41::UTINYINT), (NULL), (254::UTINYINT)) t(v)", rs -> { assertTrue(rs.next()); @@ -306,18 +1448,8 @@ public static void test_register_scalar_function_utinyint() throws Exception { public static void test_register_scalar_function_usmallint() throws Exception { assertUnaryScalarFunction( "java_add_usmallint", DuckDBColumnType.USMALLINT, DuckDBColumnType.USMALLINT, - (input, out) - -> { - int rowCount = input.rowCount(); - DuckDBReadableVector in = input.vector(0); - for (int i = 0; i < rowCount; i++) { - if (in.isNull(i)) { - out.setNull(i); - } else { - out.setUint16(i, in.getUint16(i) + 2); - } - } - }, + ctx + -> { ctx.stream().forEachOrdered(row -> row.setUint16(row.getUint16(0) + 2)); }, "SELECT java_add_usmallint(v) FROM (VALUES (40::USMALLINT), (NULL), (65533::USMALLINT)) t(v)", rs -> { assertTrue(rs.next()); @@ -333,18 +1465,8 @@ public static void test_register_scalar_function_usmallint() throws Exception { public static void test_register_scalar_function_uinteger() throws Exception { assertUnaryScalarFunction( "java_add_uinteger", DuckDBColumnType.UINTEGER, DuckDBColumnType.UINTEGER, - (input, out) - -> { - int rowCount = input.rowCount(); - DuckDBReadableVector in = input.vector(0); - for (int i = 0; i < rowCount; i++) { - if (in.isNull(i)) { - out.setNull(i); - } else { - out.setUint32(i, in.getUint32(i) + 3); - } - } - }, + ctx + -> { ctx.stream().forEachOrdered(row -> row.setUint32(row.getUint32(0) + 3)); }, "SELECT java_add_uinteger(v) FROM (VALUES (39::UINTEGER), (NULL), (4294967292::UINTEGER)) t(v)", rs -> { assertTrue(rs.next()); @@ -358,48 +1480,30 @@ public static void test_register_scalar_function_uinteger() throws Exception { } public static void test_register_scalar_function_ubigint() throws Exception { - assertUnaryScalarFunction("java_add_ubigint", DuckDBColumnType.UBIGINT, DuckDBColumnType.UBIGINT, - (input, out) - -> { - int rowCount = input.rowCount(); - DuckDBReadableVector in = input.vector(0); - BigInteger increment = BigInteger.ONE; - for (int i = 0; i < rowCount; i++) { - if (in.isNull(i)) { - out.setNull(i); - } else { - out.setUint64(i, in.getUint64(i).add(increment)); - } - } - }, - "SELECT java_add_ubigint(v) FROM (VALUES (41::UBIGINT), (NULL), " - + "(18446744073709551614::UBIGINT)) t(v)", - rs -> { - assertTrue(rs.next()); - assertEquals(rs.getObject(1, BigInteger.class), new BigInteger("42")); - assertTrue(rs.next()); - assertNullRow(rs); - assertTrue(rs.next()); - assertEquals(rs.getObject(1, BigInteger.class), - new BigInteger("18446744073709551615")); - assertFalse(rs.next()); - }); + assertUnaryScalarFunction( + "java_add_ubigint", DuckDBColumnType.UBIGINT, DuckDBColumnType.UBIGINT, + ctx + -> { + BigInteger increment = BigInteger.ONE; + ctx.stream().forEachOrdered(row -> row.setUint64(row.getUint64(0).add(increment))); + }, + "SELECT java_add_ubigint(v) FROM (VALUES (41::UBIGINT), (NULL), " + + "(18446744073709551614::UBIGINT)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, BigInteger.class), new BigInteger("42")); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, BigInteger.class), new BigInteger("18446744073709551615")); + assertFalse(rs.next()); + }); } public static void test_register_scalar_function_float() throws Exception { assertUnaryScalarFunction("java_add_float", DuckDBColumnType.FLOAT, DuckDBColumnType.FLOAT, - (input, out) - -> { - int rowCount = input.rowCount(); - DuckDBReadableVector in = input.vector(0); - for (int i = 0; i < rowCount; i++) { - if (in.isNull(i)) { - out.setNull(i); - } else { - out.setFloat(i, in.getFloat(i) + 1.25f); - } - } - }, + ctx + -> { ctx.stream().forEachOrdered(row -> row.setFloat(row.getFloat(0) + 1.25f)); }, "SELECT java_add_float(v) FROM (VALUES (40.75::FLOAT), (NULL), (-2.5::FLOAT)) t(v)", rs -> { assertTrue(rs.next()); @@ -414,18 +1518,8 @@ public static void test_register_scalar_function_float() throws Exception { public static void test_register_scalar_function_double() throws Exception { assertUnaryScalarFunction("java_add_double", DuckDBColumnType.DOUBLE, DuckDBColumnType.DOUBLE, - (input, out) - -> { - int rowCount = input.rowCount(); - DuckDBReadableVector in = input.vector(0); - for (int i = 0; i < rowCount; i++) { - if (in.isNull(i)) { - out.setNull(i); - } else { - out.setDouble(i, in.getDouble(i) + 1.5d); - } - } - }, + ctx + -> { ctx.stream().forEachOrdered(row -> row.setDouble(row.getDouble(0) + 1.5d)); }, "SELECT java_add_double(v) FROM (VALUES (40.5::DOUBLE), (NULL), (-3.0::DOUBLE)) t(v)", rs -> { assertTrue(rs.next()); @@ -440,34 +1534,26 @@ public static void test_register_scalar_function_double() throws Exception { public static void test_register_scalar_function_decimal() throws Exception { try (DuckDBLogicalType decimalType = DuckDBLogicalType.decimal(38, 10)) { - assertUnaryScalarFunction("java_add_decimal", decimalType, decimalType, - (input, out) - -> { - int rowCount = input.rowCount(); - DuckDBReadableVector in = input.vector(0); - BigDecimal increment = new BigDecimal("0.0000000001"); - for (int i = 0; i < rowCount; i++) { - if (in.isNull(i)) { - out.setNull(i); - } else { - out.setBigDecimal(i, in.getBigDecimal(i).add(increment)); - } - } - }, - "SELECT java_add_decimal(v) FROM (VALUES " - + "(CAST('12345678901234567890.1234567890' AS DECIMAL(38,10))), " - + "(NULL), " - + "(CAST('-0.0000000001' AS DECIMAL(38,10)))) t(v)", - rs -> { - assertTrue(rs.next()); - assertEquals(rs.getObject(1, BigDecimal.class), - new BigDecimal("12345678901234567890.1234567891")); - assertTrue(rs.next()); - assertNullRow(rs); - assertTrue(rs.next()); - assertEquals(rs.getObject(1, BigDecimal.class), BigDecimal.ZERO.setScale(10)); - assertFalse(rs.next()); - }); + assertUnaryScalarFunction( + "java_add_decimal", decimalType, decimalType, + ctx + -> { + BigDecimal increment = new BigDecimal("0.0000000001"); + ctx.stream().forEachOrdered(row -> row.setBigDecimal(row.getBigDecimal(0).add(increment))); + }, + "SELECT java_add_decimal(v) FROM (VALUES " + + "(CAST('12345678901234567890.1234567890' AS DECIMAL(38,10))), " + + "(NULL), " + + "(CAST('-0.0000000001' AS DECIMAL(38,10)))) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, BigDecimal.class), new BigDecimal("12345678901234567890.1234567891")); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, BigDecimal.class), BigDecimal.ZERO.setScale(10)); + assertFalse(rs.next()); + }); } } @@ -475,13 +1561,19 @@ public static void test_register_scalar_function_decimal_precision_overflow() th try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); Statement stmt = conn.createStatement(); DuckDBLogicalType decimalType = DuckDBLogicalType.decimal(10, 2)) { - conn.registerScalarFunction("java_decimal_precision_overflow", new DuckDBLogicalType[] {decimalType}, - decimalType, (input, out) -> { - int rowCount = input.rowCount(); - for (int i = 0; i < rowCount; i++) { - out.setBigDecimal(i, new BigDecimal("12345678901.23")); - } - }); + DuckDBFunctions.scalarFunction() + .withName("java_decimal_precision_overflow") + .withParameter(decimalType) + .withReturnType(decimalType) + .propagateNulls(true) + .withVectorizedFunction(ctx -> { + DuckDBWritableVector out = ctx.output(); + long rowCount = ctx.rowCount(); + for (int i = 0; i < rowCount; i++) { + out.setBigDecimal(i, new BigDecimal("12345678901.23")); + } + }) + .register(conn); String err = assertThrows(() -> { stmt.execute("SELECT java_decimal_precision_overflow(CAST(1 AS DECIMAL(10,2)))"); @@ -494,13 +1586,19 @@ public static void test_register_scalar_function_decimal_scale_overflow() throws try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); Statement stmt = conn.createStatement(); DuckDBLogicalType decimalType = DuckDBLogicalType.decimal(10, 2)) { - conn.registerScalarFunction("java_decimal_scale_overflow", new DuckDBLogicalType[] {decimalType}, - decimalType, (input, out) -> { - int rowCount = input.rowCount(); - for (int i = 0; i < rowCount; i++) { - out.setBigDecimal(i, new BigDecimal("1.234")); - } - }); + DuckDBFunctions.scalarFunction() + .withName("java_decimal_scale_overflow") + .withParameter(decimalType) + .withReturnType(decimalType) + .propagateNulls(true) + .withVectorizedFunction(ctx -> { + DuckDBWritableVector out = ctx.output(); + long rowCount = ctx.rowCount(); + for (int i = 0; i < rowCount; i++) { + out.setBigDecimal(i, new BigDecimal("1.234")); + } + }) + .register(conn); String err = assertThrows(() -> { stmt.execute("SELECT java_decimal_scale_overflow(CAST(1 AS DECIMAL(10,2)))"); @@ -512,18 +1610,8 @@ public static void test_register_scalar_function_decimal_scale_overflow() throws public static void test_register_scalar_function_date() throws Exception { assertUnaryScalarFunction( "java_add_date", DuckDBColumnType.DATE, DuckDBColumnType.DATE, - (input, out) - -> { - int rowCount = input.rowCount(); - DuckDBReadableVector in = input.vector(0); - for (int i = 0; i < rowCount; i++) { - if (in.isNull(i)) { - out.setNull(i); - } else { - out.setDate(i, in.getLocalDate(i).plusDays(2)); - } - } - }, + ctx + -> { ctx.stream().forEachOrdered(row -> row.setDate(row.getLocalDate(0).plusDays(2))); }, "SELECT java_add_date(v) FROM (VALUES (DATE '2024-07-20'), (NULL), (DATE '1969-12-31')) t(v)", rs -> { assertTrue(rs.next()); @@ -539,18 +1627,12 @@ public static void test_register_scalar_function_date() throws Exception { public static void test_register_scalar_function_date_from_java_util_date() throws Exception { assertUnaryScalarFunction("java_date_from_util_date", DuckDBColumnType.DATE, DuckDBColumnType.DATE, - (input, out) - -> { - int rowCount = input.rowCount(); - DuckDBReadableVector in = input.vector(0); - for (int i = 0; i < rowCount; i++) { - if (in.isNull(i)) { - out.setNull(i); - } else { - LocalDate value = in.getLocalDate(i).plusDays(1); - out.setDate(i, java.util.Date.from(value.atStartOfDay(UTC).toInstant())); - } - } + ctx + -> { + ctx.stream().forEachOrdered(row -> { + LocalDate value = row.getLocalDate(0).plusDays(1); + row.setDate(java.util.Date.from(value.atStartOfDay(UTC).toInstant())); + }); }, "SELECT java_date_from_util_date(v) FROM (VALUES (DATE '2024-07-21'), (NULL)) t(v)", rs -> { @@ -563,52 +1645,32 @@ public static void test_register_scalar_function_date_from_java_util_date() thro } public static void test_register_scalar_function_timestamp() throws Exception { - assertUnaryScalarFunction("java_add_timestamp", DuckDBColumnType.TIMESTAMP, DuckDBColumnType.TIMESTAMP, - (input, out) - -> { - int rowCount = input.rowCount(); - DuckDBReadableVector in = input.vector(0); - for (int i = 0; i < rowCount; i++) { - if (in.isNull(i)) { - out.setNull(i); - } else { - out.setTimestamp(i, in.getLocalDateTime(i).plusMinutes(30)); - } - } - }, - "SELECT java_add_timestamp(v) FROM (VALUES " - + "(TIMESTAMP '2024-07-21 12:34:56.123456'), " - + "(NULL), " - + "(TIMESTAMP '1969-12-31 23:45:00')) t(v)", - rs -> { - assertTrue(rs.next()); - Timestamp ts1 = rs.getTimestamp(1); - assertEquals(ts1, Timestamp.valueOf("2024-07-21 13:04:56.123456")); - assertEquals(rs.getObject(1, LocalDateTime.class), - LocalDateTime.of(2024, 7, 21, 13, 4, 56, 123456000)); - assertTrue(rs.next()); - assertNullRow(rs); - assertTrue(rs.next()); - assertEquals(rs.getTimestamp(1), Timestamp.valueOf("1970-01-01 00:15:00")); - assertFalse(rs.next()); - }); + assertUnaryScalarFunction( + "java_add_timestamp", DuckDBColumnType.TIMESTAMP, DuckDBColumnType.TIMESTAMP, + ctx + -> { ctx.stream().forEachOrdered(row -> row.setTimestamp(row.getLocalDateTime(0).plusMinutes(30))); }, + "SELECT java_add_timestamp(v) FROM (VALUES " + + "(TIMESTAMP '2024-07-21 12:34:56.123456'), " + + "(NULL), " + + "(TIMESTAMP '1969-12-31 23:45:00')) t(v)", + rs -> { + assertTrue(rs.next()); + Timestamp ts1 = rs.getTimestamp(1); + assertEquals(ts1, Timestamp.valueOf("2024-07-21 13:04:56.123456")); + assertEquals(rs.getObject(1, LocalDateTime.class), LocalDateTime.of(2024, 7, 21, 13, 4, 56, 123456000)); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getTimestamp(1), Timestamp.valueOf("1970-01-01 00:15:00")); + assertFalse(rs.next()); + }); } public static void test_register_scalar_function_timestamp_s() throws Exception { assertUnaryScalarFunction( "java_add_timestamp_s", DuckDBColumnType.TIMESTAMP_S, DuckDBColumnType.TIMESTAMP_S, - (input, out) - -> { - int rowCount = input.rowCount(); - DuckDBReadableVector in = input.vector(0); - for (int i = 0; i < rowCount; i++) { - if (in.isNull(i)) { - out.setNull(i); - } else { - out.setTimestamp(i, in.getLocalDateTime(i).plusSeconds(2)); - } - } - }, + ctx + -> { ctx.stream().forEachOrdered(row -> row.setTimestamp(row.getLocalDateTime(0).plusSeconds(2))); }, "SELECT java_add_timestamp_s(v) FROM (VALUES (TIMESTAMP_S '2024-07-21 12:34:56'), (NULL)) t(v)", rs -> { assertTrue(rs.next()); @@ -622,18 +1684,8 @@ public static void test_register_scalar_function_timestamp_s() throws Exception public static void test_register_scalar_function_timestamp_s_pre_epoch() throws Exception { assertUnaryScalarFunction("java_copy_timestamp_s_pre_epoch", DuckDBColumnType.TIMESTAMP, DuckDBColumnType.TIMESTAMP_S, - (input, out) - -> { - int rowCount = input.rowCount(); - DuckDBReadableVector in = input.vector(0); - for (int i = 0; i < rowCount; i++) { - if (in.isNull(i)) { - out.setNull(i); - } else { - out.setTimestamp(i, in.getLocalDateTime(i)); - } - } - }, + ctx + -> { ctx.stream().forEachOrdered(row -> row.setTimestamp(row.getLocalDateTime(0))); }, "SELECT java_copy_timestamp_s_pre_epoch(v) FROM (VALUES " + "(TIMESTAMP '1969-12-31 23:59:59.999')) t(v)", rs -> { @@ -644,46 +1696,27 @@ public static void test_register_scalar_function_timestamp_s_pre_epoch() throws } public static void test_register_scalar_function_timestamp_ms() throws Exception { - assertUnaryScalarFunction("java_add_timestamp_ms", DuckDBColumnType.TIMESTAMP_MS, DuckDBColumnType.TIMESTAMP_MS, - (input, out) - -> { - int rowCount = input.rowCount(); - DuckDBReadableVector in = input.vector(0); - for (int i = 0; i < rowCount; i++) { - if (in.isNull(i)) { - out.setNull(i); - } else { - out.setTimestamp(i, in.getLocalDateTime(i).plusNanos(7_000_000)); - } - } - }, - "SELECT java_add_timestamp_ms(v) FROM (VALUES " - + "(TIMESTAMP_MS '2024-07-21 12:34:56.123'), (NULL)) t(v)", - rs -> { - assertTrue(rs.next()); - assertEquals(rs.getObject(1, LocalDateTime.class), - LocalDateTime.of(2024, 7, 21, 12, 34, 56, 130_000_000)); - assertTrue(rs.next()); - assertNullRow(rs); - assertFalse(rs.next()); - }); + assertUnaryScalarFunction( + "java_add_timestamp_ms", DuckDBColumnType.TIMESTAMP_MS, DuckDBColumnType.TIMESTAMP_MS, + ctx + -> { ctx.stream().forEachOrdered(row -> row.setTimestamp(row.getLocalDateTime(0).plusNanos(7_000_000))); }, + "SELECT java_add_timestamp_ms(v) FROM (VALUES " + + "(TIMESTAMP_MS '2024-07-21 12:34:56.123'), (NULL)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, LocalDateTime.class), + LocalDateTime.of(2024, 7, 21, 12, 34, 56, 130_000_000)); + assertTrue(rs.next()); + assertNullRow(rs); + assertFalse(rs.next()); + }); } public static void test_register_scalar_function_timestamp_ms_pre_epoch() throws Exception { assertUnaryScalarFunction("java_copy_timestamp_ms_pre_epoch", DuckDBColumnType.TIMESTAMP, DuckDBColumnType.TIMESTAMP_MS, - (input, out) - -> { - int rowCount = input.rowCount(); - DuckDBReadableVector in = input.vector(0); - for (int i = 0; i < rowCount; i++) { - if (in.isNull(i)) { - out.setNull(i); - } else { - out.setTimestamp(i, in.getLocalDateTime(i)); - } - } - }, + ctx + -> { ctx.stream().forEachOrdered(row -> row.setTimestamp(row.getLocalDateTime(0))); }, "SELECT java_copy_timestamp_ms_pre_epoch(v) FROM (VALUES " + "(TIMESTAMP '1969-12-31 23:59:59.9995')) t(v)", rs -> { @@ -695,47 +1728,28 @@ public static void test_register_scalar_function_timestamp_ms_pre_epoch() throws } public static void test_register_scalar_function_timestamp_ns() throws Exception { - assertUnaryScalarFunction("java_add_timestamp_ns", DuckDBColumnType.TIMESTAMP_NS, DuckDBColumnType.TIMESTAMP_NS, - (input, out) - -> { - int rowCount = input.rowCount(); - DuckDBReadableVector in = input.vector(0); - for (int i = 0; i < rowCount; i++) { - if (in.isNull(i)) { - out.setNull(i); - } else { - out.setTimestamp(i, in.getLocalDateTime(i).plusNanos(789)); - } - } - }, - "SELECT java_add_timestamp_ns(v) FROM (VALUES " - + "(TIMESTAMP_NS '2024-07-21 12:34:56.123456789'), (NULL)) t(v)", - rs -> { - assertTrue(rs.next()); - assertEquals(rs.getObject(1, LocalDateTime.class), - LocalDateTime.of(2024, 7, 21, 12, 34, 56, 123457578)); - assertTrue(rs.next()); - assertNullRow(rs); - assertFalse(rs.next()); - }); + assertUnaryScalarFunction( + "java_add_timestamp_ns", DuckDBColumnType.TIMESTAMP_NS, DuckDBColumnType.TIMESTAMP_NS, + ctx + -> { ctx.stream().forEachOrdered(row -> row.setTimestamp(row.getLocalDateTime(0).plusNanos(789))); }, + "SELECT java_add_timestamp_ns(v) FROM (VALUES " + + "(TIMESTAMP_NS '2024-07-21 12:34:56.123456789'), (NULL)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, LocalDateTime.class), + LocalDateTime.of(2024, 7, 21, 12, 34, 56, 123457578)); + assertTrue(rs.next()); + assertNullRow(rs); + assertFalse(rs.next()); + }); } public static void test_register_scalar_function_timestamptz() throws Exception { assertUnaryScalarFunction( "java_add_timestamptz", DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE, DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE, - (input, out) - -> { - int rowCount = input.rowCount(); - DuckDBReadableVector in = input.vector(0); - for (int i = 0; i < rowCount; i++) { - if (in.isNull(i)) { - out.setNull(i); - } else { - out.setOffsetDateTime(i, in.getOffsetDateTime(i).plusMinutes(5)); - } - } - }, + ctx + -> { ctx.stream().forEachOrdered(row -> row.setOffsetDateTime(row.getOffsetDateTime(0).plusMinutes(5))); }, "SELECT java_add_timestamptz(v) FROM (VALUES " + "(TIMESTAMPTZ '2024-07-21 12:34:56.123456+02:00'), (NULL)) t(v)", rs -> { @@ -752,18 +1766,8 @@ public static void test_register_scalar_function_timestamptz_set_timestamp() thr assertUnaryScalarFunction( "java_copy_timestamptz_with_timestamp", DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE, DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE, - (input, out) - -> { - int rowCount = input.rowCount(); - DuckDBReadableVector in = input.vector(0); - for (int i = 0; i < rowCount; i++) { - if (in.isNull(i)) { - out.setNull(i); - } else { - out.setTimestamp(i, in.getTimestamp(i)); - } - } - }, + ctx + -> { ctx.stream().forEachOrdered(row -> row.setTimestamp(row.getTimestamp(0))); }, "SELECT java_copy_timestamptz_with_timestamp(v) FROM (VALUES " + "(TIMESTAMPTZ '2024-07-21 12:34:56.123456+02:00')) t(v)", rs -> { @@ -777,18 +1781,11 @@ public static void test_register_scalar_function_timestamptz_set_timestamp() thr public static void test_register_scalar_function_timestamp_from_java_util_date() throws Exception { assertUnaryScalarFunction( "java_timestamp_from_util_date", DuckDBColumnType.TIMESTAMP, DuckDBColumnType.TIMESTAMP, - (input, out) - -> { - int rowCount = input.rowCount(); - DuckDBReadableVector in = input.vector(0); + ctx + -> { long oneSecondMillis = 1000L; - for (int i = 0; i < rowCount; i++) { - if (in.isNull(i)) { - out.setNull(i); - } else { - out.setTimestamp(i, new java.util.Date(in.getTimestamp(i).getTime() + oneSecondMillis)); - } - } + ctx.stream().forEachOrdered( + row -> { row.setTimestamp(new java.util.Date(row.getTimestamp(0).getTime() + oneSecondMillis)); }); }, "SELECT epoch_ms(java_timestamp_from_util_date(v)) FROM (VALUES " + "(TIMESTAMP '2024-07-21 12:34:56.123456'), " @@ -807,18 +1804,12 @@ public static void test_register_scalar_function_timestamp_from_java_util_date() public static void test_register_scalar_function_timestamp_from_java_util_date_typed_timestamp() throws Exception { assertUnaryScalarFunction( "java_timestamp_from_util_ts", DuckDBColumnType.TIMESTAMP, DuckDBColumnType.TIMESTAMP, - (input, out) - -> { - int rowCount = input.rowCount(); - DuckDBReadableVector in = input.vector(0); - for (int i = 0; i < rowCount; i++) { - if (in.isNull(i)) { - out.setNull(i); - } else { - java.util.Date value = Timestamp.valueOf(in.getLocalDateTime(i).plusNanos(789000)); - out.setTimestamp(i, value); - } - } + ctx + -> { + ctx.stream().forEachOrdered(row -> { + java.util.Date value = Timestamp.valueOf(row.getLocalDateTime(0).plusNanos(789000)); + row.setTimestamp(value); + }); }, "SELECT java_timestamp_from_util_ts(v) FROM (VALUES (TIMESTAMP '2024-07-21 12:34:56.123456')) t(v)", rs -> { @@ -832,18 +1823,12 @@ public static void test_register_scalar_function_timestamp_from_java_util_date_t public static void test_register_scalar_function_timestamp_from_java_util_date_typed_sql_date() throws Exception { assertUnaryScalarFunction( "java_timestamp_from_util_sql_date", DuckDBColumnType.DATE, DuckDBColumnType.TIMESTAMP, - (input, out) - -> { - int rowCount = input.rowCount(); - DuckDBReadableVector in = input.vector(0); - for (int i = 0; i < rowCount; i++) { - if (in.isNull(i)) { - out.setNull(i); - } else { - java.util.Date value = Date.valueOf(in.getLocalDate(i)); - out.setTimestamp(i, value); - } - } + ctx + -> { + ctx.stream().forEachOrdered(row -> { + java.util.Date value = Date.valueOf(row.getLocalDate(0)); + row.setTimestamp(value); + }); }, "SELECT epoch_ms(java_timestamp_from_util_sql_date(v)) FROM (VALUES (DATE '2024-07-21')) t(v)", rs -> { @@ -856,13 +1841,12 @@ public static void test_register_scalar_function_timestamp_from_java_util_date_t public static void test_register_scalar_function_timestamp_from_java_util_date_typed_sql_time() throws Exception { assertUnaryScalarFunction( "java_timestamp_from_util_sql_time", DuckDBColumnType.TIMESTAMP, DuckDBColumnType.TIMESTAMP, - (input, out) - -> { - int rowCount = input.rowCount(); - for (int i = 0; i < rowCount; i++) { + ctx + -> { + ctx.stream().forEachOrdered(row -> { java.util.Date value = Time.valueOf("12:34:56"); - out.setTimestamp(i, value); - } + row.setTimestamp(value); + }); }, "SELECT epoch_ms(java_timestamp_from_util_sql_time(v)) FROM (VALUES (TIMESTAMP '2024-07-21 00:00:00')) t(v)", rs -> { @@ -875,18 +1859,8 @@ public static void test_register_scalar_function_timestamp_from_java_util_date_t public static void test_register_scalar_function_timestamp_from_local_date() throws Exception { assertUnaryScalarFunction( "java_timestamp_from_local_date", DuckDBColumnType.DATE, DuckDBColumnType.TIMESTAMP, - (input, out) - -> { - int rowCount = input.rowCount(); - DuckDBReadableVector in = input.vector(0); - for (int i = 0; i < rowCount; i++) { - if (in.isNull(i)) { - out.setNull(i); - } else { - out.setTimestamp(i, in.getLocalDate(i).plusDays(1)); - } - } - }, + ctx + -> { ctx.stream().forEachOrdered(row -> row.setTimestamp(row.getLocalDate(0).plusDays(1))); }, "SELECT java_timestamp_from_local_date(v) FROM (VALUES (DATE '2024-07-21'), (NULL)) t(v)", rs -> { assertTrue(rs.next()); @@ -900,18 +1874,8 @@ public static void test_register_scalar_function_timestamp_from_local_date() thr public static void test_register_scalar_function_varchar() throws Exception { assertUnaryScalarFunction("java_suffix_varchar", DuckDBColumnType.VARCHAR, DuckDBColumnType.VARCHAR, - (input, out) - -> { - int rowCount = input.rowCount(); - DuckDBReadableVector in = input.vector(0); - for (int i = 0; i < rowCount; i++) { - if (in.isNull(i)) { - out.setNull(i); - } else { - out.setString(i, in.getString(i) + "_java"); - } - } - }, + ctx + -> { ctx.stream().forEachOrdered(row -> row.setString(row.getString(0) + "_java")); }, "SELECT java_suffix_varchar(v) FROM (VALUES ('duck'), (NULL), " + "('abcdefghijklmnop')) t(v)", rs -> { @@ -927,14 +1891,9 @@ public static void test_register_scalar_function_varchar() throws Exception { public static void test_register_scalar_function_varchar_get_string_handles_null() throws Exception { assertUnaryScalarFunction("java_echo_varchar_nullable", DuckDBColumnType.VARCHAR, DuckDBColumnType.VARCHAR, - (input, out) - -> { - int rowCount = input.rowCount(); - DuckDBReadableVector in = input.vector(0); - for (int i = 0; i < rowCount; i++) { - out.setString(i, in.getString(i)); - } - }, + false, + ctx + -> { ctx.stream().forEachOrdered(row -> row.setString(row.getString(0))); }, "SELECT java_echo_varchar_nullable(v) FROM (VALUES ('duck'), (NULL), " + "('abcdefghijklmnop')) t(v)", rs -> { @@ -950,18 +1909,12 @@ public static void test_register_scalar_function_varchar_get_string_handles_null public static void test_register_scalar_function_varchar_revalidates_after_null() throws Exception { assertUnaryScalarFunction("java_revalidate_varchar", DuckDBColumnType.VARCHAR, DuckDBColumnType.VARCHAR, - (input, out) - -> { - int rowCount = input.rowCount(); - DuckDBReadableVector in = input.vector(0); - for (int i = 0; i < rowCount; i++) { - if (in.isNull(i)) { - out.setNull(i); - } else { - out.setNull(i); - out.setString(i, in.getString(i) + "_ok"); - } - } + ctx + -> { + ctx.stream().forEachOrdered(row -> { + row.setNull(); + row.setString(row.getString(0) + "_ok"); + }); }, "SELECT java_revalidate_varchar(v) FROM (VALUES ('duck'), (NULL)) t(v)", rs -> { @@ -977,18 +1930,123 @@ public static void test_register_scalar_function_varchar_revalidates_after_null( private static void assertUnaryScalarFunction(String functionName, DuckDBColumnType parameterType, DuckDBColumnType returnType, DuckDBScalarFunction function, String query, ResultSetVerifier verifier) throws Exception { + assertUnaryScalarFunction(functionName, parameterType, returnType, true, function, query, verifier); + } + + private static void assertUnaryScalarFunction(String functionName, DuckDBColumnType parameterType, + DuckDBColumnType returnType, boolean propagateNulls, + DuckDBScalarFunction function, String query, + ResultSetVerifier verifier) throws Exception { try (DuckDBLogicalType parameterLogicalType = DuckDBLogicalType.of(parameterType); DuckDBLogicalType returnLogicalType = DuckDBLogicalType.of(returnType)) { - assertUnaryScalarFunction(functionName, parameterLogicalType, returnLogicalType, function, query, verifier); + assertUnaryScalarFunction(functionName, parameterLogicalType, returnLogicalType, propagateNulls, function, + query, verifier); } } private static void assertUnaryScalarFunction(String functionName, DuckDBLogicalType parameterType, DuckDBLogicalType returnType, DuckDBScalarFunction function, String query, ResultSetVerifier verifier) throws Exception { + assertUnaryScalarFunction(functionName, parameterType, returnType, true, function, query, verifier); + } + + private static void assertUnaryScalarFunction(String functionName, DuckDBLogicalType parameterType, + DuckDBLogicalType returnType, boolean propagateNulls, + DuckDBScalarFunction function, String query, + ResultSetVerifier verifier) throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBFunctions.scalarFunction() + .withName(functionName) + .withParameter(parameterType) + .withReturnType(returnType) + .propagateNulls(propagateNulls) + .withVectorizedFunction(function) + .register(conn); + try (ResultSet rs = stmt.executeQuery(query)) { + verifier.verify(rs); + } + } + } + + private static void assertUnaryJavaFunction(String functionName, Class parameterType, Class returnType, + Function function, String query, ResultSetVerifier verifier) + throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBFunctions.scalarFunction() + .withName(functionName) + .withParameter(parameterType) + .withReturnType(returnType) + .withFunction(function) + .register(conn); + try (ResultSet rs = stmt.executeQuery(query)) { + verifier.verify(rs); + } + } + } + + private static void assertUnaryJavaFunction(String functionName, DuckDBColumnType parameterType, + DuckDBColumnType returnType, Function function, String query, + ResultSetVerifier verifier) throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBFunctions.scalarFunction() + .withName(functionName) + .withParameter(parameterType) + .withReturnType(returnType) + .withFunction(function) + .register(conn); + try (ResultSet rs = stmt.executeQuery(query)) { + verifier.verify(rs); + } + } + } + + private static void assertUnaryJavaFunction(String functionName, DuckDBLogicalType parameterType, + DuckDBLogicalType returnType, Function function, String query, + ResultSetVerifier verifier) throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBFunctions.scalarFunction() + .withName(functionName) + .withParameter(parameterType) + .withReturnType(returnType) + .withFunction(function) + .register(conn); + try (ResultSet rs = stmt.executeQuery(query)) { + verifier.verify(rs); + } + } + } + + private static void assertBinaryJavaFunction(String functionName, Class leftType, Class rightType, + Class returnType, BiFunction function, String query, + ResultSetVerifier verifier) throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBFunctions.scalarFunction() + .withName(functionName) + .withParameter(leftType) + .withParameter(rightType) + .withReturnType(returnType) + .withFunction(function) + .register(conn); + try (ResultSet rs = stmt.executeQuery(query)) { + verifier.verify(rs); + } + } + } + + private static void assertNullaryJavaFunction(String functionName, Class returnType, Supplier function, + String query, ResultSetVerifier verifier) throws Exception { try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); Statement stmt = conn.createStatement()) { - conn.registerScalarFunction(functionName, new DuckDBLogicalType[] {parameterType}, returnType, function); + DuckDBFunctions.scalarFunction() + .withName(functionName) + .withReturnType(returnType) + .withFunction(function) + .register(conn); try (ResultSet rs = stmt.executeQuery(query)) { verifier.verify(rs); } From cfa37ba0bd061c5783500735aa3a02f0518e86b7 Mon Sep 17 00:00:00 2001 From: Luis Fernando Kauer Date: Mon, 6 Apr 2026 21:14:07 -0300 Subject: [PATCH 6/9] Use long indices in scalar chunk readers --- .../java/org/duckdb/DuckDBDataChunkReader.java | 15 ++++++++------- src/main/java/org/duckdb/DuckDBScalarContext.java | 8 ++++---- .../org/duckdb/DuckDBScalarFunctionAdapter.java | 2 +- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/src/main/java/org/duckdb/DuckDBDataChunkReader.java b/src/main/java/org/duckdb/DuckDBDataChunkReader.java index a628ce1ff..6e7161c85 100644 --- a/src/main/java/org/duckdb/DuckDBDataChunkReader.java +++ b/src/main/java/org/duckdb/DuckDBDataChunkReader.java @@ -8,7 +8,7 @@ public final class DuckDBDataChunkReader { private final ByteBuffer chunkRef; private final long rowCount; - private final int columnCount; + private final long columnCount; private final DuckDBReadableVector[] vectors; DuckDBDataChunkReader(ByteBuffer chunkRef) throws SQLException { @@ -17,27 +17,28 @@ public final class DuckDBDataChunkReader { } this.chunkRef = chunkRef; this.rowCount = duckdb_data_chunk_get_size(chunkRef); - this.columnCount = Math.toIntExact(duckdb_data_chunk_get_column_count(chunkRef)); - this.vectors = new DuckDBReadableVector[columnCount]; + this.columnCount = duckdb_data_chunk_get_column_count(chunkRef); + this.vectors = new DuckDBReadableVector[Math.toIntExact(columnCount)]; } public long rowCount() { return rowCount; } - public int columnCount() { + public long columnCount() { return columnCount; } - public DuckDBReadableVector vector(int columnIndex) throws SQLException { + public DuckDBReadableVector vector(long columnIndex) throws SQLException { if (columnIndex < 0 || columnIndex >= columnCount) { throw new IndexOutOfBoundsException("Column index out of bounds: " + columnIndex); } - DuckDBReadableVector vector = vectors[columnIndex]; + int arrayIndex = Math.toIntExact(columnIndex); + DuckDBReadableVector vector = vectors[arrayIndex]; if (vector == null) { ByteBuffer vectorRef = duckdb_data_chunk_get_vector(chunkRef, columnIndex); vector = new DuckDBReadableVectorImpl(vectorRef, rowCount); - vectors[columnIndex] = vector; + vectors[arrayIndex] = vector; } return vector; } diff --git a/src/main/java/org/duckdb/DuckDBScalarContext.java b/src/main/java/org/duckdb/DuckDBScalarContext.java index 79689175b..c9d4fc8be 100644 --- a/src/main/java/org/duckdb/DuckDBScalarContext.java +++ b/src/main/java/org/duckdb/DuckDBScalarContext.java @@ -25,11 +25,11 @@ public long rowCount() { return input.rowCount(); } - public int columnCount() { + public long columnCount() { return input.columnCount(); } - public DuckDBReadableVector input(int columnIndex) throws SQLException { + public DuckDBReadableVector input(long columnIndex) throws SQLException { return input.vector(columnIndex); } @@ -58,7 +58,7 @@ public DuckDBScalarRow row(long rowIndex) { return new DuckDBScalarRow(this, rowIndex); } - DuckDBReadableVector inputUnchecked(int columnIndex) { + DuckDBReadableVector inputUnchecked(long columnIndex) { try { return input(columnIndex); } catch (SQLException exception) { @@ -73,7 +73,7 @@ private void checkRowIndex(long rowIndex) { } private boolean rowHasNoNullInputs(long rowIndex) { - for (int columnIndex = 0; columnIndex < columnCount(); columnIndex++) { + for (long columnIndex = 0; columnIndex < columnCount(); columnIndex++) { if (inputUnchecked(columnIndex).isNull(rowIndex)) { try { output.setNull(rowIndex); diff --git a/src/main/java/org/duckdb/DuckDBScalarFunctionAdapter.java b/src/main/java/org/duckdb/DuckDBScalarFunctionAdapter.java index dbefe0b50..11456fa3b 100644 --- a/src/main/java/org/duckdb/DuckDBScalarFunctionAdapter.java +++ b/src/main/java/org/duckdb/DuckDBScalarFunctionAdapter.java @@ -298,7 +298,7 @@ static DuckDBScalarFunction variadic(Function function, DuckDBColum return ctx -> { DuckDBWritableVector out = ctx.output(); long rowCount = ctx.rowCount(); - int vectorCount = ctx.columnCount(); + int vectorCount = Math.toIntExact(ctx.columnCount()); boolean propagateNulls = ctx.propagateNulls(); DuckDBReadableVector[] vectors = new DuckDBReadableVector[vectorCount]; TypeCodec[] codecs = new TypeCodec[vectorCount]; From 590aa3dd42982126daed74aa4474347006caf9eb Mon Sep 17 00:00:00 2001 From: Luis Fernando Kauer Date: Mon, 6 Apr 2026 22:12:35 -0300 Subject: [PATCH 7/9] Sync Windows export list with current JNI symbols --- duckdb_java.def | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/duckdb_java.def b/duckdb_java.def index dc9e5a4ae..7f340311b 100644 --- a/duckdb_java.def +++ b/duckdb_java.def @@ -57,9 +57,11 @@ Java_org_duckdb_DuckDBBindings_duckdb_1destroy_1scalar_1function Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1name Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1add_1parameter Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1return_1type +Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1varargs +Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1volatile +Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1special_1handling Java_org_duckdb_DuckDBBindings_duckdb_1register_1scalar_1function Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1error -Java_org_duckdb_DuckDBBindings_duckdb_1jdbc_1varchar_1string_1bytes Java_org_duckdb_DuckDBBindings_duckdb_1create_1logical_1type Java_org_duckdb_DuckDBBindings_duckdb_1create_1decimal_1type Java_org_duckdb_DuckDBBindings_duckdb_1get_1type_1id @@ -83,6 +85,8 @@ Java_org_duckdb_DuckDBBindings_duckdb_1vector_1get_1data Java_org_duckdb_DuckDBBindings_duckdb_1vector_1get_1validity Java_org_duckdb_DuckDBBindings_duckdb_1vector_1ensure_1validity_1writable Java_org_duckdb_DuckDBBindings_duckdb_1vector_1assign_1string_1element_1len +Java_org_duckdb_DuckDBBindings_duckdb_1vector_1get_1string +Java_org_duckdb_DuckDBBindings_duckdb_1vector_1get_1string__Ljava_nio_ByteBuffer_2J Java_org_duckdb_DuckDBBindings_duckdb_1validity_1row_1is_1valid Java_org_duckdb_DuckDBBindings_duckdb_1validity_1set_1row_1validity Java_org_duckdb_DuckDBBindings_duckdb_1list_1vector_1get_1child @@ -316,7 +320,6 @@ duckdb_destroy_selection_vector duckdb_destroy_table_function duckdb_destroy_task_state duckdb_destroy_value -duckdb_destroy_vector duckdb_disconnect duckdb_double_to_decimal duckdb_double_to_hugeint From 68a7f91c1142637977a41e7b8fce10f15bb48601 Mon Sep 17 00:00:00 2001 From: Luis Fernando Kauer Date: Tue, 7 Apr 2026 22:57:18 -0300 Subject: [PATCH 8/9] Refine scalar UDF callback API, runtime model, and 128-bit support Squash all local scalar-function work into a single cohesive change set guided by review feedback. This adopts context-driven null propagation, moves null-aware primitive handling into vector readers, unifies callback-time failures under DuckDBFunctionException (while documenting bounds violations as IndexOutOfBoundsException), and removes checked SQLExceptions from callback runtime APIs. It also adds complete HUGEINT/UHUGEINT callback read/write support, keeps BigInteger class mapping on HUGEINT, hoists vector native byte-order setup, syncs scalar JNI sources in CMake templates, and adds JNI ExceptionCheck guards. Tests and docs are updated to cover the new behavior across scalar callbacks, bindings, null handling, and unsigned 128-bit round-trips. --- CMakeLists.txt.in | 2 + UDF.MD | 25 +- src/jni/bindings_scalar_function.cpp | 6 + .../org/duckdb/DuckDBDataChunkReader.java | 12 +- .../org/duckdb/DuckDBFunctionException.java | 13 + .../java/org/duckdb/DuckDBFunctionKind.java | 3 - src/main/java/org/duckdb/DuckDBFunctions.java | 2 + src/main/java/org/duckdb/DuckDBHugeInt.java | 22 + .../java/org/duckdb/DuckDBLogicalType.java | 41 +- .../java/org/duckdb/DuckDBReadableVector.java | 77 +- .../org/duckdb/DuckDBReadableVectorImpl.java | 274 +++++-- .../org/duckdb/DuckDBRegisteredFunction.java | 10 +- .../java/org/duckdb/DuckDBScalarContext.java | 30 +- .../duckdb/DuckDBScalarFunctionAdapter.java | 212 +++-- .../duckdb/DuckDBScalarFunctionBuilder.java | 42 +- src/main/java/org/duckdb/DuckDBScalarRow.java | 202 +++-- .../java/org/duckdb/DuckDBVectorTypeInfo.java | 4 + .../java/org/duckdb/DuckDBWritableVector.java | 109 +-- .../org/duckdb/DuckDBWritableVectorImpl.java | 288 ++++--- src/test/java/org/duckdb/TestBindings.java | 42 +- .../java/org/duckdb/TestScalarFunctions.java | 732 +++++++++++------- 21 files changed, 1394 insertions(+), 754 deletions(-) create mode 100644 src/main/java/org/duckdb/DuckDBFunctionException.java delete mode 100644 src/main/java/org/duckdb/DuckDBFunctionKind.java diff --git a/CMakeLists.txt.in b/CMakeLists.txt.in index 3d6f041d6..751076c19 100644 --- a/CMakeLists.txt.in +++ b/CMakeLists.txt.in @@ -109,12 +109,14 @@ add_library(duckdb_java SHARED src/jni/bindings_common.cpp src/jni/bindings_data_chunk.cpp src/jni/bindings_logical_type.cpp + src/jni/bindings_scalar_function.cpp src/jni/bindings_validity.cpp src/jni/bindings_vector.cpp src/jni/config.cpp src/jni/duckdb_java.cpp src/jni/functions.cpp src/jni/refs.cpp + src/jni/scalar_functions.cpp src/jni/types.cpp src/jni/util.cpp ${DUCKDB_SRC_FILES}) diff --git a/UDF.MD b/UDF.MD index 0398165b7..f8407c0e2 100644 --- a/UDF.MD +++ b/UDF.MD @@ -55,13 +55,18 @@ SELECT java_weighted_sum(2.5, 4.0); Behavior: -- `propagateNulls(true)` is the default. -- With `propagateNulls(true)`, NULL input propagates to NULL output for `Function`/`BiFunction`. -- With `propagateNulls(false)`, functional callbacks receive `null` arguments and can decide the output. -- `withIntFunction(...)`, `withLongFunction(...)`, and `withDoubleFunction(...)` require `propagateNulls(true)`. +- `Function` and `BiFunction` callbacks receive `null` for NULL inputs; implement null handling in Java callback logic. +- `withIntFunction(...)`, `withLongFunction(...)`, and `withDoubleFunction(...)` run with null propagation enabled. +- For `withVectorizedFunction(...)`, use `ctx.propagateNulls(true)` when you want stream-level null skipping and automatic NULL output. - For `Supplier`, returning `null` writes NULL output. - `Function` and `BiFunction` are fixed arity only (no varargs). +Runtime error model: + +- Callback-time reader/writer/context type and value failures throw `DuckDBFunctionException`. +- Invalid row/column indexes throw `IndexOutOfBoundsException`. +- `SQLException` remains for registration-time API usage and type declaration/validation. + ## Type declaration and mapping `withParameter(...)` and `withReturnType(...)` accept: @@ -76,9 +81,15 @@ Common class mappings include: - `Long` -> `BIGINT` - `String` -> `VARCHAR` - `BigDecimal` -> `DECIMAL` +- `BigInteger` -> `HUGEINT` - `LocalDate` and `java.sql.Date` -> `DATE` - `LocalDateTime`, `java.sql.Timestamp`, and `java.util.Date` -> `TIMESTAMP` +Notes: + +- `UHUGEINT` is supported through explicit `DuckDBColumnType.UHUGEINT`/`DuckDBLogicalType` declarations. +- Java class auto-mapping for `BigInteger` remains `HUGEINT`. + Use `DuckDBLogicalType.decimal(width, scale)` for explicit DECIMAL precision/scale. ## Varargs @@ -119,6 +130,7 @@ Notes: - `withName(String)` - `withParameter(Class | DuckDBColumnType | DuckDBLogicalType)` +- `withParameters(Class...)` - `withReturnType(Class | DuckDBColumnType | DuckDBLogicalType)` - `withFunction(Supplier | Function | BiFunction)` - `withIntFunction(IntUnaryOperator | IntBinaryOperator)` @@ -127,7 +139,6 @@ Notes: - `withVarArgs(DuckDBLogicalType)` - `withVarArgsFunction(Function)` - `withVectorizedFunction(DuckDBScalarFunction)` -- `propagateNulls(boolean)` - `withVolatile()` - `withSpecialHandling()` - `register(java.sql.Connection)` @@ -167,9 +178,8 @@ try (Connection conn = DriverManager.getConnection("jdbc:duckdb:"); .withParameter(strType) .withParameter(dblType) .withReturnType(strType) - .propagateNulls(true) .withVectorizedFunction(ctx -> { - ctx.stream().forEachOrdered(row -> { + ctx.propagateNulls(true).stream().forEachOrdered(row -> { String value = row.getLocalDateTime(0) + " | " + row.getString(1).trim().toUpperCase() + " | " + row.getDouble(2); @@ -187,5 +197,6 @@ SELECT java_event_label(TIMESTAMP '2026-04-04 12:00:00', 'launch', 4.5); Lifecycle rules: - `DuckDBScalarContext`, `DuckDBScalarRow`, `DuckDBReadableVector`, and `DuckDBWritableVector` are valid only during callback execution. +- `DuckDBReadableVector` and `DuckDBWritableVector` are abstract callback runtime types (not interfaces). - Write exactly one output value per input row for each callback invocation. - With `propagateNulls(true)`, `DuckDBScalarContext.stream()` skips rows that contain NULL in any input column and writes NULL to the output for those rows. diff --git a/src/jni/bindings_scalar_function.cpp b/src/jni/bindings_scalar_function.cpp index 54e6557a9..5fb9c0b7a 100644 --- a/src/jni/bindings_scalar_function.cpp +++ b/src/jni/bindings_scalar_function.cpp @@ -61,6 +61,9 @@ JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1 return; } auto function_name = jbyteArray_to_string(env, name); + if (env->ExceptionCheck()) { + return; + } duckdb_scalar_function_set_name(function, function_name.c_str()); } @@ -159,5 +162,8 @@ JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1 return; } auto error_message = jbyteArray_to_string(env, error); + if (env->ExceptionCheck()) { + return; + } duckdb_scalar_function_set_error(function_info, error_message.c_str()); } diff --git a/src/main/java/org/duckdb/DuckDBDataChunkReader.java b/src/main/java/org/duckdb/DuckDBDataChunkReader.java index 6e7161c85..23fc16190 100644 --- a/src/main/java/org/duckdb/DuckDBDataChunkReader.java +++ b/src/main/java/org/duckdb/DuckDBDataChunkReader.java @@ -3,17 +3,21 @@ import static org.duckdb.DuckDBBindings.*; import java.nio.ByteBuffer; -import java.sql.SQLException; +/** + * Reader over callback input data chunks. + * + *

Column index violations throw {@link IndexOutOfBoundsException}. + */ public final class DuckDBDataChunkReader { private final ByteBuffer chunkRef; private final long rowCount; private final long columnCount; private final DuckDBReadableVector[] vectors; - DuckDBDataChunkReader(ByteBuffer chunkRef) throws SQLException { + DuckDBDataChunkReader(ByteBuffer chunkRef) { if (chunkRef == null) { - throw new SQLException("Invalid data chunk reference"); + throw new DuckDBFunctionException("Invalid data chunk reference"); } this.chunkRef = chunkRef; this.rowCount = duckdb_data_chunk_get_size(chunkRef); @@ -29,7 +33,7 @@ public long columnCount() { return columnCount; } - public DuckDBReadableVector vector(long columnIndex) throws SQLException { + public DuckDBReadableVector vector(long columnIndex) { if (columnIndex < 0 || columnIndex >= columnCount) { throw new IndexOutOfBoundsException("Column index out of bounds: " + columnIndex); } diff --git a/src/main/java/org/duckdb/DuckDBFunctionException.java b/src/main/java/org/duckdb/DuckDBFunctionException.java new file mode 100644 index 000000000..2010fe433 --- /dev/null +++ b/src/main/java/org/duckdb/DuckDBFunctionException.java @@ -0,0 +1,13 @@ +package org.duckdb; + +public final class DuckDBFunctionException extends RuntimeException { + private static final long serialVersionUID = 1L; + + public DuckDBFunctionException(String message) { + super(message); + } + + public DuckDBFunctionException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/src/main/java/org/duckdb/DuckDBFunctionKind.java b/src/main/java/org/duckdb/DuckDBFunctionKind.java deleted file mode 100644 index 1ec4c8f46..000000000 --- a/src/main/java/org/duckdb/DuckDBFunctionKind.java +++ /dev/null @@ -1,3 +0,0 @@ -package org.duckdb; - -public enum DuckDBFunctionKind { SCALAR } diff --git a/src/main/java/org/duckdb/DuckDBFunctions.java b/src/main/java/org/duckdb/DuckDBFunctions.java index a8e1dad3e..0a9337301 100644 --- a/src/main/java/org/duckdb/DuckDBFunctions.java +++ b/src/main/java/org/duckdb/DuckDBFunctions.java @@ -3,6 +3,8 @@ import java.sql.SQLException; public final class DuckDBFunctions { + public enum DuckDBFunctionKind { SCALAR } + private DuckDBFunctions() { } diff --git a/src/main/java/org/duckdb/DuckDBHugeInt.java b/src/main/java/org/duckdb/DuckDBHugeInt.java index 5912e0fa2..eeca81f83 100644 --- a/src/main/java/org/duckdb/DuckDBHugeInt.java +++ b/src/main/java/org/duckdb/DuckDBHugeInt.java @@ -1,11 +1,13 @@ package org.duckdb; import java.math.BigInteger; +import java.nio.ByteBuffer; import java.sql.SQLException; class DuckDBHugeInt { static final BigInteger HUGE_INT_MIN = BigInteger.ONE.shiftLeft(127).negate(); static final BigInteger HUGE_INT_MAX = BigInteger.ONE.shiftLeft(127).subtract(BigInteger.ONE); + static final BigInteger UHUGE_INT_MAX = BigInteger.ONE.shiftLeft(128).subtract(BigInteger.ONE); private final long lower; private final long upper; @@ -25,4 +27,24 @@ class DuckDBHugeInt { this.lower = bi.longValue(); this.upper = bi.shiftRight(64).longValue(); } + + static BigInteger toBigInteger(long lower, long upper) { + byte[] bytes = new byte[Long.BYTES * 2]; + ByteBuffer.wrap(bytes).putLong(upper).putLong(lower); + return new BigInteger(bytes); + } + + static BigInteger toUnsignedBigInteger(long lower, long upper) { + byte[] bytes = new byte[Long.BYTES * 2]; + ByteBuffer.wrap(bytes).putLong(upper).putLong(lower); + return new BigInteger(1, bytes); + } + + long lower() { + return lower; + } + + long upper() { + return upper; + } } diff --git a/src/main/java/org/duckdb/DuckDBLogicalType.java b/src/main/java/org/duckdb/DuckDBLogicalType.java index 1578e6ba2..26a0f7436 100644 --- a/src/main/java/org/duckdb/DuckDBLogicalType.java +++ b/src/main/java/org/duckdb/DuckDBLogicalType.java @@ -1,5 +1,6 @@ package org.duckdb; +import static org.duckdb.DuckDBBindings.CAPIType.*; import static org.duckdb.DuckDBBindings.*; import java.nio.ByteBuffer; @@ -21,41 +22,45 @@ public static DuckDBLogicalType of(DuckDBColumnType type) throws SQLException { } switch (type) { case BOOLEAN: - return createPrimitive(DuckDBBindings.CAPIType.DUCKDB_TYPE_BOOLEAN); + return createPrimitive(DUCKDB_TYPE_BOOLEAN); case TINYINT: - return createPrimitive(DuckDBBindings.CAPIType.DUCKDB_TYPE_TINYINT); + return createPrimitive(DUCKDB_TYPE_TINYINT); case SMALLINT: - return createPrimitive(DuckDBBindings.CAPIType.DUCKDB_TYPE_SMALLINT); + return createPrimitive(DUCKDB_TYPE_SMALLINT); case INTEGER: - return createPrimitive(DuckDBBindings.CAPIType.DUCKDB_TYPE_INTEGER); + return createPrimitive(DUCKDB_TYPE_INTEGER); case BIGINT: - return createPrimitive(DuckDBBindings.CAPIType.DUCKDB_TYPE_BIGINT); + return createPrimitive(DUCKDB_TYPE_BIGINT); + case HUGEINT: + return createPrimitive(DUCKDB_TYPE_HUGEINT); case UTINYINT: - return createPrimitive(DuckDBBindings.CAPIType.DUCKDB_TYPE_UTINYINT); + return createPrimitive(DUCKDB_TYPE_UTINYINT); case USMALLINT: - return createPrimitive(DuckDBBindings.CAPIType.DUCKDB_TYPE_USMALLINT); + return createPrimitive(DUCKDB_TYPE_USMALLINT); case UINTEGER: - return createPrimitive(DuckDBBindings.CAPIType.DUCKDB_TYPE_UINTEGER); + return createPrimitive(DUCKDB_TYPE_UINTEGER); case UBIGINT: - return createPrimitive(DuckDBBindings.CAPIType.DUCKDB_TYPE_UBIGINT); + return createPrimitive(DUCKDB_TYPE_UBIGINT); + case UHUGEINT: + return createPrimitive(DUCKDB_TYPE_UHUGEINT); case FLOAT: - return createPrimitive(DuckDBBindings.CAPIType.DUCKDB_TYPE_FLOAT); + return createPrimitive(DUCKDB_TYPE_FLOAT); case DOUBLE: - return createPrimitive(DuckDBBindings.CAPIType.DUCKDB_TYPE_DOUBLE); + return createPrimitive(DUCKDB_TYPE_DOUBLE); case VARCHAR: - return createPrimitive(DuckDBBindings.CAPIType.DUCKDB_TYPE_VARCHAR); + return createPrimitive(DUCKDB_TYPE_VARCHAR); case DATE: - return createPrimitive(DuckDBBindings.CAPIType.DUCKDB_TYPE_DATE); + return createPrimitive(DUCKDB_TYPE_DATE); case TIMESTAMP_S: - return createPrimitive(DuckDBBindings.CAPIType.DUCKDB_TYPE_TIMESTAMP_S); + return createPrimitive(DUCKDB_TYPE_TIMESTAMP_S); case TIMESTAMP_MS: - return createPrimitive(DuckDBBindings.CAPIType.DUCKDB_TYPE_TIMESTAMP_MS); + return createPrimitive(DUCKDB_TYPE_TIMESTAMP_MS); case TIMESTAMP: - return createPrimitive(DuckDBBindings.CAPIType.DUCKDB_TYPE_TIMESTAMP); + return createPrimitive(DUCKDB_TYPE_TIMESTAMP); case TIMESTAMP_NS: - return createPrimitive(DuckDBBindings.CAPIType.DUCKDB_TYPE_TIMESTAMP_NS); + return createPrimitive(DUCKDB_TYPE_TIMESTAMP_NS); case TIMESTAMP_WITH_TIME_ZONE: - return createPrimitive(DuckDBBindings.CAPIType.DUCKDB_TYPE_TIMESTAMP_TZ); + return createPrimitive(DUCKDB_TYPE_TIMESTAMP_TZ); default: throw new SQLException("Unsupported logical type for scalar UDF registration: " + type); } diff --git a/src/main/java/org/duckdb/DuckDBReadableVector.java b/src/main/java/org/duckdb/DuckDBReadableVector.java index b997f53a1..52324ad70 100644 --- a/src/main/java/org/duckdb/DuckDBReadableVector.java +++ b/src/main/java/org/duckdb/DuckDBReadableVector.java @@ -3,55 +3,84 @@ import java.math.BigDecimal; import java.math.BigInteger; import java.sql.Date; -import java.sql.SQLException; import java.sql.Timestamp; import java.time.LocalDate; import java.time.LocalDateTime; import java.time.OffsetDateTime; import java.util.stream.LongStream; -public interface DuckDBReadableVector { - DuckDBColumnType getType(); +/** + * Read-only scalar callback view over a DuckDB vector. + * + *

Implementations throw {@link DuckDBFunctionException} for callback-time type/value errors. + * Invalid row indexes throw {@link IndexOutOfBoundsException}. + */ +public abstract class DuckDBReadableVector { + public abstract DuckDBColumnType getType(); - long rowCount(); + public abstract long rowCount(); - LongStream rowIndexStream(); + public abstract LongStream rowIndexStream(); - boolean isNull(long row); + public abstract boolean isNull(long row); - boolean getBoolean(long row) throws SQLException; + public abstract boolean getBoolean(long row); - byte getByte(long row) throws SQLException; + public abstract boolean getBoolean(long row, boolean defaultVal); - short getShort(long row) throws SQLException; + public abstract byte getByte(long row); - short getUint8(long row) throws SQLException; + public abstract byte getByte(long row, byte defaultVal); - int getUint16(long row) throws SQLException; + public abstract short getShort(long row); - int getInt(long row) throws SQLException; + public abstract short getShort(long row, short defaultVal); - long getUint32(long row) throws SQLException; + public abstract short getUint8(long row); - long getLong(long row) throws SQLException; + public abstract short getUint8(long row, short defaultVal); - BigInteger getUint64(long row) throws SQLException; + public abstract int getUint16(long row); - float getFloat(long row) throws SQLException; + public abstract int getUint16(long row, int defaultVal); - double getDouble(long row) throws SQLException; + public abstract int getInt(long row); - LocalDate getLocalDate(long row) throws SQLException; + public abstract int getInt(long row, int defaultVal); - Date getDate(long row) throws SQLException; + public abstract long getUint32(long row); - LocalDateTime getLocalDateTime(long row) throws SQLException; + public abstract long getUint32(long row, long defaultVal); - Timestamp getTimestamp(long row) throws SQLException; + public abstract long getLong(long row); - OffsetDateTime getOffsetDateTime(long row) throws SQLException; + public abstract long getLong(long row, long defaultVal); - BigDecimal getBigDecimal(long row) throws SQLException; + public abstract BigInteger getHugeInt(long row); - String getString(long row) throws SQLException; + public abstract BigInteger getUHugeInt(long row); + + public abstract BigInteger getUint64(long row); + + public abstract float getFloat(long row); + + public abstract float getFloat(long row, float defaultVal); + + public abstract double getDouble(long row); + + public abstract double getDouble(long row, double defaultVal); + + public abstract LocalDate getLocalDate(long row); + + public abstract Date getDate(long row); + + public abstract LocalDateTime getLocalDateTime(long row); + + public abstract Timestamp getTimestamp(long row); + + public abstract OffsetDateTime getOffsetDateTime(long row); + + public abstract BigDecimal getBigDecimal(long row); + + public abstract String getString(long row); } diff --git a/src/main/java/org/duckdb/DuckDBReadableVectorImpl.java b/src/main/java/org/duckdb/DuckDBReadableVectorImpl.java index 1f3d27184..fe4c322e3 100644 --- a/src/main/java/org/duckdb/DuckDBReadableVectorImpl.java +++ b/src/main/java/org/duckdb/DuckDBReadableVectorImpl.java @@ -8,7 +8,6 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.sql.Date; -import java.sql.SQLException; import java.sql.Timestamp; import java.time.Instant; import java.time.LocalDate; @@ -18,7 +17,7 @@ import java.time.temporal.ChronoUnit; import java.util.stream.LongStream; -final class DuckDBReadableVectorImpl implements DuckDBReadableVector { +final class DuckDBReadableVectorImpl extends DuckDBReadableVector { private static final BigDecimal ULONG_MULTIPLIER = new BigDecimal("18446744073709551616"); private static final ByteOrder NATIVE_ORDER = ByteOrder.nativeOrder(); @@ -28,15 +27,20 @@ final class DuckDBReadableVectorImpl implements DuckDBReadableVector { private final ByteBuffer data; private final ByteBuffer validity; - DuckDBReadableVectorImpl(ByteBuffer vectorRef, long rowCount) throws SQLException { + DuckDBReadableVectorImpl(ByteBuffer vectorRef, long rowCount) { if (vectorRef == null) { - throw new SQLException("Invalid vector reference"); + throw new DuckDBFunctionException("Invalid vector reference"); } this.vectorRef = vectorRef; this.rowCount = rowCount; - this.typeInfo = DuckDBVectorTypeInfo.fromVector(vectorRef); - this.data = duckdb_vector_get_data(vectorRef, Math.multiplyExact(rowCount, typeInfo.widthBytes)); - this.validity = duckdb_vector_get_validity(vectorRef, rowCount); + try { + this.typeInfo = DuckDBVectorTypeInfo.fromVector(vectorRef); + } catch (java.sql.SQLException exception) { + throw new DuckDBFunctionException("Failed to resolve vector type info", exception); + } + this.data = duckdb_vector_get_data(vectorRef, Math.multiplyExact(rowCount, typeInfo.widthBytes)).order(NATIVE_ORDER); + ByteBuffer validityBuffer = duckdb_vector_get_validity(vectorRef, rowCount); + this.validity = validityBuffer == null ? null : validityBuffer.order(NATIVE_ORDER); } @Override @@ -61,151 +65,285 @@ public boolean isNull(long row) { return false; } int entryPos = Math.toIntExact(Math.multiplyExact(row / Long.SIZE, (long) Long.BYTES)); - long mask = validity.order(NATIVE_ORDER).getLong(entryPos); + long mask = validity.getLong(entryPos); return (mask & (1L << (row % Long.SIZE))) == 0; } @Override - public boolean getBoolean(long row) throws SQLException { + public boolean getBoolean(long row) { requireType(DuckDBColumnType.BOOLEAN); + if (isNull(row)) { + throw primitiveNullValue(DuckDBColumnType.BOOLEAN, row); + } return data.get(checkedRowIndex(row)) != 0; } @Override - public byte getByte(long row) throws SQLException { + public boolean getBoolean(long row, boolean defaultVal) { + requireType(DuckDBColumnType.BOOLEAN); + return isNull(row) ? defaultVal : data.get(checkedRowIndex(row)) != 0; + } + + @Override + public byte getByte(long row) { requireType(DuckDBColumnType.TINYINT); + if (isNull(row)) { + throw primitiveNullValue(DuckDBColumnType.TINYINT, row); + } return data.get(checkedRowIndex(row)); } @Override - public short getShort(long row) throws SQLException { + public byte getByte(long row, byte defaultVal) { + requireType(DuckDBColumnType.TINYINT); + return isNull(row) ? defaultVal : data.get(checkedRowIndex(row)); + } + + @Override + public short getShort(long row) { requireType(DuckDBColumnType.SMALLINT); - return data.order(NATIVE_ORDER).getShort(checkedByteOffset(row, Short.BYTES)); + if (isNull(row)) { + throw primitiveNullValue(DuckDBColumnType.SMALLINT, row); + } + return data.getShort(checkedByteOffset(row, Short.BYTES)); } @Override - public short getUint8(long row) throws SQLException { + public short getShort(long row, short defaultVal) { + requireType(DuckDBColumnType.SMALLINT); + return isNull(row) ? defaultVal : data.getShort(checkedByteOffset(row, Short.BYTES)); + } + + @Override + public short getUint8(long row) { requireType(DuckDBColumnType.UTINYINT); + if (isNull(row)) { + throw primitiveNullValue(DuckDBColumnType.UTINYINT, row); + } return (short) Byte.toUnsignedInt(data.get(checkedRowIndex(row))); } @Override - public int getUint16(long row) throws SQLException { + public short getUint8(long row, short defaultVal) { + requireType(DuckDBColumnType.UTINYINT); + return isNull(row) ? defaultVal : (short) Byte.toUnsignedInt(data.get(checkedRowIndex(row))); + } + + @Override + public int getUint16(long row) { + requireType(DuckDBColumnType.USMALLINT); + if (isNull(row)) { + throw primitiveNullValue(DuckDBColumnType.USMALLINT, row); + } + return Short.toUnsignedInt(data.getShort(checkedByteOffset(row, Short.BYTES))); + } + + @Override + public int getUint16(long row, int defaultVal) { requireType(DuckDBColumnType.USMALLINT); - return Short.toUnsignedInt(data.order(NATIVE_ORDER).getShort(checkedByteOffset(row, Short.BYTES))); + return isNull(row) ? defaultVal : Short.toUnsignedInt(data.getShort(checkedByteOffset(row, Short.BYTES))); } @Override - public int getInt(long row) throws SQLException { + public int getInt(long row) { requireType(DuckDBColumnType.INTEGER); - return data.order(NATIVE_ORDER).getInt(checkedByteOffset(row, Integer.BYTES)); + if (isNull(row)) { + throw primitiveNullValue(DuckDBColumnType.INTEGER, row); + } + return data.getInt(checkedByteOffset(row, Integer.BYTES)); } @Override - public long getUint32(long row) throws SQLException { + public int getInt(long row, int defaultVal) { + requireType(DuckDBColumnType.INTEGER); + return isNull(row) ? defaultVal : data.getInt(checkedByteOffset(row, Integer.BYTES)); + } + + @Override + public long getUint32(long row) { requireType(DuckDBColumnType.UINTEGER); - return Integer.toUnsignedLong(data.order(NATIVE_ORDER).getInt(checkedByteOffset(row, Integer.BYTES))); + if (isNull(row)) { + throw primitiveNullValue(DuckDBColumnType.UINTEGER, row); + } + return Integer.toUnsignedLong(data.getInt(checkedByteOffset(row, Integer.BYTES))); } @Override - public long getLong(long row) throws SQLException { + public long getUint32(long row, long defaultVal) { + requireType(DuckDBColumnType.UINTEGER); + return isNull(row) ? defaultVal : Integer.toUnsignedLong(data.getInt(checkedByteOffset(row, Integer.BYTES))); + } + + @Override + public long getLong(long row) { requireType(DuckDBColumnType.BIGINT); - return data.order(NATIVE_ORDER).getLong(checkedByteOffset(row, Long.BYTES)); + if (isNull(row)) { + throw primitiveNullValue(DuckDBColumnType.BIGINT, row); + } + return data.getLong(checkedByteOffset(row, Long.BYTES)); } @Override - public BigInteger getUint64(long row) throws SQLException { + public long getLong(long row, long defaultVal) { + requireType(DuckDBColumnType.BIGINT); + return isNull(row) ? defaultVal : data.getLong(checkedByteOffset(row, Long.BYTES)); + } + + @Override + public BigInteger getHugeInt(long row) { + requireType(DuckDBColumnType.HUGEINT); + if (isNull(row)) { + return null; + } + int offset = checkedByteOffset(row, typeInfo.widthBytes); + long lower = data.getLong(offset); + long upper = data.getLong(offset + Long.BYTES); + return DuckDBHugeInt.toBigInteger(lower, upper); + } + + @Override + public BigInteger getUHugeInt(long row) { + requireType(DuckDBColumnType.UHUGEINT); + if (isNull(row)) { + return null; + } + int offset = checkedByteOffset(row, typeInfo.widthBytes); + long lower = data.getLong(offset); + long upper = data.getLong(offset + Long.BYTES); + return DuckDBHugeInt.toUnsignedBigInteger(lower, upper); + } + + @Override + public BigInteger getUint64(long row) { requireType(DuckDBColumnType.UBIGINT); - long value = data.order(NATIVE_ORDER).getLong(checkedByteOffset(row, Long.BYTES)); + if (isNull(row)) { + return null; + } + long value = data.getLong(checkedByteOffset(row, Long.BYTES)); return unsignedLongToBigInteger(value); } @Override - public float getFloat(long row) throws SQLException { + public float getFloat(long row) { requireType(DuckDBColumnType.FLOAT); - return data.order(NATIVE_ORDER).getFloat(checkedByteOffset(row, Float.BYTES)); + if (isNull(row)) { + throw primitiveNullValue(DuckDBColumnType.FLOAT, row); + } + return data.getFloat(checkedByteOffset(row, Float.BYTES)); } @Override - public double getDouble(long row) throws SQLException { + public float getFloat(long row, float defaultVal) { + requireType(DuckDBColumnType.FLOAT); + return isNull(row) ? defaultVal : data.getFloat(checkedByteOffset(row, Float.BYTES)); + } + + @Override + public double getDouble(long row) { requireType(DuckDBColumnType.DOUBLE); - return data.order(NATIVE_ORDER).getDouble(checkedByteOffset(row, Double.BYTES)); + if (isNull(row)) { + throw primitiveNullValue(DuckDBColumnType.DOUBLE, row); + } + return data.getDouble(checkedByteOffset(row, Double.BYTES)); } @Override - public LocalDate getLocalDate(long row) throws SQLException { + public double getDouble(long row, double defaultVal) { + requireType(DuckDBColumnType.DOUBLE); + return isNull(row) ? defaultVal : data.getDouble(checkedByteOffset(row, Double.BYTES)); + } + + @Override + public LocalDate getLocalDate(long row) { requireType(DuckDBColumnType.DATE); - return LocalDate.ofEpochDay(data.order(NATIVE_ORDER).getInt(checkedByteOffset(row, Integer.BYTES))); + if (isNull(row)) { + return null; + } + return LocalDate.ofEpochDay(data.getInt(checkedByteOffset(row, Integer.BYTES))); } @Override - public Date getDate(long row) throws SQLException { - return Date.valueOf(getLocalDate(row)); + public Date getDate(long row) { + LocalDate value = getLocalDate(row); + return value == null ? null : Date.valueOf(value); } @Override - public LocalDateTime getLocalDateTime(long row) throws SQLException { + public LocalDateTime getLocalDateTime(long row) { requireTimestampType(); - long epochValue = data.order(NATIVE_ORDER).getLong(checkedByteOffset(row, Long.BYTES)); - switch (typeInfo.capiType) { - case DUCKDB_TYPE_TIMESTAMP_S: - return DuckDBTimestamp.localDateTimeFromTimestamp(epochValue, ChronoUnit.SECONDS, null); - case DUCKDB_TYPE_TIMESTAMP_MS: - return DuckDBTimestamp.localDateTimeFromTimestamp(epochValue, ChronoUnit.MILLIS, null); - case DUCKDB_TYPE_TIMESTAMP: - return DuckDBTimestamp.localDateTimeFromTimestamp(epochValue, ChronoUnit.MICROS, null); - case DUCKDB_TYPE_TIMESTAMP_NS: - return DuckDBTimestamp.localDateTimeFromTimestamp(epochValue, ChronoUnit.NANOS, null); - case DUCKDB_TYPE_TIMESTAMP_TZ: - return DuckDBTimestamp.localDateTimeFromTimestampWithTimezone(epochValue, ChronoUnit.MICROS, null); - default: - throw new SQLException("Expected vector type TIMESTAMP*, found " + typeInfo.columnType); + if (isNull(row)) { + return null; + } + long epochValue = data.getLong(checkedByteOffset(row, Long.BYTES)); + try { + switch (typeInfo.capiType) { + case DUCKDB_TYPE_TIMESTAMP_S: + return DuckDBTimestamp.localDateTimeFromTimestamp(epochValue, ChronoUnit.SECONDS, null); + case DUCKDB_TYPE_TIMESTAMP_MS: + return DuckDBTimestamp.localDateTimeFromTimestamp(epochValue, ChronoUnit.MILLIS, null); + case DUCKDB_TYPE_TIMESTAMP: + return DuckDBTimestamp.localDateTimeFromTimestamp(epochValue, ChronoUnit.MICROS, null); + case DUCKDB_TYPE_TIMESTAMP_NS: + return DuckDBTimestamp.localDateTimeFromTimestamp(epochValue, ChronoUnit.NANOS, null); + case DUCKDB_TYPE_TIMESTAMP_TZ: + return DuckDBTimestamp.localDateTimeFromTimestampWithTimezone(epochValue, ChronoUnit.MICROS, null); + default: + throw new DuckDBFunctionException("Expected vector type TIMESTAMP*, found " + typeInfo.columnType); + } + } catch (java.sql.SQLException exception) { + throw new DuckDBFunctionException("Failed to decode timestamp at row " + row, exception); } } @Override - public Timestamp getTimestamp(long row) throws SQLException { - return Timestamp.valueOf(getLocalDateTime(row)); + public Timestamp getTimestamp(long row) { + LocalDateTime value = getLocalDateTime(row); + return value == null ? null : Timestamp.valueOf(value); } @Override - public OffsetDateTime getOffsetDateTime(long row) throws SQLException { + public OffsetDateTime getOffsetDateTime(long row) { requireType(DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE); - long micros = data.order(NATIVE_ORDER).getLong(checkedByteOffset(row, Long.BYTES)); + if (isNull(row)) { + return null; + } + long micros = data.getLong(checkedByteOffset(row, Long.BYTES)); Instant instant = instantFromEpoch(micros, ChronoUnit.MICROS); return instant.atZone(ZoneId.systemDefault()).toOffsetDateTime(); } @Override - public BigDecimal getBigDecimal(long row) throws SQLException { + public BigDecimal getBigDecimal(long row) { requireType(DuckDBColumnType.DECIMAL); + if (isNull(row)) { + return null; + } switch (typeInfo.storageType) { case DUCKDB_TYPE_SMALLINT: - return BigDecimal.valueOf(data.order(NATIVE_ORDER).getShort(checkedByteOffset(row, Short.BYTES)), + return BigDecimal.valueOf(data.getShort(checkedByteOffset(row, Short.BYTES)), typeInfo.decimalMeta.scale); case DUCKDB_TYPE_INTEGER: - return BigDecimal.valueOf(data.order(NATIVE_ORDER).getInt(checkedByteOffset(row, Integer.BYTES)), + return BigDecimal.valueOf(data.getInt(checkedByteOffset(row, Integer.BYTES)), typeInfo.decimalMeta.scale); case DUCKDB_TYPE_BIGINT: - return BigDecimal.valueOf(data.order(NATIVE_ORDER).getLong(checkedByteOffset(row, Long.BYTES)), + return BigDecimal.valueOf(data.getLong(checkedByteOffset(row, Long.BYTES)), typeInfo.decimalMeta.scale); case DUCKDB_TYPE_HUGEINT: { - ByteBuffer slice = data.duplicate().order(NATIVE_ORDER); - slice.position(checkedByteOffset(row, typeInfo.widthBytes)); - long lower = slice.getLong(); - long upper = slice.getLong(); + int offset = checkedByteOffset(row, typeInfo.widthBytes); + long lower = data.getLong(offset); + long upper = data.getLong(offset + Long.BYTES); return new BigDecimal(upper) .multiply(ULONG_MULTIPLIER) .add(new BigDecimal(Long.toUnsignedString(lower))) .scaleByPowerOfTen(typeInfo.decimalMeta.scale * -1); } default: - throw new SQLException("Unsupported DECIMAL storage type: " + typeInfo.storageType); + throw new DuckDBFunctionException("Unsupported DECIMAL storage type: " + typeInfo.storageType); } } @Override - public String getString(long row) throws SQLException { + public String getString(long row) { requireType(DuckDBColumnType.VARCHAR); if (isNull(row)) { return null; @@ -221,13 +359,13 @@ ByteBuffer vectorRef() { return vectorRef; } - private void requireType(DuckDBColumnType expected) throws SQLException { + private void requireType(DuckDBColumnType expected) { if (typeInfo.columnType != expected) { - throw new SQLException("Expected vector type " + expected + ", found " + typeInfo.columnType); + throw new DuckDBFunctionException("Expected vector type " + expected + ", found " + typeInfo.columnType); } } - private void requireTimestampType() throws SQLException { + private void requireTimestampType() { switch (typeInfo.columnType) { case TIMESTAMP: case TIMESTAMP_S: @@ -236,7 +374,7 @@ private void requireTimestampType() throws SQLException { case TIMESTAMP_WITH_TIME_ZONE: return; default: - throw new SQLException("Expected vector type TIMESTAMP*, found " + typeInfo.columnType); + throw new DuckDBFunctionException("Expected vector type TIMESTAMP*, found " + typeInfo.columnType); } } @@ -256,7 +394,7 @@ private int checkedByteOffset(long row, int elementWidth) { return Math.toIntExact(Math.multiplyExact(row, (long) elementWidth)); } - private static Instant instantFromEpoch(long value, ChronoUnit unit) throws SQLException { + private static Instant instantFromEpoch(long value, ChronoUnit unit) { switch (unit) { case SECONDS: return Instant.ofEpochSecond(value); @@ -273,7 +411,7 @@ private static Instant instantFromEpoch(long value, ChronoUnit unit) throws SQLE return Instant.ofEpochSecond(epochSecond, nanoAdjustment); } default: - throw new SQLException("Unsupported unit type: " + unit); + throw new DuckDBFunctionException("Unsupported unit type: " + unit); } } @@ -283,4 +421,8 @@ private static BigInteger unsignedLongToBigInteger(long value) { } return BigInteger.valueOf(value & Long.MAX_VALUE).setBit(Long.SIZE - 1); } + + private static DuckDBFunctionException primitiveNullValue(DuckDBColumnType type, long row) { + return new DuckDBFunctionException("Primitive value for " + type + " at row " + row + " is NULL"); + } } diff --git a/src/main/java/org/duckdb/DuckDBRegisteredFunction.java b/src/main/java/org/duckdb/DuckDBRegisteredFunction.java index f824a00b9..517b7908c 100644 --- a/src/main/java/org/duckdb/DuckDBRegisteredFunction.java +++ b/src/main/java/org/duckdb/DuckDBRegisteredFunction.java @@ -6,7 +6,7 @@ public final class DuckDBRegisteredFunction { private final String name; - private final DuckDBFunctionKind functionKind; + private final DuckDBFunctions.DuckDBFunctionKind functionKind; private final List parameterTypes; private final List parameterColumnTypes; private final DuckDBLogicalType returnType; @@ -17,7 +17,7 @@ public final class DuckDBRegisteredFunction { private final boolean specialHandlingFlag; private final boolean propagateNullsFlag; - private DuckDBRegisteredFunction(String name, DuckDBFunctionKind functionKind, + private DuckDBRegisteredFunction(String name, DuckDBFunctions.DuckDBFunctionKind functionKind, List parameterTypes, List parameterColumnTypes, DuckDBLogicalType returnType, DuckDBColumnType returnColumnType, DuckDBScalarFunction function, @@ -40,7 +40,7 @@ public String name() { return name; } - public DuckDBFunctionKind functionKind() { + public DuckDBFunctions.DuckDBFunctionKind functionKind() { return functionKind; } @@ -81,7 +81,7 @@ public boolean propagateNulls() { } public boolean isScalar() { - return functionKind == DuckDBFunctionKind.SCALAR; + return functionKind == DuckDBFunctions.DuckDBFunctionKind.SCALAR; } static DuckDBRegisteredFunction of(String name, List parameterTypes, @@ -90,7 +90,7 @@ static DuckDBRegisteredFunction of(String name, List paramete DuckDBLogicalType varArgType, boolean volatileFlag, boolean specialHandlingFlag, boolean propagateNullsFlag) { return new DuckDBRegisteredFunction( - name, DuckDBFunctionKind.SCALAR, Collections.unmodifiableList(new ArrayList<>(parameterTypes)), + name, DuckDBFunctions.DuckDBFunctionKind.SCALAR, Collections.unmodifiableList(new ArrayList<>(parameterTypes)), Collections.unmodifiableList(new ArrayList<>(parameterColumnTypes)), returnType, returnColumnType, function, varArgType, volatileFlag, specialHandlingFlag, propagateNullsFlag); } diff --git a/src/main/java/org/duckdb/DuckDBScalarContext.java b/src/main/java/org/duckdb/DuckDBScalarContext.java index c9d4fc8be..ad0c927ae 100644 --- a/src/main/java/org/duckdb/DuckDBScalarContext.java +++ b/src/main/java/org/duckdb/DuckDBScalarContext.java @@ -1,13 +1,18 @@ package org.duckdb; -import java.sql.SQLException; import java.util.stream.LongStream; import java.util.stream.Stream; +/** + * Per-invocation scalar callback context. + * + *

Runtime type/value failures are surfaced as {@link DuckDBFunctionException}. Invalid row or + * column indexes throw {@link IndexOutOfBoundsException}. + */ public final class DuckDBScalarContext { private final DuckDBDataChunkReader input; private final DuckDBWritableVector output; - private final boolean propagateNulls; + private boolean propagateNulls; DuckDBScalarContext(DuckDBDataChunkReader input, DuckDBWritableVector output, boolean propagateNulls) { if (input == null) { @@ -29,7 +34,7 @@ public long columnCount() { return input.columnCount(); } - public DuckDBReadableVector input(long columnIndex) throws SQLException { + public DuckDBReadableVector input(long columnIndex) { return input.vector(columnIndex); } @@ -37,10 +42,15 @@ public DuckDBWritableVector output() { return output; } - public boolean propagateNulls() { + public boolean nullsPropagated() { return propagateNulls; } + public DuckDBScalarContext propagateNulls(boolean propagateNulls) { + this.propagateNulls = propagateNulls; + return this; + } + DuckDBDataChunkReader inputChunk() { return input; } @@ -59,11 +69,7 @@ public DuckDBScalarRow row(long rowIndex) { } DuckDBReadableVector inputUnchecked(long columnIndex) { - try { - return input(columnIndex); - } catch (SQLException exception) { - throw new IllegalStateException("Failed to access input column " + columnIndex, exception); - } + return input(columnIndex); } private void checkRowIndex(long rowIndex) { @@ -75,11 +81,7 @@ private void checkRowIndex(long rowIndex) { private boolean rowHasNoNullInputs(long rowIndex) { for (long columnIndex = 0; columnIndex < columnCount(); columnIndex++) { if (inputUnchecked(columnIndex).isNull(rowIndex)) { - try { - output.setNull(rowIndex); - } catch (SQLException exception) { - throw new IllegalStateException("Failed to write NULL to output row " + rowIndex, exception); - } + output.setNull(rowIndex); return false; } } diff --git a/src/main/java/org/duckdb/DuckDBScalarFunctionAdapter.java b/src/main/java/org/duckdb/DuckDBScalarFunctionAdapter.java index 11456fa3b..f95f9599f 100644 --- a/src/main/java/org/duckdb/DuckDBScalarFunctionAdapter.java +++ b/src/main/java/org/duckdb/DuckDBScalarFunctionAdapter.java @@ -53,6 +53,10 @@ final class DuckDBScalarFunctionAdapter { register(DuckDBColumnType.UINTEGER, Long.class, DuckDBReadableVector::getUint32, (out, row, value) -> out.setUint32(row, value)); register(DuckDBColumnType.BIGINT, Long.class, DuckDBReadableVector::getLong, DuckDBWritableVector::setLong); + register(DuckDBColumnType.HUGEINT, BigInteger.class, DuckDBReadableVector::getHugeInt, + DuckDBWritableVector::setHugeInt); + register(DuckDBColumnType.UHUGEINT, BigInteger.class, DuckDBReadableVector::getUHugeInt, + DuckDBWritableVector::setUHugeInt); register(DuckDBColumnType.UBIGINT, BigInteger.class, DuckDBReadableVector::getUint64, DuckDBWritableVector::setUint64); register(DuckDBColumnType.FLOAT, Float.class, DuckDBReadableVector::getFloat, DuckDBWritableVector::setFloat); @@ -91,7 +95,7 @@ final class DuckDBScalarFunctionAdapter { DUCKDB_TYPE_BY_JAVA_CLASS.put(Double.class, DuckDBColumnType.DOUBLE); DUCKDB_TYPE_BY_JAVA_CLASS.put(String.class, DuckDBColumnType.VARCHAR); DUCKDB_TYPE_BY_JAVA_CLASS.put(BigDecimal.class, DuckDBColumnType.DECIMAL); - DUCKDB_TYPE_BY_JAVA_CLASS.put(BigInteger.class, DuckDBColumnType.UBIGINT); + DUCKDB_TYPE_BY_JAVA_CLASS.put(BigInteger.class, DuckDBColumnType.HUGEINT); DUCKDB_TYPE_BY_JAVA_CLASS.put(LocalDate.class, DuckDBColumnType.DATE); DUCKDB_TYPE_BY_JAVA_CLASS.put(java.sql.Date.class, DuckDBColumnType.DATE); DUCKDB_TYPE_BY_JAVA_CLASS.put(LocalDateTime.class, DuckDBColumnType.TIMESTAMP); @@ -118,15 +122,20 @@ static DuckDBScalarFunction unary(Function function, DuckDBColumnType para DuckDBReadableVector in = ctx.input(0); DuckDBWritableVector out = ctx.output(); long rowCount = ctx.rowCount(); - boolean propagateNulls = ctx.propagateNulls(); + boolean propagateNulls = ctx.nullsPropagated(); for (long row = 0; row < rowCount; row++) { - if (propagateNulls && in.isNull(row)) { - out.setNull(row); - continue; + try { + if (propagateNulls && in.isNull(row)) { + out.setNull(row); + continue; + } + Object argument = in.isNull(row) ? null : inCodec.read(in, row); + Object result = typedFunction.apply(argument); + outCodec.write(out, row, result); + } catch (DuckDBFunctionException exception) { + throw new DuckDBFunctionException("Failed to execute unary scalar function at row " + row, + exception); } - Object argument = in.isNull(row) ? null : inCodec.read(in, row); - Object result = typedFunction.apply(argument); - outCodec.write(out, row, result); } }; } @@ -144,35 +153,45 @@ static DuckDBScalarFunction binary(BiFunction function, DuckDBColumnTyp DuckDBReadableVector right = ctx.input(1); DuckDBWritableVector out = ctx.output(); long rowCount = ctx.rowCount(); - boolean propagateNulls = ctx.propagateNulls(); + boolean propagateNulls = ctx.nullsPropagated(); for (long row = 0; row < rowCount; row++) { - boolean leftIsNull = left.isNull(row); - boolean rightIsNull = right.isNull(row); - if (propagateNulls && (leftIsNull || rightIsNull)) { - out.setNull(row); - continue; + try { + boolean leftIsNull = left.isNull(row); + boolean rightIsNull = right.isNull(row); + if (propagateNulls && (leftIsNull || rightIsNull)) { + out.setNull(row); + continue; + } + Object leftValue = leftIsNull ? null : leftCodec.read(left, row); + Object rightValue = rightIsNull ? null : rightCodec.read(right, row); + Object result = typedFunction.apply(leftValue, rightValue); + outCodec.write(out, row, result); + } catch (DuckDBFunctionException exception) { + throw new DuckDBFunctionException("Failed to execute binary scalar function at row " + row, + exception); } - Object leftValue = leftIsNull ? null : leftCodec.read(left, row); - Object rightValue = rightIsNull ? null : rightCodec.read(right, row); - Object result = typedFunction.apply(leftValue, rightValue); - outCodec.write(out, row, result); } }; } static DuckDBScalarFunction intUnary(IntUnaryOperator function) { return ctx -> { - if (!ctx.propagateNulls()) { - throw new IllegalStateException("withIntFunction requires propagateNulls(true)"); + if (!ctx.nullsPropagated()) { + throw new DuckDBFunctionException("withIntFunction requires propagateNulls(true)"); } DuckDBReadableVector in = ctx.input(0); DuckDBWritableVector out = ctx.output(); long rowCount = ctx.rowCount(); for (long row = 0; row < rowCount; row++) { - if (in.isNull(row)) { - out.setNull(row); - } else { - out.setInt(row, function.applyAsInt(in.getInt(row))); + try { + if (in.isNull(row)) { + out.setNull(row); + } else { + out.setInt(row, function.applyAsInt(in.getInt(row))); + } + } catch (DuckDBFunctionException exception) { + throw new DuckDBFunctionException("Failed to execute withIntFunction at row " + row, + exception); } } }; @@ -180,18 +199,23 @@ static DuckDBScalarFunction intUnary(IntUnaryOperator function) { static DuckDBScalarFunction intBinary(IntBinaryOperator function) { return ctx -> { - if (!ctx.propagateNulls()) { - throw new IllegalStateException("withIntFunction requires propagateNulls(true)"); + if (!ctx.nullsPropagated()) { + throw new DuckDBFunctionException("withIntFunction requires propagateNulls(true)"); } DuckDBReadableVector left = ctx.input(0); DuckDBReadableVector right = ctx.input(1); DuckDBWritableVector out = ctx.output(); long rowCount = ctx.rowCount(); for (long row = 0; row < rowCount; row++) { - if (left.isNull(row) || right.isNull(row)) { - out.setNull(row); - } else { - out.setInt(row, function.applyAsInt(left.getInt(row), right.getInt(row))); + try { + if (left.isNull(row) || right.isNull(row)) { + out.setNull(row); + } else { + out.setInt(row, function.applyAsInt(left.getInt(row), right.getInt(row))); + } + } catch (DuckDBFunctionException exception) { + throw new DuckDBFunctionException("Failed to execute withIntFunction at row " + row, + exception); } } }; @@ -199,17 +223,22 @@ static DuckDBScalarFunction intBinary(IntBinaryOperator function) { static DuckDBScalarFunction doubleUnary(DoubleUnaryOperator function) { return ctx -> { - if (!ctx.propagateNulls()) { - throw new IllegalStateException("withDoubleFunction requires propagateNulls(true)"); + if (!ctx.nullsPropagated()) { + throw new DuckDBFunctionException("withDoubleFunction requires propagateNulls(true)"); } DuckDBReadableVector in = ctx.input(0); DuckDBWritableVector out = ctx.output(); long rowCount = ctx.rowCount(); for (long row = 0; row < rowCount; row++) { - if (in.isNull(row)) { - out.setNull(row); - } else { - out.setDouble(row, function.applyAsDouble(in.getDouble(row))); + try { + if (in.isNull(row)) { + out.setNull(row); + } else { + out.setDouble(row, function.applyAsDouble(in.getDouble(row))); + } + } catch (DuckDBFunctionException exception) { + throw new DuckDBFunctionException("Failed to execute withDoubleFunction at row " + row, + exception); } } }; @@ -217,18 +246,23 @@ static DuckDBScalarFunction doubleUnary(DoubleUnaryOperator function) { static DuckDBScalarFunction doubleBinary(DoubleBinaryOperator function) { return ctx -> { - if (!ctx.propagateNulls()) { - throw new IllegalStateException("withDoubleFunction requires propagateNulls(true)"); + if (!ctx.nullsPropagated()) { + throw new DuckDBFunctionException("withDoubleFunction requires propagateNulls(true)"); } DuckDBReadableVector left = ctx.input(0); DuckDBReadableVector right = ctx.input(1); DuckDBWritableVector out = ctx.output(); long rowCount = ctx.rowCount(); for (long row = 0; row < rowCount; row++) { - if (left.isNull(row) || right.isNull(row)) { - out.setNull(row); - } else { - out.setDouble(row, function.applyAsDouble(left.getDouble(row), right.getDouble(row))); + try { + if (left.isNull(row) || right.isNull(row)) { + out.setNull(row); + } else { + out.setDouble(row, function.applyAsDouble(left.getDouble(row), right.getDouble(row))); + } + } catch (DuckDBFunctionException exception) { + throw new DuckDBFunctionException("Failed to execute withDoubleFunction at row " + row, + exception); } } }; @@ -236,17 +270,22 @@ static DuckDBScalarFunction doubleBinary(DoubleBinaryOperator function) { static DuckDBScalarFunction longUnary(LongUnaryOperator function) { return ctx -> { - if (!ctx.propagateNulls()) { - throw new IllegalStateException("withLongFunction requires propagateNulls(true)"); + if (!ctx.nullsPropagated()) { + throw new DuckDBFunctionException("withLongFunction requires propagateNulls(true)"); } DuckDBReadableVector in = ctx.input(0); DuckDBWritableVector out = ctx.output(); long rowCount = ctx.rowCount(); for (long row = 0; row < rowCount; row++) { - if (in.isNull(row)) { - out.setNull(row); - } else { - out.setLong(row, function.applyAsLong(in.getLong(row))); + try { + if (in.isNull(row)) { + out.setNull(row); + } else { + out.setLong(row, function.applyAsLong(in.getLong(row))); + } + } catch (DuckDBFunctionException exception) { + throw new DuckDBFunctionException("Failed to execute withLongFunction at row " + row, + exception); } } }; @@ -254,18 +293,23 @@ static DuckDBScalarFunction longUnary(LongUnaryOperator function) { static DuckDBScalarFunction longBinary(LongBinaryOperator function) { return ctx -> { - if (!ctx.propagateNulls()) { - throw new IllegalStateException("withLongFunction requires propagateNulls(true)"); + if (!ctx.nullsPropagated()) { + throw new DuckDBFunctionException("withLongFunction requires propagateNulls(true)"); } DuckDBReadableVector left = ctx.input(0); DuckDBReadableVector right = ctx.input(1); DuckDBWritableVector out = ctx.output(); long rowCount = ctx.rowCount(); for (long row = 0; row < rowCount; row++) { - if (left.isNull(row) || right.isNull(row)) { - out.setNull(row); - } else { - out.setLong(row, function.applyAsLong(left.getLong(row), right.getLong(row))); + try { + if (left.isNull(row) || right.isNull(row)) { + out.setNull(row); + } else { + out.setLong(row, function.applyAsLong(left.getLong(row), right.getLong(row))); + } + } catch (DuckDBFunctionException exception) { + throw new DuckDBFunctionException("Failed to execute withLongFunction at row " + row, + exception); } } }; @@ -279,8 +323,13 @@ static DuckDBScalarFunction nullary(Supplier function, DuckDBColumnType retur DuckDBWritableVector out = ctx.output(); long rowCount = ctx.rowCount(); for (long row = 0; row < rowCount; row++) { - Object result = typedFunction.get(); - outCodec.write(out, row, result); + try { + Object result = typedFunction.get(); + outCodec.write(out, row, result); + } catch (DuckDBFunctionException exception) { + throw new DuckDBFunctionException("Failed to execute supplier scalar function at row " + row, + exception); + } } }; } @@ -299,7 +348,7 @@ static DuckDBScalarFunction variadic(Function function, DuckDBColum DuckDBWritableVector out = ctx.output(); long rowCount = ctx.rowCount(); int vectorCount = Math.toIntExact(ctx.columnCount()); - boolean propagateNulls = ctx.propagateNulls(); + boolean propagateNulls = ctx.nullsPropagated(); DuckDBReadableVector[] vectors = new DuckDBReadableVector[vectorCount]; TypeCodec[] codecs = new TypeCodec[vectorCount]; for (int column = 0; column < vectorCount; column++) { @@ -308,25 +357,30 @@ static DuckDBScalarFunction variadic(Function function, DuckDBColum } Object[] args = new Object[vectorCount]; for (long row = 0; row < rowCount; row++) { - boolean skipRow = false; - for (int column = 0; column < vectorCount; column++) { - DuckDBReadableVector vector = vectors[column]; - if (vector.isNull(row)) { - if (propagateNulls) { - out.setNull(row); - skipRow = true; - break; + try { + boolean skipRow = false; + for (int column = 0; column < vectorCount; column++) { + DuckDBReadableVector vector = vectors[column]; + if (vector.isNull(row)) { + if (propagateNulls) { + out.setNull(row); + skipRow = true; + break; + } + args[column] = null; + } else { + args[column] = codecs[column].read(vector, row); } - args[column] = null; - } else { - args[column] = codecs[column].read(vector, row); } + if (skipRow) { + continue; + } + Object result = function.apply(args); + outCodec.write(out, row, result); + } catch (DuckDBFunctionException exception) { + throw new DuckDBFunctionException("Failed to execute variadic scalar function at row " + row, + exception); } - if (skipRow) { - continue; - } - Object result = function.apply(args); - outCodec.write(out, row, result); } }; } @@ -366,8 +420,12 @@ static DuckDBColumnType mapLogicalTypeToDuckDBType(DuckDBLogicalType logicalType return DuckDBColumnType.UINTEGER; case DUCKDB_TYPE_BIGINT: return DuckDBColumnType.BIGINT; + case DUCKDB_TYPE_HUGEINT: + return DuckDBColumnType.HUGEINT; case DUCKDB_TYPE_UBIGINT: return DuckDBColumnType.UBIGINT; + case DUCKDB_TYPE_UHUGEINT: + return DuckDBColumnType.UHUGEINT; case DUCKDB_TYPE_FLOAT: return DuckDBColumnType.FLOAT; case DUCKDB_TYPE_DOUBLE: @@ -455,12 +513,12 @@ private DuckDBScalarFunctionAdapter() { @FunctionalInterface private interface Reader { - T read(DuckDBReadableVector vector, long row) throws SQLException; + T read(DuckDBReadableVector vector, long row); } @FunctionalInterface private interface Writer { - void write(DuckDBWritableVector vector, long row, T value) throws SQLException; + void write(DuckDBWritableVector vector, long row, T value); } private static final class TypeCodec { @@ -478,11 +536,11 @@ boolean matches(Class declaredJavaType) { return javaType == declaredJavaType; } - Object read(DuckDBReadableVector vector, long row) throws SQLException { + Object read(DuckDBReadableVector vector, long row) { return reader.read(vector, row); } - void write(DuckDBWritableVector vector, long row, Object value) throws SQLException { + void write(DuckDBWritableVector vector, long row, Object value) { if (value == null) { vector.setNull(row); return; diff --git a/src/main/java/org/duckdb/DuckDBScalarFunctionBuilder.java b/src/main/java/org/duckdb/DuckDBScalarFunctionBuilder.java index 446b7100d..6298d0713 100644 --- a/src/main/java/org/duckdb/DuckDBScalarFunctionBuilder.java +++ b/src/main/java/org/duckdb/DuckDBScalarFunctionBuilder.java @@ -32,8 +32,7 @@ public final class DuckDBScalarFunctionBuilder implements AutoCloseable { private final List> parameterJavaTypes = new ArrayList<>(); private boolean volatileFlag; private boolean specialHandlingFlag; - private boolean propagateNullsFlag = true; - private boolean callbackRequiresNullPropagation; + private boolean propagateNullsFlag; private boolean finalized; DuckDBScalarFunctionBuilder() throws SQLException { @@ -95,6 +94,17 @@ public DuckDBScalarFunctionBuilder withParameter(Class parameterType) throws return addMappedParameterType(mappedType, parameterType); } + public DuckDBScalarFunctionBuilder withParameters(Class... parameterTypes) throws SQLException { + ensureNotFinalized(); + if (parameterTypes == null) { + throw new SQLException("Parameter types cannot be null"); + } + for (Class parameterType : parameterTypes) { + withParameter(parameterType); + } + return this; + } + public DuckDBScalarFunctionBuilder withReturnType(DuckDBColumnType returnType) throws SQLException { ensureNotFinalized(); if (returnType == null) { @@ -119,24 +129,12 @@ public DuckDBScalarFunctionBuilder withVectorizedFunction(DuckDBScalarFunction f return setCallback(function, false); } - public DuckDBScalarFunctionBuilder propagateNulls(boolean propagateNulls) throws SQLException { - ensureNotFinalized(); - if (!propagateNulls && callbackRequiresNullPropagation) { - throw new SQLException("Primitive scalar callbacks require propagateNulls(true)"); - } - this.propagateNullsFlag = propagateNulls; - if (callback != null) { - duckdb_scalar_function_set_function(scalarFunctionRef, - new DuckDBScalarFunctionWrapper(callback, propagateNullsFlag)); - } - return this; - } - public DuckDBScalarFunctionBuilder withIntFunction(IntUnaryOperator function) throws SQLException { ensureNotFinalized(); if (function == null) { throw new SQLException("Scalar function callback cannot be null"); } + enablePrimitiveNullPropagation(); ensurePrimitiveCallbackCompatible("withIntFunction"); ensureUnaryPrimitiveSignature(DuckDBColumnType.INTEGER, "withIntFunction"); return setCallback(DuckDBScalarFunctionAdapter.intUnary(function), true); @@ -147,6 +145,7 @@ public DuckDBScalarFunctionBuilder withIntFunction(IntBinaryOperator function) t if (function == null) { throw new SQLException("Scalar function callback cannot be null"); } + enablePrimitiveNullPropagation(); ensurePrimitiveCallbackCompatible("withIntFunction"); ensureBinaryPrimitiveSignature(DuckDBColumnType.INTEGER, "withIntFunction"); return setCallback(DuckDBScalarFunctionAdapter.intBinary(function), true); @@ -157,6 +156,7 @@ public DuckDBScalarFunctionBuilder withDoubleFunction(DoubleUnaryOperator functi if (function == null) { throw new SQLException("Scalar function callback cannot be null"); } + enablePrimitiveNullPropagation(); ensurePrimitiveCallbackCompatible("withDoubleFunction"); ensureUnaryPrimitiveSignature(DuckDBColumnType.DOUBLE, "withDoubleFunction"); return setCallback(DuckDBScalarFunctionAdapter.doubleUnary(function), true); @@ -167,6 +167,7 @@ public DuckDBScalarFunctionBuilder withDoubleFunction(DoubleBinaryOperator funct if (function == null) { throw new SQLException("Scalar function callback cannot be null"); } + enablePrimitiveNullPropagation(); ensurePrimitiveCallbackCompatible("withDoubleFunction"); ensureBinaryPrimitiveSignature(DuckDBColumnType.DOUBLE, "withDoubleFunction"); return setCallback(DuckDBScalarFunctionAdapter.doubleBinary(function), true); @@ -177,6 +178,7 @@ public DuckDBScalarFunctionBuilder withLongFunction(LongUnaryOperator function) if (function == null) { throw new SQLException("Scalar function callback cannot be null"); } + enablePrimitiveNullPropagation(); ensurePrimitiveCallbackCompatible("withLongFunction"); ensureUnaryPrimitiveSignature(DuckDBColumnType.BIGINT, "withLongFunction"); return setCallback(DuckDBScalarFunctionAdapter.longUnary(function), true); @@ -187,6 +189,7 @@ public DuckDBScalarFunctionBuilder withLongFunction(LongBinaryOperator function) if (function == null) { throw new SQLException("Scalar function callback cannot be null"); } + enablePrimitiveNullPropagation(); ensurePrimitiveCallbackCompatible("withLongFunction"); ensureBinaryPrimitiveSignature(DuckDBColumnType.BIGINT, "withLongFunction"); return setCallback(DuckDBScalarFunctionAdapter.longBinary(function), true); @@ -405,21 +408,22 @@ private DuckDBScalarFunctionBuilder addMappedParameterType(DuckDBColumnType mapp private DuckDBScalarFunctionBuilder setCallback(DuckDBScalarFunction function, boolean requiresNullPropagation) throws SQLException { this.callback = function; - this.callbackRequiresNullPropagation = requiresNullPropagation; + this.propagateNullsFlag = requiresNullPropagation; duckdb_scalar_function_set_function(scalarFunctionRef, new DuckDBScalarFunctionWrapper(function, propagateNullsFlag)); return this; } private void ensurePrimitiveCallbackCompatible(String callbackMethodName) throws SQLException { - if (!propagateNullsFlag) { - throw new SQLException(callbackMethodName + " requires propagateNulls(true)"); - } if (varArgType != null) { throw new SQLException(callbackMethodName + " does not support varargs; use withVarArgsFunction instead"); } } + private void enablePrimitiveNullPropagation() { + propagateNullsFlag = true; + } + private void ensureUnaryPrimitiveSignature(DuckDBColumnType expectedType, String callbackMethodName) throws SQLException { if (parameterTypes.size() != 1) { diff --git a/src/main/java/org/duckdb/DuckDBScalarRow.java b/src/main/java/org/duckdb/DuckDBScalarRow.java index cd0e10721..5033faaba 100644 --- a/src/main/java/org/duckdb/DuckDBScalarRow.java +++ b/src/main/java/org/duckdb/DuckDBScalarRow.java @@ -2,7 +2,6 @@ import java.math.BigDecimal; import java.math.BigInteger; -import java.sql.SQLException; import java.sql.Timestamp; import java.time.LocalDate; import java.time.LocalDateTime; @@ -28,7 +27,15 @@ public boolean isNull(int columnIndex) { public boolean getBoolean(int columnIndex) { try { return input(columnIndex).getBoolean(rowIndex); - } catch (SQLException exception) { + } catch (DuckDBFunctionException exception) { + throw readFailure("BOOLEAN", columnIndex, exception); + } + } + + public boolean getBoolean(int columnIndex, boolean defaultVal) { + try { + return input(columnIndex).getBoolean(rowIndex, defaultVal); + } catch (DuckDBFunctionException exception) { throw readFailure("BOOLEAN", columnIndex, exception); } } @@ -36,7 +43,15 @@ public boolean getBoolean(int columnIndex) { public byte getByte(int columnIndex) { try { return input(columnIndex).getByte(rowIndex); - } catch (SQLException exception) { + } catch (DuckDBFunctionException exception) { + throw readFailure("TINYINT", columnIndex, exception); + } + } + + public byte getByte(int columnIndex, byte defaultVal) { + try { + return input(columnIndex).getByte(rowIndex, defaultVal); + } catch (DuckDBFunctionException exception) { throw readFailure("TINYINT", columnIndex, exception); } } @@ -44,7 +59,15 @@ public byte getByte(int columnIndex) { public short getShort(int columnIndex) { try { return input(columnIndex).getShort(rowIndex); - } catch (SQLException exception) { + } catch (DuckDBFunctionException exception) { + throw readFailure("SMALLINT", columnIndex, exception); + } + } + + public short getShort(int columnIndex, short defaultVal) { + try { + return input(columnIndex).getShort(rowIndex, defaultVal); + } catch (DuckDBFunctionException exception) { throw readFailure("SMALLINT", columnIndex, exception); } } @@ -52,7 +75,15 @@ public short getShort(int columnIndex) { public short getUint8(int columnIndex) { try { return input(columnIndex).getUint8(rowIndex); - } catch (SQLException exception) { + } catch (DuckDBFunctionException exception) { + throw readFailure("UTINYINT", columnIndex, exception); + } + } + + public short getUint8(int columnIndex, short defaultVal) { + try { + return input(columnIndex).getUint8(rowIndex, defaultVal); + } catch (DuckDBFunctionException exception) { throw readFailure("UTINYINT", columnIndex, exception); } } @@ -60,7 +91,15 @@ public short getUint8(int columnIndex) { public int getUint16(int columnIndex) { try { return input(columnIndex).getUint16(rowIndex); - } catch (SQLException exception) { + } catch (DuckDBFunctionException exception) { + throw readFailure("USMALLINT", columnIndex, exception); + } + } + + public int getUint16(int columnIndex, int defaultVal) { + try { + return input(columnIndex).getUint16(rowIndex, defaultVal); + } catch (DuckDBFunctionException exception) { throw readFailure("USMALLINT", columnIndex, exception); } } @@ -68,7 +107,15 @@ public int getUint16(int columnIndex) { public int getInt(int columnIndex) { try { return input(columnIndex).getInt(rowIndex); - } catch (SQLException exception) { + } catch (DuckDBFunctionException exception) { + throw readFailure("INTEGER", columnIndex, exception); + } + } + + public int getInt(int columnIndex, int defaultVal) { + try { + return input(columnIndex).getInt(rowIndex, defaultVal); + } catch (DuckDBFunctionException exception) { throw readFailure("INTEGER", columnIndex, exception); } } @@ -76,7 +123,15 @@ public int getInt(int columnIndex) { public long getUint32(int columnIndex) { try { return input(columnIndex).getUint32(rowIndex); - } catch (SQLException exception) { + } catch (DuckDBFunctionException exception) { + throw readFailure("UINTEGER", columnIndex, exception); + } + } + + public long getUint32(int columnIndex, long defaultVal) { + try { + return input(columnIndex).getUint32(rowIndex, defaultVal); + } catch (DuckDBFunctionException exception) { throw readFailure("UINTEGER", columnIndex, exception); } } @@ -84,15 +139,39 @@ public long getUint32(int columnIndex) { public long getLong(int columnIndex) { try { return input(columnIndex).getLong(rowIndex); - } catch (SQLException exception) { + } catch (DuckDBFunctionException exception) { throw readFailure("BIGINT", columnIndex, exception); } } + public long getLong(int columnIndex, long defaultVal) { + try { + return input(columnIndex).getLong(rowIndex, defaultVal); + } catch (DuckDBFunctionException exception) { + throw readFailure("BIGINT", columnIndex, exception); + } + } + + public BigInteger getHugeInt(int columnIndex) { + try { + return input(columnIndex).getHugeInt(rowIndex); + } catch (DuckDBFunctionException exception) { + throw readFailure("HUGEINT", columnIndex, exception); + } + } + + public BigInteger getUHugeInt(int columnIndex) { + try { + return input(columnIndex).getUHugeInt(rowIndex); + } catch (DuckDBFunctionException exception) { + throw readFailure("UHUGEINT", columnIndex, exception); + } + } + public BigInteger getUint64(int columnIndex) { try { return input(columnIndex).getUint64(rowIndex); - } catch (SQLException exception) { + } catch (DuckDBFunctionException exception) { throw readFailure("UBIGINT", columnIndex, exception); } } @@ -100,7 +179,15 @@ public BigInteger getUint64(int columnIndex) { public float getFloat(int columnIndex) { try { return input(columnIndex).getFloat(rowIndex); - } catch (SQLException exception) { + } catch (DuckDBFunctionException exception) { + throw readFailure("FLOAT", columnIndex, exception); + } + } + + public float getFloat(int columnIndex, float defaultVal) { + try { + return input(columnIndex).getFloat(rowIndex, defaultVal); + } catch (DuckDBFunctionException exception) { throw readFailure("FLOAT", columnIndex, exception); } } @@ -108,7 +195,15 @@ public float getFloat(int columnIndex) { public double getDouble(int columnIndex) { try { return input(columnIndex).getDouble(rowIndex); - } catch (SQLException exception) { + } catch (DuckDBFunctionException exception) { + throw readFailure("DOUBLE", columnIndex, exception); + } + } + + public double getDouble(int columnIndex, double defaultVal) { + try { + return input(columnIndex).getDouble(rowIndex, defaultVal); + } catch (DuckDBFunctionException exception) { throw readFailure("DOUBLE", columnIndex, exception); } } @@ -116,7 +211,7 @@ public double getDouble(int columnIndex) { public LocalDate getLocalDate(int columnIndex) { try { return input(columnIndex).getLocalDate(rowIndex); - } catch (SQLException exception) { + } catch (DuckDBFunctionException exception) { throw readFailure("DATE", columnIndex, exception); } } @@ -124,7 +219,7 @@ public LocalDate getLocalDate(int columnIndex) { public java.sql.Date getDate(int columnIndex) { try { return input(columnIndex).getDate(rowIndex); - } catch (SQLException exception) { + } catch (DuckDBFunctionException exception) { throw readFailure("DATE", columnIndex, exception); } } @@ -132,7 +227,7 @@ public java.sql.Date getDate(int columnIndex) { public LocalDateTime getLocalDateTime(int columnIndex) { try { return input(columnIndex).getLocalDateTime(rowIndex); - } catch (SQLException exception) { + } catch (DuckDBFunctionException exception) { throw readFailure("TIMESTAMP", columnIndex, exception); } } @@ -140,7 +235,7 @@ public LocalDateTime getLocalDateTime(int columnIndex) { public Timestamp getTimestamp(int columnIndex) { try { return input(columnIndex).getTimestamp(rowIndex); - } catch (SQLException exception) { + } catch (DuckDBFunctionException exception) { throw readFailure("TIMESTAMP", columnIndex, exception); } } @@ -148,7 +243,7 @@ public Timestamp getTimestamp(int columnIndex) { public OffsetDateTime getOffsetDateTime(int columnIndex) { try { return input(columnIndex).getOffsetDateTime(rowIndex); - } catch (SQLException exception) { + } catch (DuckDBFunctionException exception) { throw readFailure("TIMESTAMP WITH TIME ZONE", columnIndex, exception); } } @@ -156,7 +251,7 @@ public OffsetDateTime getOffsetDateTime(int columnIndex) { public BigDecimal getBigDecimal(int columnIndex) { try { return input(columnIndex).getBigDecimal(rowIndex); - } catch (SQLException exception) { + } catch (DuckDBFunctionException exception) { throw readFailure("DECIMAL", columnIndex, exception); } } @@ -164,7 +259,7 @@ public BigDecimal getBigDecimal(int columnIndex) { public String getString(int columnIndex) { try { return input(columnIndex).getString(rowIndex); - } catch (SQLException exception) { + } catch (DuckDBFunctionException exception) { throw readFailure("VARCHAR", columnIndex, exception); } } @@ -172,7 +267,7 @@ public String getString(int columnIndex) { public void setNull() { try { context.output().setNull(rowIndex); - } catch (SQLException exception) { + } catch (DuckDBFunctionException exception) { throw writeFailure("NULL", exception); } } @@ -180,7 +275,7 @@ public void setNull() { public void setBoolean(boolean value) { try { context.output().setBoolean(rowIndex, value); - } catch (SQLException exception) { + } catch (DuckDBFunctionException exception) { throw writeFailure("BOOLEAN", exception); } } @@ -188,7 +283,7 @@ public void setBoolean(boolean value) { public void setByte(byte value) { try { context.output().setByte(rowIndex, value); - } catch (SQLException exception) { + } catch (DuckDBFunctionException exception) { throw writeFailure("TINYINT", exception); } } @@ -196,7 +291,7 @@ public void setByte(byte value) { public void setShort(short value) { try { context.output().setShort(rowIndex, value); - } catch (SQLException exception) { + } catch (DuckDBFunctionException exception) { throw writeFailure("SMALLINT", exception); } } @@ -204,7 +299,7 @@ public void setShort(short value) { public void setUint8(int value) { try { context.output().setUint8(rowIndex, value); - } catch (SQLException exception) { + } catch (DuckDBFunctionException exception) { throw writeFailure("UTINYINT", exception); } } @@ -212,7 +307,7 @@ public void setUint8(int value) { public void setUint16(int value) { try { context.output().setUint16(rowIndex, value); - } catch (SQLException exception) { + } catch (DuckDBFunctionException exception) { throw writeFailure("USMALLINT", exception); } } @@ -220,7 +315,7 @@ public void setUint16(int value) { public void setInt(int value) { try { context.output().setInt(rowIndex, value); - } catch (SQLException exception) { + } catch (DuckDBFunctionException exception) { throw writeFailure("INTEGER", exception); } } @@ -228,7 +323,7 @@ public void setInt(int value) { public void setUint32(long value) { try { context.output().setUint32(rowIndex, value); - } catch (SQLException exception) { + } catch (DuckDBFunctionException exception) { throw writeFailure("UINTEGER", exception); } } @@ -236,15 +331,31 @@ public void setUint32(long value) { public void setLong(long value) { try { context.output().setLong(rowIndex, value); - } catch (SQLException exception) { + } catch (DuckDBFunctionException exception) { throw writeFailure("BIGINT", exception); } } + public void setHugeInt(BigInteger value) { + try { + context.output().setHugeInt(rowIndex, value); + } catch (DuckDBFunctionException exception) { + throw writeFailure("HUGEINT", exception); + } + } + + public void setUHugeInt(BigInteger value) { + try { + context.output().setUHugeInt(rowIndex, value); + } catch (DuckDBFunctionException exception) { + throw writeFailure("UHUGEINT", exception); + } + } + public void setUint64(BigInteger value) { try { context.output().setUint64(rowIndex, value); - } catch (SQLException exception) { + } catch (DuckDBFunctionException exception) { throw writeFailure("UBIGINT", exception); } } @@ -252,7 +363,7 @@ public void setUint64(BigInteger value) { public void setFloat(float value) { try { context.output().setFloat(rowIndex, value); - } catch (SQLException exception) { + } catch (DuckDBFunctionException exception) { throw writeFailure("FLOAT", exception); } } @@ -260,7 +371,7 @@ public void setFloat(float value) { public void setDouble(double value) { try { context.output().setDouble(rowIndex, value); - } catch (SQLException exception) { + } catch (DuckDBFunctionException exception) { throw writeFailure("DOUBLE", exception); } } @@ -268,7 +379,7 @@ public void setDouble(double value) { public void setDate(LocalDate value) { try { context.output().setDate(rowIndex, value); - } catch (SQLException exception) { + } catch (DuckDBFunctionException exception) { throw writeFailure("DATE", exception); } } @@ -276,7 +387,7 @@ public void setDate(LocalDate value) { public void setDate(java.sql.Date value) { try { context.output().setDate(rowIndex, value); - } catch (SQLException exception) { + } catch (DuckDBFunctionException exception) { throw writeFailure("DATE", exception); } } @@ -284,7 +395,7 @@ public void setDate(java.sql.Date value) { public void setDate(java.util.Date value) { try { context.output().setDate(rowIndex, value); - } catch (SQLException exception) { + } catch (DuckDBFunctionException exception) { throw writeFailure("DATE", exception); } } @@ -292,7 +403,7 @@ public void setDate(java.util.Date value) { public void setTimestamp(LocalDateTime value) { try { context.output().setTimestamp(rowIndex, value); - } catch (SQLException exception) { + } catch (DuckDBFunctionException exception) { throw writeFailure("TIMESTAMP", exception); } } @@ -300,7 +411,7 @@ public void setTimestamp(LocalDateTime value) { public void setTimestamp(Timestamp value) { try { context.output().setTimestamp(rowIndex, value); - } catch (SQLException exception) { + } catch (DuckDBFunctionException exception) { throw writeFailure("TIMESTAMP", exception); } } @@ -308,7 +419,7 @@ public void setTimestamp(Timestamp value) { public void setTimestamp(java.util.Date value) { try { context.output().setTimestamp(rowIndex, value); - } catch (SQLException exception) { + } catch (DuckDBFunctionException exception) { throw writeFailure("TIMESTAMP", exception); } } @@ -316,7 +427,7 @@ public void setTimestamp(java.util.Date value) { public void setTimestamp(LocalDate value) { try { context.output().setTimestamp(rowIndex, value); - } catch (SQLException exception) { + } catch (DuckDBFunctionException exception) { throw writeFailure("TIMESTAMP", exception); } } @@ -324,7 +435,7 @@ public void setTimestamp(LocalDate value) { public void setOffsetDateTime(OffsetDateTime value) { try { context.output().setOffsetDateTime(rowIndex, value); - } catch (SQLException exception) { + } catch (DuckDBFunctionException exception) { throw writeFailure("TIMESTAMP WITH TIME ZONE", exception); } } @@ -332,7 +443,7 @@ public void setOffsetDateTime(OffsetDateTime value) { public void setBigDecimal(BigDecimal value) { try { context.output().setBigDecimal(rowIndex, value); - } catch (SQLException exception) { + } catch (DuckDBFunctionException exception) { throw writeFailure("DECIMAL", exception); } } @@ -340,7 +451,7 @@ public void setBigDecimal(BigDecimal value) { public void setString(String value) { try { context.output().setString(rowIndex, value); - } catch (SQLException exception) { + } catch (DuckDBFunctionException exception) { throw writeFailure("VARCHAR", exception); } } @@ -349,12 +460,13 @@ private DuckDBReadableVector input(int columnIndex) { return context.inputUnchecked(columnIndex); } - private IllegalStateException readFailure(String type, int columnIndex, SQLException exception) { - return new IllegalStateException( + private DuckDBFunctionException readFailure(String type, int columnIndex, DuckDBFunctionException exception) { + return new DuckDBFunctionException( "Failed to read " + type + " from input column " + columnIndex + " at row " + rowIndex, exception); } - private IllegalStateException writeFailure(String type, SQLException exception) { - return new IllegalStateException("Failed to write " + type + " to output row " + rowIndex, exception); + private DuckDBFunctionException writeFailure(String type, DuckDBFunctionException exception) { + return new DuckDBFunctionException("Failed to write " + type + " to output row " + rowIndex, exception); } + } diff --git a/src/main/java/org/duckdb/DuckDBVectorTypeInfo.java b/src/main/java/org/duckdb/DuckDBVectorTypeInfo.java index 909abd6fd..a7768cc96 100644 --- a/src/main/java/org/duckdb/DuckDBVectorTypeInfo.java +++ b/src/main/java/org/duckdb/DuckDBVectorTypeInfo.java @@ -48,8 +48,12 @@ static DuckDBVectorTypeInfo fromVector(ByteBuffer vectorRef) throws SQLException return new DuckDBVectorTypeInfo(DuckDBColumnType.UINTEGER, capiType, capiType, 4, null); case DUCKDB_TYPE_BIGINT: return new DuckDBVectorTypeInfo(DuckDBColumnType.BIGINT, capiType, capiType, 8, null); + case DUCKDB_TYPE_HUGEINT: + return new DuckDBVectorTypeInfo(DuckDBColumnType.HUGEINT, capiType, capiType, 16, null); case DUCKDB_TYPE_UBIGINT: return new DuckDBVectorTypeInfo(DuckDBColumnType.UBIGINT, capiType, capiType, 8, null); + case DUCKDB_TYPE_UHUGEINT: + return new DuckDBVectorTypeInfo(DuckDBColumnType.UHUGEINT, capiType, capiType, 16, null); case DUCKDB_TYPE_FLOAT: return new DuckDBVectorTypeInfo(DuckDBColumnType.FLOAT, capiType, capiType, 4, null); case DUCKDB_TYPE_DOUBLE: diff --git a/src/main/java/org/duckdb/DuckDBWritableVector.java b/src/main/java/org/duckdb/DuckDBWritableVector.java index d1532a976..1179e9084 100644 --- a/src/main/java/org/duckdb/DuckDBWritableVector.java +++ b/src/main/java/org/duckdb/DuckDBWritableVector.java @@ -2,102 +2,115 @@ import java.math.BigDecimal; import java.math.BigInteger; -import java.sql.SQLException; import java.sql.Timestamp; import java.time.LocalDate; import java.time.LocalDateTime; import java.time.OffsetDateTime; -public interface DuckDBWritableVector { - DuckDBColumnType getType(); +/** + * Mutable scalar callback view over a DuckDB output vector. + * + *

Implementations throw {@link DuckDBFunctionException} for callback-time type/value write + * errors. Invalid row indexes throw {@link IndexOutOfBoundsException}. + */ +public abstract class DuckDBWritableVector { + public abstract DuckDBColumnType getType(); - long rowCount(); + public abstract long rowCount(); - void addNull() throws SQLException; + public abstract void addNull(); - void setNull(long row) throws SQLException; + public abstract void setNull(long row); - void addBoolean(boolean value) throws SQLException; + public abstract void addBoolean(boolean value); - void setBoolean(long row, boolean value) throws SQLException; + public abstract void setBoolean(long row, boolean value); - void addByte(byte value) throws SQLException; + public abstract void addByte(byte value); - void setByte(long row, byte value) throws SQLException; + public abstract void setByte(long row, byte value); - void addShort(short value) throws SQLException; + public abstract void addShort(short value); - void setShort(long row, short value) throws SQLException; + public abstract void setShort(long row, short value); - void addUint8(int value) throws SQLException; + public abstract void addUint8(int value); - void setUint8(long row, int value) throws SQLException; + public abstract void setUint8(long row, int value); - void addUint16(int value) throws SQLException; + public abstract void addUint16(int value); - void setUint16(long row, int value) throws SQLException; + public abstract void setUint16(long row, int value); - void addInt(int value) throws SQLException; + public abstract void addInt(int value); - void setInt(long row, int value) throws SQLException; + public abstract void setInt(long row, int value); - void addUint32(long value) throws SQLException; + public abstract void addUint32(long value); - void setUint32(long row, long value) throws SQLException; + public abstract void setUint32(long row, long value); - void addLong(long value) throws SQLException; + public abstract void addLong(long value); - void setLong(long row, long value) throws SQLException; + public abstract void setLong(long row, long value); - void addUint64(BigInteger value) throws SQLException; + public abstract void addHugeInt(BigInteger value); - void setUint64(long row, BigInteger value) throws SQLException; + public abstract void setHugeInt(long row, BigInteger value); - void addFloat(float value) throws SQLException; + public abstract void addUHugeInt(BigInteger value); - void setFloat(long row, float value) throws SQLException; + public abstract void setUHugeInt(long row, BigInteger value); - void addDouble(double value) throws SQLException; + public abstract void addUint64(BigInteger value); - void setDouble(long row, double value) throws SQLException; + public abstract void setUint64(long row, BigInteger value); - void addDate(LocalDate value) throws SQLException; + public abstract void addFloat(float value); - void setDate(long row, LocalDate value) throws SQLException; + public abstract void setFloat(long row, float value); - void addDate(java.sql.Date value) throws SQLException; + public abstract void addDouble(double value); - void setDate(long row, java.sql.Date value) throws SQLException; + public abstract void setDouble(long row, double value); - void addDate(java.util.Date value) throws SQLException; + public abstract void addDate(LocalDate value); - void setDate(long row, java.util.Date value) throws SQLException; + public abstract void setDate(long row, LocalDate value); - void addTimestamp(LocalDateTime value) throws SQLException; + public abstract void addDate(java.sql.Date value); - void setTimestamp(long row, LocalDateTime value) throws SQLException; + public abstract void setDate(long row, java.sql.Date value); - void addTimestamp(Timestamp value) throws SQLException; + public abstract void addDate(java.util.Date value); - void setTimestamp(long row, Timestamp value) throws SQLException; + public abstract void setDate(long row, java.util.Date value); - void addTimestamp(java.util.Date value) throws SQLException; + public abstract void addTimestamp(LocalDateTime value); - void setTimestamp(long row, java.util.Date value) throws SQLException; + public abstract void setTimestamp(long row, LocalDateTime value); - void addTimestamp(LocalDate value) throws SQLException; + public abstract void addTimestamp(Timestamp value); - void setTimestamp(long row, LocalDate value) throws SQLException; + public abstract void setTimestamp(long row, Timestamp value); - void addOffsetDateTime(OffsetDateTime value) throws SQLException; + public abstract void addTimestamp(java.util.Date value); - void setOffsetDateTime(long row, OffsetDateTime value) throws SQLException; + public abstract void setTimestamp(long row, java.util.Date value); - void addBigDecimal(BigDecimal value) throws SQLException; + public abstract void addTimestamp(LocalDate value); - void setBigDecimal(long row, BigDecimal value) throws SQLException; + public abstract void setTimestamp(long row, LocalDate value); - void addString(String value) throws SQLException; + public abstract void addOffsetDateTime(OffsetDateTime value); - void setString(long row, String value) throws SQLException; + public abstract void setOffsetDateTime(long row, OffsetDateTime value); + + public abstract void addBigDecimal(BigDecimal value); + + public abstract void setBigDecimal(long row, BigDecimal value); + + public abstract void addString(String value); + + public abstract void setString(long row, String value); } diff --git a/src/main/java/org/duckdb/DuckDBWritableVectorImpl.java b/src/main/java/org/duckdb/DuckDBWritableVectorImpl.java index 4b5a68e47..1b7e8cb13 100644 --- a/src/main/java/org/duckdb/DuckDBWritableVectorImpl.java +++ b/src/main/java/org/duckdb/DuckDBWritableVectorImpl.java @@ -7,8 +7,6 @@ import java.math.BigInteger; import java.nio.ByteBuffer; import java.nio.ByteOrder; -import java.nio.LongBuffer; -import java.sql.SQLException; import java.sql.Timestamp; import java.time.Instant; import java.time.LocalDate; @@ -17,7 +15,7 @@ import java.time.ZoneId; import java.time.ZoneOffset; -final class DuckDBWritableVectorImpl implements DuckDBWritableVector { +final class DuckDBWritableVectorImpl extends DuckDBWritableVector { private static final BigInteger UINT64_MAX = new BigInteger("18446744073709551615"); private static final ByteOrder NATIVE_ORDER = ByteOrder.nativeOrder(); @@ -28,15 +26,20 @@ final class DuckDBWritableVectorImpl implements DuckDBWritableVector { private ByteBuffer validity; private long appendIndex; - DuckDBWritableVectorImpl(ByteBuffer vectorRef, long rowCount) throws SQLException { + DuckDBWritableVectorImpl(ByteBuffer vectorRef, long rowCount) { if (vectorRef == null) { - throw new SQLException("Invalid vector reference"); + throw new DuckDBFunctionException("Invalid vector reference"); } this.vectorRef = vectorRef; this.rowCount = rowCount; - this.typeInfo = DuckDBVectorTypeInfo.fromVector(vectorRef); - this.data = duckdb_vector_get_data(vectorRef, Math.multiplyExact(rowCount, typeInfo.widthBytes)); - this.validity = duckdb_vector_get_validity(vectorRef, rowCount); + try { + this.typeInfo = DuckDBVectorTypeInfo.fromVector(vectorRef); + } catch (java.sql.SQLException exception) { + throw new DuckDBFunctionException("Failed to resolve vector type info", exception); + } + this.data = duckdb_vector_get_data(vectorRef, Math.multiplyExact(rowCount, typeInfo.widthBytes)).order(NATIVE_ORDER); + ByteBuffer validityBuffer = duckdb_vector_get_validity(vectorRef, rowCount); + this.validity = validityBuffer == null ? null : validityBuffer.order(NATIVE_ORDER); } @Override @@ -50,12 +53,12 @@ public long rowCount() { } @Override - public void addNull() throws SQLException { + public void addNull() { setNull(nextAppendRow()); } @Override - public void setNull(long row) throws SQLException { + public void setNull(long row) { String rowError = rowIndexErrorMessage(row); if (rowError != null) { throw new IndexOutOfBoundsException(rowError); @@ -66,247 +69,306 @@ public void setNull(long row) throws SQLException { } @Override - public void addBoolean(boolean value) throws SQLException { + public void addBoolean(boolean value) { setBoolean(nextAppendRow(), value); } @Override - public void setBoolean(long row, boolean value) throws SQLException { + public void setBoolean(long row, boolean value) { String rowError = rowIndexErrorMessage(row); if (rowError != null) { throw new IndexOutOfBoundsException(rowError); } String typeError = typeMismatchMessage(DuckDBColumnType.BOOLEAN); if (typeError != null) { - throw new SQLException(typeError); + throw new DuckDBFunctionException(typeError); } data.put(checkedRowIndex(row), value ? (byte) 1 : (byte) 0); markValid(row); } @Override - public void addByte(byte value) throws SQLException { + public void addByte(byte value) { setByte(nextAppendRow(), value); } @Override - public void setByte(long row, byte value) throws SQLException { + public void setByte(long row, byte value) { String rowError = rowIndexErrorMessage(row); if (rowError != null) { throw new IndexOutOfBoundsException(rowError); } String typeError = typeMismatchMessage(DuckDBColumnType.TINYINT); if (typeError != null) { - throw new SQLException(typeError); + throw new DuckDBFunctionException(typeError); } data.put(checkedRowIndex(row), value); markValid(row); } @Override - public void addShort(short value) throws SQLException { + public void addShort(short value) { setShort(nextAppendRow(), value); } @Override - public void setShort(long row, short value) throws SQLException { + public void setShort(long row, short value) { String rowError = rowIndexErrorMessage(row); if (rowError != null) { throw new IndexOutOfBoundsException(rowError); } String typeError = typeMismatchMessage(DuckDBColumnType.SMALLINT); if (typeError != null) { - throw new SQLException(typeError); + throw new DuckDBFunctionException(typeError); } - data.order(NATIVE_ORDER).putShort(checkedByteOffset(row, Short.BYTES), value); + data.putShort(checkedByteOffset(row, Short.BYTES), value); markValid(row); } @Override - public void addUint8(int value) throws SQLException { + public void addUint8(int value) { setUint8(nextAppendRow(), value); } @Override - public void setUint8(long row, int value) throws SQLException { + public void setUint8(long row, int value) { String rowError = rowIndexErrorMessage(row); if (rowError != null) { throw new IndexOutOfBoundsException(rowError); } String typeError = typeMismatchMessage(DuckDBColumnType.UTINYINT); if (typeError != null) { - throw new SQLException(typeError); + throw new DuckDBFunctionException(typeError); } String rangeError = unsignedRangeErrorMessage("UTINYINT", value, 0xFFL); if (rangeError != null) { - throw new SQLException(rangeError); + throw new DuckDBFunctionException(rangeError); } data.put(checkedRowIndex(row), (byte) value); markValid(row); } @Override - public void addUint16(int value) throws SQLException { + public void addUint16(int value) { setUint16(nextAppendRow(), value); } @Override - public void setUint16(long row, int value) throws SQLException { + public void setUint16(long row, int value) { String rowError = rowIndexErrorMessage(row); if (rowError != null) { throw new IndexOutOfBoundsException(rowError); } String typeError = typeMismatchMessage(DuckDBColumnType.USMALLINT); if (typeError != null) { - throw new SQLException(typeError); + throw new DuckDBFunctionException(typeError); } String rangeError = unsignedRangeErrorMessage("USMALLINT", value, 0xFFFFL); if (rangeError != null) { - throw new SQLException(rangeError); + throw new DuckDBFunctionException(rangeError); } - data.order(NATIVE_ORDER).putShort(checkedByteOffset(row, Short.BYTES), (short) value); + data.putShort(checkedByteOffset(row, Short.BYTES), (short) value); markValid(row); } @Override - public void addInt(int value) throws SQLException { + public void addInt(int value) { setInt(nextAppendRow(), value); } @Override - public void setInt(long row, int value) throws SQLException { + public void setInt(long row, int value) { String rowError = rowIndexErrorMessage(row); if (rowError != null) { throw new IndexOutOfBoundsException(rowError); } String typeError = typeMismatchMessage(DuckDBColumnType.INTEGER); if (typeError != null) { - throw new SQLException(typeError); + throw new DuckDBFunctionException(typeError); } - data.order(NATIVE_ORDER).putInt(checkedByteOffset(row, Integer.BYTES), value); + data.putInt(checkedByteOffset(row, Integer.BYTES), value); markValid(row); } @Override - public void addUint32(long value) throws SQLException { + public void addUint32(long value) { setUint32(nextAppendRow(), value); } @Override - public void setUint32(long row, long value) throws SQLException { + public void setUint32(long row, long value) { String rowError = rowIndexErrorMessage(row); if (rowError != null) { throw new IndexOutOfBoundsException(rowError); } String typeError = typeMismatchMessage(DuckDBColumnType.UINTEGER); if (typeError != null) { - throw new SQLException(typeError); + throw new DuckDBFunctionException(typeError); } String rangeError = unsignedRangeErrorMessage("UINTEGER", value, 0xFFFFFFFFL); if (rangeError != null) { - throw new SQLException(rangeError); + throw new DuckDBFunctionException(rangeError); } - data.order(NATIVE_ORDER).putInt(checkedByteOffset(row, Integer.BYTES), (int) value); + data.putInt(checkedByteOffset(row, Integer.BYTES), (int) value); markValid(row); } @Override - public void addLong(long value) throws SQLException { + public void addLong(long value) { setLong(nextAppendRow(), value); } @Override - public void setLong(long row, long value) throws SQLException { + public void setLong(long row, long value) { String rowError = rowIndexErrorMessage(row); if (rowError != null) { throw new IndexOutOfBoundsException(rowError); } String typeError = typeMismatchMessage(DuckDBColumnType.BIGINT); if (typeError != null) { - throw new SQLException(typeError); + throw new DuckDBFunctionException(typeError); + } + data.putLong(checkedByteOffset(row, Long.BYTES), value); + markValid(row); + } + + @Override + public void addHugeInt(BigInteger value) { + setHugeInt(nextAppendRow(), value); + } + + @Override + public void setHugeInt(long row, BigInteger value) { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = typeMismatchMessage(DuckDBColumnType.HUGEINT); + if (typeError != null) { + throw new DuckDBFunctionException(typeError); + } + if (value == null) { + setNull(row); + return; + } + DuckDBHugeInt hugeInt; + try { + hugeInt = new DuckDBHugeInt(value); + } catch (java.sql.SQLException exception) { + throw new DuckDBFunctionException("Value out of range for HUGEINT: " + value, exception); + } + int offset = checkedByteOffset(row, typeInfo.widthBytes); + data.putLong(offset, hugeInt.lower()); + data.putLong(offset + Long.BYTES, hugeInt.upper()); + markValid(row); + } + + @Override + public void addUHugeInt(BigInteger value) { + setUHugeInt(nextAppendRow(), value); + } + + @Override + public void setUHugeInt(long row, BigInteger value) { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = typeMismatchMessage(DuckDBColumnType.UHUGEINT); + if (typeError != null) { + throw new DuckDBFunctionException(typeError); + } + if (value == null) { + setNull(row); + return; + } + if (value.signum() < 0 || value.compareTo(DuckDBHugeInt.UHUGE_INT_MAX) > 0) { + throw new DuckDBFunctionException("Value out of range for UHUGEINT: " + value); } - data.order(NATIVE_ORDER).putLong(checkedByteOffset(row, Long.BYTES), value); + int offset = checkedByteOffset(row, typeInfo.widthBytes); + data.putLong(offset, value.longValue()); + data.putLong(offset + Long.BYTES, value.shiftRight(Long.SIZE).longValue()); markValid(row); } @Override - public void addUint64(BigInteger value) throws SQLException { + public void addUint64(BigInteger value) { setUint64(nextAppendRow(), value); } @Override - public void setUint64(long row, BigInteger value) throws SQLException { + public void setUint64(long row, BigInteger value) { String rowError = rowIndexErrorMessage(row); if (rowError != null) { throw new IndexOutOfBoundsException(rowError); } String typeError = typeMismatchMessage(DuckDBColumnType.UBIGINT); if (typeError != null) { - throw new SQLException(typeError); + throw new DuckDBFunctionException(typeError); } if (value == null) { setNull(row); return; } if (value.signum() < 0 || value.compareTo(UINT64_MAX) > 0) { - throw new SQLException("Value out of range for UBIGINT: " + value); + throw new DuckDBFunctionException("Value out of range for UBIGINT: " + value); } - data.order(NATIVE_ORDER).putLong(checkedByteOffset(row, Long.BYTES), value.longValue()); + data.putLong(checkedByteOffset(row, Long.BYTES), value.longValue()); markValid(row); } @Override - public void addFloat(float value) throws SQLException { + public void addFloat(float value) { setFloat(nextAppendRow(), value); } @Override - public void setFloat(long row, float value) throws SQLException { + public void setFloat(long row, float value) { String rowError = rowIndexErrorMessage(row); if (rowError != null) { throw new IndexOutOfBoundsException(rowError); } String typeError = typeMismatchMessage(DuckDBColumnType.FLOAT); if (typeError != null) { - throw new SQLException(typeError); + throw new DuckDBFunctionException(typeError); } - data.order(NATIVE_ORDER).putFloat(checkedByteOffset(row, Float.BYTES), value); + data.putFloat(checkedByteOffset(row, Float.BYTES), value); markValid(row); } @Override - public void addDouble(double value) throws SQLException { + public void addDouble(double value) { setDouble(nextAppendRow(), value); } @Override - public void setDouble(long row, double value) throws SQLException { + public void setDouble(long row, double value) { String rowError = rowIndexErrorMessage(row); if (rowError != null) { throw new IndexOutOfBoundsException(rowError); } String typeError = typeMismatchMessage(DuckDBColumnType.DOUBLE); if (typeError != null) { - throw new SQLException(typeError); + throw new DuckDBFunctionException(typeError); } - data.order(NATIVE_ORDER).putDouble(checkedByteOffset(row, Double.BYTES), value); + data.putDouble(checkedByteOffset(row, Double.BYTES), value); markValid(row); } @Override - public void addDate(LocalDate value) throws SQLException { + public void addDate(LocalDate value) { setDate(nextAppendRow(), value); } @Override - public void setDate(long row, LocalDate value) throws SQLException { + public void setDate(long row, LocalDate value) { String rowError = rowIndexErrorMessage(row); if (rowError != null) { throw new IndexOutOfBoundsException(rowError); } String typeError = typeMismatchMessage(DuckDBColumnType.DATE); if (typeError != null) { - throw new SQLException(typeError); + throw new DuckDBFunctionException(typeError); } if (value == null) { setNull(row); @@ -314,29 +376,29 @@ public void setDate(long row, LocalDate value) throws SQLException { } long days = value.toEpochDay(); if (days < Integer.MIN_VALUE || days > Integer.MAX_VALUE) { - throw new SQLException("Value out of range for DATE: " + value); + throw new DuckDBFunctionException("Value out of range for DATE: " + value); } - data.order(NATIVE_ORDER).putInt(checkedByteOffset(row, Integer.BYTES), (int) days); + data.putInt(checkedByteOffset(row, Integer.BYTES), (int) days); markValid(row); } @Override - public void addDate(java.sql.Date value) throws SQLException { + public void addDate(java.sql.Date value) { setDate(nextAppendRow(), value); } @Override - public void setDate(long row, java.sql.Date value) throws SQLException { + public void setDate(long row, java.sql.Date value) { setDate(row, value == null ? null : value.toLocalDate()); } @Override - public void addDate(java.util.Date value) throws SQLException { + public void addDate(java.util.Date value) { setDate(nextAppendRow(), value); } @Override - public void setDate(long row, java.util.Date value) throws SQLException { + public void setDate(long row, java.util.Date value) { if (value == null) { setNull(row); return; @@ -350,35 +412,35 @@ public void setDate(long row, java.util.Date value) throws SQLException { } @Override - public void addTimestamp(LocalDateTime value) throws SQLException { + public void addTimestamp(LocalDateTime value) { setTimestamp(nextAppendRow(), value); } @Override - public void setTimestamp(long row, LocalDateTime value) throws SQLException { + public void setTimestamp(long row, LocalDateTime value) { String rowError = rowIndexErrorMessage(row); if (rowError != null) { throw new IndexOutOfBoundsException(rowError); } String typeError = timestampTypeMismatchMessage(false); if (typeError != null) { - throw new SQLException(typeError); + throw new DuckDBFunctionException(typeError); } if (value == null) { setNull(row); return; } - data.order(NATIVE_ORDER).putLong(checkedByteOffset(row, Long.BYTES), encodeLocalDateTime(value)); + data.putLong(checkedByteOffset(row, Long.BYTES), encodeLocalDateTime(value)); markValid(row); } @Override - public void addTimestamp(Timestamp value) throws SQLException { + public void addTimestamp(Timestamp value) { setTimestamp(nextAppendRow(), value); } @Override - public void setTimestamp(long row, Timestamp value) throws SQLException { + public void setTimestamp(long row, Timestamp value) { if (value == null) { setNull(row); return; @@ -388,7 +450,7 @@ public void setTimestamp(long row, Timestamp value) throws SQLException { if (rowError != null) { throw new IndexOutOfBoundsException(rowError); } - data.order(NATIVE_ORDER).putLong(checkedByteOffset(row, Long.BYTES), encodeInstant(value.toInstant())); + data.putLong(checkedByteOffset(row, Long.BYTES), encodeInstant(value.toInstant())); markValid(row); return; } @@ -396,19 +458,19 @@ public void setTimestamp(long row, Timestamp value) throws SQLException { } @Override - public void addTimestamp(java.util.Date value) throws SQLException { + public void addTimestamp(java.util.Date value) { setTimestamp(nextAppendRow(), value); } @Override - public void setTimestamp(long row, java.util.Date value) throws SQLException { + public void setTimestamp(long row, java.util.Date value) { String rowError = rowIndexErrorMessage(row); if (rowError != null) { throw new IndexOutOfBoundsException(rowError); } String typeError = timestampTypeMismatchMessage(false); if (typeError != null) { - throw new SQLException(typeError); + throw new DuckDBFunctionException(typeError); } if (value == null) { setNull(row); @@ -418,17 +480,17 @@ public void setTimestamp(long row, java.util.Date value) throws SQLException { setTimestamp(row, (Timestamp) value); return; } - data.order(NATIVE_ORDER).putLong(checkedByteOffset(row, Long.BYTES), encodeJavaUtilDate(value)); + data.putLong(checkedByteOffset(row, Long.BYTES), encodeJavaUtilDate(value)); markValid(row); } @Override - public void addTimestamp(LocalDate value) throws SQLException { + public void addTimestamp(LocalDate value) { setTimestamp(nextAppendRow(), value); } @Override - public void setTimestamp(long row, LocalDate value) throws SQLException { + public void setTimestamp(long row, LocalDate value) { if (value == null) { setNull(row); return; @@ -439,7 +501,7 @@ public void setTimestamp(long row, LocalDate value) throws SQLException { throw new IndexOutOfBoundsException(rowError); } Instant instant = value.atStartOfDay(ZoneId.systemDefault()).toInstant(); - data.order(NATIVE_ORDER).putLong(checkedByteOffset(row, Long.BYTES), encodeInstant(instant)); + data.putLong(checkedByteOffset(row, Long.BYTES), encodeInstant(instant)); markValid(row); return; } @@ -447,45 +509,43 @@ public void setTimestamp(long row, LocalDate value) throws SQLException { } @Override - public void addOffsetDateTime(OffsetDateTime value) throws SQLException { + public void addOffsetDateTime(OffsetDateTime value) { setOffsetDateTime(nextAppendRow(), value); } @Override - public void setOffsetDateTime(long row, OffsetDateTime value) throws SQLException { + public void setOffsetDateTime(long row, OffsetDateTime value) { String rowError = rowIndexErrorMessage(row); if (rowError != null) { throw new IndexOutOfBoundsException(rowError); } String typeError = timestampTypeMismatchMessage(true); if (typeError != null) { - throw new SQLException(typeError); + throw new DuckDBFunctionException(typeError); } if (value == null) { setNull(row); return; } - data.order(NATIVE_ORDER) - .putLong( - checkedByteOffset(row, Long.BYTES), - DuckDBTimestamp.localDateTime2Micros(value.withOffsetSameInstant(ZoneOffset.UTC).toLocalDateTime())); + data.putLong(checkedByteOffset(row, Long.BYTES), + DuckDBTimestamp.localDateTime2Micros(value.withOffsetSameInstant(ZoneOffset.UTC).toLocalDateTime())); markValid(row); } @Override - public void addBigDecimal(BigDecimal value) throws SQLException { + public void addBigDecimal(BigDecimal value) { setBigDecimal(nextAppendRow(), value); } @Override - public void setBigDecimal(long row, BigDecimal value) throws SQLException { + public void setBigDecimal(long row, BigDecimal value) { String rowError = rowIndexErrorMessage(row); if (rowError != null) { throw new IndexOutOfBoundsException(rowError); } String typeError = typeMismatchMessage(DuckDBColumnType.DECIMAL); if (typeError != null) { - throw new SQLException(typeError); + throw new DuckDBFunctionException(typeError); } if (value == null) { setNull(row); @@ -503,56 +563,52 @@ public void setBigDecimal(long row, BigDecimal value) throws SQLException { switch (typeInfo.storageType) { case DUCKDB_TYPE_SMALLINT: try { - data.order(NATIVE_ORDER) - .putShort(checkedByteOffset(row, Short.BYTES), scaled.unscaledValue().shortValueExact()); + data.putShort(checkedByteOffset(row, Short.BYTES), scaled.unscaledValue().shortValueExact()); } catch (ArithmeticException e) { throw decimalOutOfRange(value, e); } break; case DUCKDB_TYPE_INTEGER: try { - data.order(NATIVE_ORDER) - .putInt(checkedByteOffset(row, Integer.BYTES), scaled.unscaledValue().intValueExact()); + data.putInt(checkedByteOffset(row, Integer.BYTES), scaled.unscaledValue().intValueExact()); } catch (ArithmeticException e) { throw decimalOutOfRange(value, e); } break; case DUCKDB_TYPE_BIGINT: try { - data.order(NATIVE_ORDER) - .putLong(checkedByteOffset(row, Long.BYTES), scaled.unscaledValue().longValueExact()); + data.putLong(checkedByteOffset(row, Long.BYTES), scaled.unscaledValue().longValueExact()); } catch (ArithmeticException e) { throw decimalOutOfRange(value, e); } break; case DUCKDB_TYPE_HUGEINT: { BigInteger unscaled = scaled.unscaledValue(); - ByteBuffer slice = data.duplicate().order(NATIVE_ORDER); - slice.position(checkedByteOffset(row, typeInfo.widthBytes)); - slice.putLong(unscaled.longValue()); - slice.putLong(unscaled.shiftRight(Long.SIZE).longValue()); + int offset = checkedByteOffset(row, typeInfo.widthBytes); + data.putLong(offset, unscaled.longValue()); + data.putLong(offset + Long.BYTES, unscaled.shiftRight(Long.SIZE).longValue()); break; } default: - throw new SQLException("Unsupported DECIMAL storage type: " + typeInfo.storageType); + throw new DuckDBFunctionException("Unsupported DECIMAL storage type: " + typeInfo.storageType); } markValid(row); } @Override - public void addString(String value) throws SQLException { + public void addString(String value) { setString(nextAppendRow(), value); } @Override - public void setString(long row, String value) throws SQLException { + public void setString(long row, String value) { String rowError = rowIndexErrorMessage(row); if (rowError != null) { throw new IndexOutOfBoundsException(rowError); } String typeError = typeMismatchMessage(DuckDBColumnType.VARCHAR); if (typeError != null) { - throw new SQLException(typeError); + throw new DuckDBFunctionException(typeError); } if (value == null) { setNull(row); @@ -566,15 +622,16 @@ ByteBuffer vectorRef() { return vectorRef; } - private void ensureValidity() throws SQLException { + private void ensureValidity() { if (validity != null) { return; } duckdb_vector_ensure_validity_writable(vectorRef); validity = duckdb_vector_get_validity(vectorRef, rowCount); if (validity == null) { - throw new SQLException("Cannot initialize vector validity"); + throw new DuckDBFunctionException("Cannot initialize vector validity"); } + validity = validity.order(NATIVE_ORDER); } private void markValid(long row) { @@ -587,17 +644,16 @@ private void markValid(long row) { } private void setRowValidity(long row, boolean valid) { - LongBuffer entries = validity.asLongBuffer(); - int entryIndex = Math.toIntExact(row / Long.SIZE); + int entryOffset = Math.toIntExact(Math.multiplyExact(row / Long.SIZE, (long) Long.BYTES)); long bitIndex = row % Long.SIZE; long mask = 1L << bitIndex; - long entry = entries.get(entryIndex); + long entry = validity.getLong(entryOffset); if (valid) { entry |= mask; } else { entry &= ~mask; } - entries.put(entryIndex, entry); + validity.putLong(entryOffset, entry); } private String typeMismatchMessage(DuckDBColumnType expected) { @@ -633,7 +689,7 @@ private String timestampTypeMismatchMessage(boolean requireTimezone) { } } - private long encodeLocalDateTime(LocalDateTime value) throws SQLException { + private long encodeLocalDateTime(LocalDateTime value) { Instant instant; if (typeInfo.columnType == DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE) { instant = value.atZone(ZoneId.systemDefault()).toInstant(); @@ -643,11 +699,11 @@ private long encodeLocalDateTime(LocalDateTime value) throws SQLException { return encodeInstant(instant); } - private long encodeJavaUtilDate(java.util.Date value) throws SQLException { + private long encodeJavaUtilDate(java.util.Date value) { return encodeInstant(Instant.ofEpochMilli(value.getTime())); } - private long encodeInstant(Instant instant) throws SQLException { + private long encodeInstant(Instant instant) { long epochSeconds = instant.getEpochSecond(); int nano = instant.getNano(); switch (typeInfo.capiType) { @@ -661,7 +717,7 @@ private long encodeInstant(Instant instant) throws SQLException { case DUCKDB_TYPE_TIMESTAMP_NS: return Math.addExact(Math.multiplyExact(epochSeconds, 1_000_000_000L), nano); default: - throw new SQLException("Expected vector type TIMESTAMP*, found " + typeInfo.columnType); + throw new DuckDBFunctionException("Expected vector type TIMESTAMP*, found " + typeInfo.columnType); } } @@ -672,12 +728,12 @@ private static String unsignedRangeErrorMessage(String typeName, long value, lon return null; } - private SQLException decimalOutOfRange(BigDecimal value) { - return new SQLException("Value out of range for " + decimalTypeName() + ": " + value); + private DuckDBFunctionException decimalOutOfRange(BigDecimal value) { + return new DuckDBFunctionException("Value out of range for " + decimalTypeName() + ": " + value); } - private SQLException decimalOutOfRange(BigDecimal value, ArithmeticException cause) { - SQLException exception = decimalOutOfRange(value); + private DuckDBFunctionException decimalOutOfRange(BigDecimal value, ArithmeticException cause) { + DuckDBFunctionException exception = decimalOutOfRange(value); exception.initCause(cause); return exception; } diff --git a/src/test/java/org/duckdb/TestBindings.java b/src/test/java/org/duckdb/TestBindings.java index a947304e9..a0f4f1278 100644 --- a/src/test/java/org/duckdb/TestBindings.java +++ b/src/test/java/org/duckdb/TestBindings.java @@ -35,11 +35,7 @@ public static void test_bindings_vector_row_index_stream() throws Exception { DuckDBReadableVector readable = new DuckDBReadableVectorImpl(inputVec, 3); DuckDBWritableVector output = new DuckDBWritableVectorImpl(outputVec, 3); readable.rowIndexStream().forEachOrdered(row -> { - try { - output.setInt(row, readable.getInt(row) + 1); - } catch (SQLException exception) { - throw new RuntimeException(exception); - } + output.setInt(row, readable.getInt(row) + 1); }); DuckDBReadableVector result = new DuckDBReadableVectorImpl(outputVec, 3); @@ -75,7 +71,7 @@ public static void test_bindings_writable_vector_failed_append_does_not_advance( ByteBuffer vec = duckdb_create_vector(lt); DuckDBWritableVector writable = new DuckDBWritableVectorImpl(vec, 2); - assertThrows(() -> { writable.addString("boom"); }, SQLException.class); + assertThrows(() -> { writable.addString("boom"); }, DuckDBFunctionException.class); writable.addInt(7); writable.addInt(8); @@ -242,7 +238,7 @@ public static void test_bindings_writable_vector_stack_trace_origin() throws Exc try { writable.setInt(0, 42); fail("Expected setInt to reject VARCHAR vector"); - } catch (SQLException exception) { + } catch (DuckDBFunctionException exception) { assertTrue(exception.getMessage().contains("Expected vector type INTEGER, found VARCHAR")); assertEquals(exception.getStackTrace()[0].getMethodName(), "setInt"); } @@ -289,6 +285,38 @@ public static void test_bindings_vector_ubigint_native_endian_roundtrip() throws duckdb_destroy_logical_type(lt); } + public static void test_bindings_vector_uhugeint_native_endian_roundtrip() throws Exception { + ByteBuffer lt = duckdb_create_logical_type(DUCKDB_TYPE_UHUGEINT.typeId); + ByteBuffer vec = duckdb_create_vector(lt); + + int rowCount = (int) duckdb_vector_size(); + assertTrue(rowCount >= 4); + BigInteger[] values = + new BigInteger[] {BigInteger.ZERO, new BigInteger("42"), new BigInteger("9223372036854775808"), + new BigInteger("340282366920938463463374607431768211455")}; + + DuckDBWritableVector writable = new DuckDBWritableVectorImpl(vec, rowCount); + for (int i = 0; i < values.length; i++) { + writable.setUHugeInt(i, values[i]); + } + + ByteBuffer rawData = duckdb_vector_get_data(vec, (long) rowCount * Long.BYTES * 2); + ByteBuffer nativeData = rawData.order(ByteOrder.nativeOrder()); + for (int i = 0; i < values.length; i++) { + int offset = i * Long.BYTES * 2; + assertEquals(nativeData.getLong(offset), values[i].longValue()); + assertEquals(nativeData.getLong(offset + Long.BYTES), values[i].shiftRight(Long.SIZE).longValue()); + } + + DuckDBReadableVector readable = new DuckDBReadableVectorImpl(vec, rowCount); + for (int i = 0; i < values.length; i++) { + assertEquals(readable.getUHugeInt(i), values[i]); + } + + duckdb_destroy_vector(vec); + duckdb_destroy_logical_type(lt); + } + public static void test_bindings_vector_validity() throws Exception { ByteBuffer lt = duckdb_create_logical_type(DUCKDB_TYPE_VARCHAR.typeId); ByteBuffer vec = duckdb_create_vector(lt); diff --git a/src/test/java/org/duckdb/TestScalarFunctions.java b/src/test/java/org/duckdb/TestScalarFunctions.java index e730b1012..a9c3afbc8 100644 --- a/src/test/java/org/duckdb/TestScalarFunctions.java +++ b/src/test/java/org/duckdb/TestScalarFunctions.java @@ -80,9 +80,9 @@ public static void test_register_scalar_function_builder() throws Exception { .withName("java_add_int_builder") .withParameter(intType) .withReturnType(intType) - .propagateNulls(true) - .withVectorizedFunction( - ctx -> { ctx.stream().forEachOrdered(row -> row.setInt(row.getInt(0) + 1)); }) + .withVectorizedFunction(ctx -> { + ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setInt(row.getInt(0) + 1)); + }) .register(conn); assertEquals(function.name(), "java_add_int_builder"); assertEquals(function.parameterTypes().size(), 1); @@ -91,7 +91,7 @@ public static void test_register_scalar_function_builder() throws Exception { assertEquals(function.varArgType(), null); assertEquals(function.isVolatile(), false); assertEquals(function.hasSpecialHandling(), false); - assertEquals(function.propagateNulls(), true); + assertEquals(function.propagateNulls(), false); try (ResultSet rs = stmt.executeQuery("SELECT java_add_int_builder(v) FROM (VALUES (1), (NULL), (41)) t(v)")) { @@ -112,7 +112,7 @@ public static void test_register_scalar_function_builder_connection_without_unwr .withName("java_add_int_connection") .withParameter(Integer.class) .withReturnType(Integer.class) - .withFunction((Integer x) -> x + 1) + .withFunction((Integer x) -> null != x ? x + 1 : null) .register(conn); try (ResultSet rs = @@ -135,7 +135,7 @@ public static void test_register_scalar_function_builder_returns_detached_metada function = builder.withName("java_add_int_detached") .withParameter(Integer.class) .withReturnType(Integer.class) - .withFunction((Integer x) -> x + 1) + .withFunction((Integer x) -> null != x ? x + 1 : null) .register(conn); String message = @@ -147,9 +147,9 @@ public static void test_register_scalar_function_builder_returns_detached_metada assertEquals(function.parameterColumnTypes().get(0), DuckDBColumnType.INTEGER); assertEquals(function.returnColumnType(), DuckDBColumnType.INTEGER); assertNotNull(function.function()); - assertEquals(function.functionKind(), DuckDBFunctionKind.SCALAR); + assertEquals(function.functionKind(), DuckDBFunctions.DuckDBFunctionKind.SCALAR); assertTrue(function.isScalar()); - assertEquals(function.propagateNulls(), true); + assertEquals(function.propagateNulls(), false); try (ResultSet rs = stmt.executeQuery("SELECT java_add_int_detached(v) FROM (VALUES (1), (NULL), (41)) t(v)")) { @@ -177,7 +177,7 @@ public static void test_register_scalar_function_registry_records_registered_fun List registeredFunctions = DuckDBDriver.registeredFunctions(); assertEquals(registeredFunctions.size(), 1); assertEquals(registeredFunctions.get(0), function); - assertEquals(registeredFunctions.get(0).functionKind(), DuckDBFunctionKind.SCALAR); + assertEquals(registeredFunctions.get(0).functionKind(), DuckDBFunctions.DuckDBFunctionKind.SCALAR); assertTrue(registeredFunctions.get(0).isScalar()); try (ResultSet rs = stmt.executeQuery("SELECT java_registry_recorded(41)")) { @@ -306,7 +306,6 @@ public static void test_register_scalar_function_builder_varargs_and_flags() thr .withParameter(intType) .withVarArgs(intType) .withReturnType(intType) - .propagateNulls(true) .withVolatile() .withSpecialHandling() .withVectorizedFunction( @@ -315,7 +314,7 @@ public static void test_register_scalar_function_builder_varargs_and_flags() thr assertEquals(function.varArgType(), intType); assertEquals(function.isVolatile(), true); assertEquals(function.hasSpecialHandling(), true); - assertEquals(function.propagateNulls(), true); + assertEquals(function.propagateNulls(), false); try (ResultSet rs = stmt.executeQuery("SELECT java_sum_varargs_builder(1, 2, 3), java_sum_varargs_builder(5)")) { @@ -335,9 +334,9 @@ public static void test_register_scalar_function_builder_column_type_overloads() .withName("java_add_int_builder_col_type") .withParameter(DuckDBColumnType.INTEGER) .withReturnType(DuckDBColumnType.INTEGER) - .propagateNulls(true) - .withVectorizedFunction( - ctx -> { ctx.stream().forEachOrdered(row -> row.setInt(row.getInt(0) + 1)); }) + .withVectorizedFunction(ctx -> { + ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setInt(row.getInt(0) + 1)); + }) .register(conn); assertEquals(function.parameterColumnTypes().size(), 1); assertEquals(function.parameterColumnTypes().get(0), DuckDBColumnType.INTEGER); @@ -358,6 +357,29 @@ public static void test_register_scalar_function_builder_column_type_overloads() } } + public static void test_register_scalar_function_builder_with_parameters_class_helper() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBFunctions.scalarFunction() + .withName("java_add_int_with_parameters") + .withParameters(Integer.class, Integer.class) + .withReturnType(Integer.class) + .withFunction((Integer left, Integer right) -> left != null && right != null ? left + right : null) + .register(conn); + + try (ResultSet rs = stmt.executeQuery( + "SELECT java_add_int_with_parameters(a, b) FROM (VALUES (1, 2), (NULL, 1), (20, 22)) t(a, b)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 3); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 42); + assertFalse(rs.next()); + } + } + } + public static void test_register_scalar_function_builder_java_function() throws Exception { try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); Statement stmt = conn.createStatement()) { @@ -365,7 +387,7 @@ public static void test_register_scalar_function_builder_java_function() throws .withName("java_add_int_function") .withParameter(Integer.class) .withReturnType(Integer.class) - .withFunction((Integer x) -> x + 1) + .withFunction((Integer x) -> null != x ? x + 1 : null) .register(conn); try (ResultSet rs = @@ -389,7 +411,6 @@ public static void test_register_scalar_function_builder_java_function_propagate .withParameter(Integer.class) .withReturnType(Integer.class) .withFunction((Integer x) -> x == null ? 99 : x + 1) - .propagateNulls(false) .register(conn); assertEquals(function.propagateNulls(), false); @@ -415,7 +436,7 @@ public static void test_register_scalar_function_builder_java_bifunction() throw .withParameter(Integer.class) .withParameter(Integer.class) .withReturnType(Integer.class) - .withFunction((Integer x, Integer y) -> x + y) + .withFunction((Integer x, Integer y) -> null != x && null != y ? x + y : null) .register(conn); try ( @@ -445,7 +466,6 @@ public static void test_register_scalar_function_builder_java_bifunction_propaga .withReturnType(Integer.class) .withFunction( (Integer left, Integer right) -> (left == null ? 0 : left) + (right == null ? 0 : right)) - .propagateNulls(false) .register(conn); assertEquals(function.propagateNulls(), false); @@ -720,30 +740,6 @@ public static void test_register_scalar_function_builder_with_int_function_rejec } } - public static void test_register_scalar_function_builder_with_int_function_rejects_propagate_nulls_false() - throws Exception { - try (DuckDBScalarFunctionBuilder builder = DuckDBFunctions.scalarFunction()) { - builder.withName("java_invalid_int_function_null_propagation") - .withParameter(Integer.class) - .withReturnType(Integer.class) - .propagateNulls(false); - String message = assertThrows(() -> { builder.withIntFunction(x -> x + 1); }, SQLException.class); - assertTrue(message.contains("withIntFunction requires propagateNulls(true)")); - } - } - - public static void test_register_scalar_function_builder_with_int_function_rejects_disabling_propagate_nulls() - throws Exception { - try (DuckDBScalarFunctionBuilder builder = DuckDBFunctions.scalarFunction()) { - builder.withName("java_invalid_int_function_disable_null_propagation") - .withParameter(Integer.class) - .withReturnType(Integer.class) - .withIntFunction(x -> x + 1); - String message = assertThrows(() -> { builder.propagateNulls(false); }, SQLException.class); - assertTrue(message.contains("Primitive scalar callbacks require propagateNulls(true)")); - } - } - public static void test_register_scalar_function_builder_with_int_function_rejects_wrong_types() throws Exception { try (DuckDBScalarFunctionBuilder builder = DuckDBFunctions.scalarFunction()) { builder.withName("java_invalid_int_function_type") @@ -769,18 +765,6 @@ public static void test_register_scalar_function_builder_java_bifunction_rejects } } - public static void test_register_scalar_function_builder_with_double_function_rejects_propagate_nulls_false() - throws Exception { - try (DuckDBScalarFunctionBuilder builder = DuckDBFunctions.scalarFunction()) { - builder.withName("java_invalid_double_function_null_propagation") - .withParameter(Double.class) - .withReturnType(Double.class) - .propagateNulls(false); - String message = assertThrows(() -> { builder.withDoubleFunction(x -> x + 0.5d); }, SQLException.class); - assertTrue(message.contains("withDoubleFunction requires propagateNulls(true)")); - } - } - public static void test_register_scalar_function_builder_with_double_function_rejects_wrong_types() throws Exception { try (DuckDBScalarFunctionBuilder builder = DuckDBFunctions.scalarFunction()) { @@ -804,18 +788,6 @@ public static void test_register_scalar_function_builder_with_long_function_reje } } - public static void test_register_scalar_function_builder_with_long_function_rejects_propagate_nulls_false() - throws Exception { - try (DuckDBScalarFunctionBuilder builder = DuckDBFunctions.scalarFunction()) { - builder.withName("java_invalid_long_function_null_propagation") - .withParameter(Long.class) - .withReturnType(Long.class) - .propagateNulls(false); - String message = assertThrows(() -> { builder.withLongFunction(x -> x + 1); }, SQLException.class); - assertTrue(message.contains("withLongFunction requires propagateNulls(true)")); - } - } - public static void test_register_scalar_function_builder_with_long_function_rejects_wrong_types() throws Exception { try (DuckDBScalarFunctionBuilder builder = DuckDBFunctions.scalarFunction()) { builder.withName("java_invalid_long_function_type").withParameter(Integer.class).withReturnType(Long.class); @@ -866,14 +838,14 @@ public static void test_register_scalar_function_builder_java_varargs_function_r } public static void test_register_scalar_function_builder_java_function_supported_class_types() throws Exception { - Function notBoolean = value -> !value; - Function addTinyInt = value -> (byte) (value + 1); - Function addBigInt = value -> value + 3; - Function addDouble = value -> value + 0.5; - Function suffixString = value -> value + "_ok"; - Function addDate = value -> value.plusDays(2); - Function addTimestamp = value -> value.plusMinutes(30); - Function addTimestampTz = value -> value.plusMinutes(5); + Function notBoolean = value -> null != value ? !value : null; + Function addTinyInt = value -> null != value ? (byte) (value + 1) : null; + Function addBigInt = value -> null != value ? value + 3 : null; + Function addDouble = value -> null != value ? value + 0.5 : null; + Function suffixString = value -> null != value ? value + "_ok" : null; + Function addDate = value -> null != value ? value.plusDays(2) : null; + Function addTimestamp = value -> null != value ? value.plusMinutes(30) : null; + Function addTimestampTz = value -> null != value ? value.plusMinutes(5) : null; assertUnaryJavaFunction("java_not_bool_function", Boolean.class, Boolean.class, notBoolean, "SELECT java_not_bool_function(v) FROM (VALUES (TRUE), (NULL), (FALSE)) t(v)", rs -> { assertTrue(rs.next()); @@ -977,10 +949,10 @@ public static void test_register_scalar_function_builder_java_function_supported } public static void test_register_scalar_function_builder_java_function_supported_unsigned_types() throws Exception { - Function addUTinyInt = value -> (short) (value + 1); - Function addUSmallInt = value -> value + 2; - Function addUInteger = value -> value + 3; - Function addUBigInt = value -> value.add(BigInteger.ONE); + Function addUTinyInt = value -> null != value ? (short) (value + 1) : null; + Function addUSmallInt = value -> null != value ? value + 2 : null; + Function addUInteger = value -> null != value ? value + 3 : null; + Function addUBigInt = value -> null != value ? value.add(BigInteger.ONE) : null; assertUnaryJavaFunction( "java_add_utinyint_function", DuckDBColumnType.UTINYINT, DuckDBColumnType.UTINYINT, addUTinyInt, "SELECT java_add_utinyint_function(v) FROM (VALUES (41::UTINYINT), (NULL), (254::UTINYINT)) t(v)", rs -> { @@ -1034,9 +1006,29 @@ public static void test_register_scalar_function_builder_java_function_supported }); } + public static void test_register_scalar_function_builder_java_function_hugeint_class_mapping() throws Exception { + Function addHugeInt = + value -> null != value ? value.add(BigInteger.ONE) : null; + assertUnaryJavaFunction( + "java_add_hugeint_function", BigInteger.class, BigInteger.class, addHugeInt, + "SELECT java_add_hugeint_function(v) FROM (VALUES (CAST('41' AS HUGEINT)), (NULL), " + + "(CAST('170141183460469231731687303715884105726' AS HUGEINT))) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, BigInteger.class), new BigInteger("42")); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, BigInteger.class), + new BigInteger("170141183460469231731687303715884105727")); + assertFalse(rs.next()); + }); + } + public static void test_register_scalar_function_builder_java_function_decimal() throws Exception { try (DuckDBLogicalType decimalType = DuckDBLogicalType.decimal(10, 2)) { - Function addDecimal = value -> value.add(new BigDecimal("1.25")); + Function addDecimal = + value -> null != value ? value.add(new BigDecimal("1.25")) : null; assertUnaryJavaFunction("java_add_decimal_function", decimalType, decimalType, addDecimal, "SELECT java_add_decimal_function(v) FROM (VALUES (CAST(40.75 AS DECIMAL(10,2))), " + "(NULL), (CAST(-1.25 AS DECIMAL(10,2)))) t(v)", @@ -1054,7 +1046,7 @@ public static void test_register_scalar_function_builder_java_function_decimal() public static void test_register_scalar_function_builder_java_function_sql_date_class_mapping() throws Exception { Function addSqlDate = - value -> java.sql.Date.valueOf(value.toLocalDate().plusDays(1)); + value -> null != value ? java.sql.Date.valueOf(value.toLocalDate().plusDays(1)) : null; assertUnaryJavaFunction( "java_add_sql_date_function", java.sql.Date.class, java.sql.Date.class, addSqlDate, "SELECT java_add_sql_date_function(v) FROM (VALUES (DATE '2024-07-21'), (NULL), (DATE '2024-07-30')) t(v)", @@ -1072,7 +1064,7 @@ public static void test_register_scalar_function_builder_java_function_sql_date_ public static void test_register_scalar_function_builder_java_function_sql_timestamp_class_mapping() throws Exception { Function addSqlTimestamp = - value -> java.sql.Timestamp.valueOf(value.toLocalDateTime().plusSeconds(1)); + value -> null != value ? java.sql.Timestamp.valueOf(value.toLocalDateTime().plusSeconds(1)) : null; assertUnaryJavaFunction("java_add_sql_timestamp_function", java.sql.Timestamp.class, java.sql.Timestamp.class, addSqlTimestamp, "SELECT java_add_sql_timestamp_function(v) FROM (VALUES " @@ -1089,7 +1081,8 @@ public static void test_register_scalar_function_builder_java_function_sql_times public static void test_register_scalar_function_builder_java_function_java_util_date_class_mapping() throws Exception { - Function addUtilDate = value -> new java.util.Date(value.getTime() + 1000L); + Function addUtilDate = + value -> null != value ? new java.util.Date(value.getTime() + 1000L) : null; assertUnaryJavaFunction("java_add_java_util_date_function", java.util.Date.class, java.util.Date.class, addUtilDate, "SELECT java_add_java_util_date_function(v) = date_trunc('millisecond', v) + " @@ -1105,8 +1098,10 @@ public static void test_register_scalar_function_builder_java_function_java_util } public static void test_register_scalar_function_builder_java_bifunction_supported_types() throws Exception { - BiFunction concatUnderscore = (left, right) -> left + "_" + right; - BiFunction sumDouble = Double::sum; + BiFunction concatUnderscore = + (left, right) -> left != null && right != null ? left + "_" + right : null; + BiFunction sumDouble = + (left, right) -> left != null && right != null ? Double.sum(left, right) : null; assertBinaryJavaFunction( "java_concat_varchar_bifunction", String.class, String.class, String.class, concatUnderscore, "SELECT java_concat_varchar_bifunction(a, b) FROM (VALUES ('duck', 'db'), (NULL, 'x'), ('a', NULL)) t(a, b)", @@ -1142,8 +1137,8 @@ public static void test_register_scalar_function_typed_logical_type() throws Exc .withName("java_add_int_typed") .withParameter(intType) .withReturnType(intType) - .propagateNulls(true) - .withVectorizedFunction(ctx -> { ctx.stream().forEachOrdered(row -> row.setInt(row.getInt(0) + 1)); }) + .withVectorizedFunction( + ctx -> { ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setInt(row.getInt(0) + 1)); }) .register(conn); try (ResultSet rs = stmt.executeQuery("SELECT java_add_int_typed(v) FROM (VALUES (1), (NULL), (41)) t(v)")) { @@ -1167,7 +1162,6 @@ public static void test_register_scalar_function_parallel() throws Exception { .withName("java_add_one_bigint") .withParameter(bigintType) .withReturnType(bigintType) - .propagateNulls(true) .withVectorizedFunction(ctx -> { DuckDBWritableVector out = ctx.output(); DuckDBReadableVector in = ctx.input(0); @@ -1195,8 +1189,8 @@ public static void test_register_scalar_function_context_row_stream_int() throws .withName("java_add_int_row_stream") .withParameter(intType) .withReturnType(intType) - .propagateNulls(true) - .withVectorizedFunction(ctx -> { ctx.stream().forEachOrdered(row -> row.setInt(row.getInt(0) + 1)); }) + .withVectorizedFunction( + ctx -> { ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setInt(row.getInt(0) + 1)); }) .register(conn); try (ResultSet rs = @@ -1220,9 +1214,9 @@ public static void test_register_scalar_function_context_row_stream_double() thr .withName("java_add_double_row_stream") .withParameter(doubleType) .withReturnType(doubleType) - .propagateNulls(true) - .withVectorizedFunction( - ctx -> { ctx.stream().forEachOrdered(row -> row.setDouble(row.getDouble(0) + 1.5d)); }) + .withVectorizedFunction(ctx -> { + ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setDouble(row.getDouble(0) + 1.5d)); + }) .register(conn); try ( @@ -1239,9 +1233,85 @@ public static void test_register_scalar_function_context_row_stream_double() thr } } + public static void test_register_scalar_function_primitive_nulls_handling() throws Exception { + try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement()) { + DuckDBFunctions.scalarFunction() + .withName("java_primitive_nulls_handling") + .withParameter(DuckDBColumnType.BOOLEAN) + .withParameter(DuckDBColumnType.TINYINT) + .withParameter(DuckDBColumnType.UTINYINT) + .withParameter(DuckDBColumnType.SMALLINT) + .withParameter(DuckDBColumnType.USMALLINT) + .withParameter(DuckDBColumnType.INTEGER) + .withParameter(DuckDBColumnType.UINTEGER) + .withParameter(DuckDBColumnType.BIGINT) + .withParameter(DuckDBColumnType.FLOAT) + .withParameter(DuckDBColumnType.DOUBLE) + .withReturnType(DuckDBColumnType.VARCHAR) + .withSpecialHandling() + .withVectorizedFunction(ctx -> { + assertFalse(ctx.nullsPropagated()); + ctx.stream().forEachOrdered(row -> { + try { + DuckDBReadableVector booleanVector = ctx.input(0); + DuckDBReadableVector intVector = ctx.input(5); + assertThrows(() -> { booleanVector.getBoolean(row.index()); }, DuckDBFunctionException.class); + assertTrue(booleanVector.getBoolean(row.index(), true)); + assertThrows(() -> { intVector.getInt(row.index()); }, DuckDBFunctionException.class); + assertEquals(intVector.getInt(row.index(), 42), 42); + + assertThrows(() -> { row.getBoolean(0); }, DuckDBFunctionException.class); + try { + row.getBoolean(0); + fail("Expected row.getBoolean(0) to fail on NULL"); + } catch (DuckDBFunctionException exception) { + assertTrue(exception.getMessage().contains("Failed to read BOOLEAN")); + assertNotNull(exception.getCause()); + assertTrue(exception.getCause() instanceof DuckDBFunctionException); + assertTrue( + exception.getCause().getMessage().contains("Primitive value for BOOLEAN")); + } + assertTrue(row.getBoolean(0, true)); + assertThrows(() -> { row.getByte(1); }, DuckDBFunctionException.class); + assertEquals(row.getByte(1, (byte) 42), (byte) 42); + assertThrows(() -> { row.getUint8(2); }, DuckDBFunctionException.class); + assertEquals(row.getUint8(2, (short) 42), (short) 42); + assertThrows(() -> { row.getShort(3); }, DuckDBFunctionException.class); + assertEquals(row.getShort(3, (short) 42), (short) 42); + assertThrows(() -> { row.getUint16(4); }, DuckDBFunctionException.class); + assertEquals(row.getUint16(4, 42), 42); + assertThrows(() -> { row.getInt(5); }, DuckDBFunctionException.class); + assertEquals(row.getInt(5, 42), 42); + assertThrows(() -> { row.getUint32(6); }, DuckDBFunctionException.class); + assertEquals(row.getUint32(6, (long) 42), (long) 42); + assertThrows(() -> { row.getLong(7); }, DuckDBFunctionException.class); + assertEquals(row.getLong(7, (long) 42), (long) 42); + assertThrows(() -> { row.getFloat(8); }, DuckDBFunctionException.class); + assertEquals(row.getFloat(8, (float) 42.1), (float) 42.1); + assertThrows(() -> { row.getDouble(9); }, DuckDBFunctionException.class); + assertEquals(row.getDouble(9, 42.1), 42.1); + } catch (Exception e) { + throw new RuntimeException(e); + } + row.setString("ok"); + }); + }) + .register(conn); + + try ( + ResultSet rs = stmt.executeQuery( + "SELECT java_primitive_nulls_handling(NULL::BOOLEAN, NULL::TINYINT, NULL::UTINYINT, NULL::SMALLINT," + + " NULL::USMALLINT, NULL::INTEGER, NULL::UINTEGER, NULL::BIGINT, NULL::FLOAT, NULL::DOUBLE)")) { + assertTrue(rs.next()); + assertEquals(rs.getString(1), "ok"); + assertFalse(rs.next()); + } + } + } + public static void test_register_scalar_function_context_row_stream_propagate_nulls_false() throws Exception { assertUnaryScalarFunction( - "java_suffix_varchar_row_stream_nullable", DuckDBColumnType.VARCHAR, DuckDBColumnType.VARCHAR, false, + "java_suffix_varchar_row_stream_nullable", DuckDBColumnType.VARCHAR, DuckDBColumnType.VARCHAR, ctx -> { ctx.stream().forEachOrdered(row -> { @@ -1316,7 +1386,6 @@ public static void test_register_scalar_function_exception_propagation() throws .withName("java_throws_exception") .withParameter(intType) .withReturnType(intType) - .propagateNulls(true) .withVectorizedFunction(ctx -> { throw new IllegalStateException("boom"); }) .register(conn); String message = @@ -1328,42 +1397,46 @@ public static void test_register_scalar_function_exception_propagation() throws } public static void test_register_scalar_function_boolean() throws Exception { - assertUnaryScalarFunction("java_not_bool", DuckDBColumnType.BOOLEAN, DuckDBColumnType.BOOLEAN, - ctx - -> { ctx.stream().forEachOrdered(row -> row.setBoolean(!row.getBoolean(0))); }, - "SELECT java_not_bool(v) FROM (VALUES (TRUE), (NULL), (FALSE)) t(v)", - rs -> { - assertTrue(rs.next()); - assertEquals(rs.getObject(1, Boolean.class), false); - assertTrue(rs.next()); - assertNullRow(rs); - assertTrue(rs.next()); - assertEquals(rs.getObject(1, Boolean.class), true); - assertFalse(rs.next()); - }); + assertUnaryScalarFunction( + "java_not_bool", DuckDBColumnType.BOOLEAN, DuckDBColumnType.BOOLEAN, + ctx + -> { ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setBoolean(!row.getBoolean(0))); }, + "SELECT java_not_bool(v) FROM (VALUES (TRUE), (NULL), (FALSE)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Boolean.class), false); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Boolean.class), true); + assertFalse(rs.next()); + }); } public static void test_register_scalar_function_tinyint() throws Exception { - assertUnaryScalarFunction("java_add_tinyint", DuckDBColumnType.TINYINT, DuckDBColumnType.TINYINT, - ctx - -> { ctx.stream().forEachOrdered(row -> row.setByte((byte) (row.getByte(0) + 1))); }, - "SELECT java_add_tinyint(v) FROM (VALUES (41::TINYINT), (NULL), (-2::TINYINT)) t(v)", - rs -> { - assertTrue(rs.next()); - assertEquals(rs.getObject(1, Byte.class), (byte) 42); - assertTrue(rs.next()); - assertNullRow(rs); - assertTrue(rs.next()); - assertEquals(rs.getObject(1, Byte.class), (byte) -1); - assertFalse(rs.next()); - }); + assertUnaryScalarFunction( + "java_add_tinyint", DuckDBColumnType.TINYINT, DuckDBColumnType.TINYINT, + ctx + -> { ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setByte((byte) (row.getByte(0) + 1))); }, + "SELECT java_add_tinyint(v) FROM (VALUES (41::TINYINT), (NULL), (-2::TINYINT)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Byte.class), (byte) 42); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Byte.class), (byte) -1); + assertFalse(rs.next()); + }); } public static void test_register_scalar_function_smallint() throws Exception { assertUnaryScalarFunction( "java_add_smallint", DuckDBColumnType.SMALLINT, DuckDBColumnType.SMALLINT, ctx - -> { ctx.stream().forEachOrdered(row -> row.setShort((short) (row.getShort(0) + 2))); }, + -> { + ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setShort((short) (row.getShort(0) + 2))); + }, "SELECT java_add_smallint(v) FROM (VALUES (40::SMALLINT), (NULL), (-4::SMALLINT)) t(v)", rs -> { assertTrue(rs.next()); @@ -1377,26 +1450,27 @@ public static void test_register_scalar_function_smallint() throws Exception { } public static void test_register_scalar_function_integer() throws Exception { - assertUnaryScalarFunction("java_add_int", DuckDBColumnType.INTEGER, DuckDBColumnType.INTEGER, - ctx - -> { ctx.stream().forEachOrdered(row -> row.setInt(row.getInt(0) + 1)); }, - "SELECT java_add_int(v) FROM (VALUES (1), (NULL), (41)) t(v)", - rs -> { - assertTrue(rs.next()); - assertEquals(rs.getObject(1, Integer.class), 2); - assertTrue(rs.next()); - assertNullRow(rs); - assertTrue(rs.next()); - assertEquals(rs.getObject(1, Integer.class), 42); - assertFalse(rs.next()); - }); + assertUnaryScalarFunction( + "java_add_int", DuckDBColumnType.INTEGER, DuckDBColumnType.INTEGER, + ctx + -> { ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setInt(row.getInt(0) + 1)); }, + "SELECT java_add_int(v) FROM (VALUES (1), (NULL), (41)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 2); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 42); + assertFalse(rs.next()); + }); } public static void test_register_scalar_function_integer_revalidates_after_null() throws Exception { assertUnaryScalarFunction("java_revalidate_int", DuckDBColumnType.INTEGER, DuckDBColumnType.INTEGER, ctx -> { - ctx.stream().forEachOrdered(row -> { + ctx.propagateNulls(true).stream().forEachOrdered(row -> { row.setNull(); row.setInt(row.getInt(0) + 1); }); @@ -1413,26 +1487,27 @@ public static void test_register_scalar_function_integer_revalidates_after_null( } public static void test_register_scalar_function_bigint() throws Exception { - assertUnaryScalarFunction("java_add_bigint", DuckDBColumnType.BIGINT, DuckDBColumnType.BIGINT, - ctx - -> { ctx.stream().forEachOrdered(row -> row.setLong(row.getLong(0) + 3)); }, - "SELECT java_add_bigint(v) FROM (VALUES (39::BIGINT), (NULL), (-5::BIGINT)) t(v)", - rs -> { - assertTrue(rs.next()); - assertEquals(rs.getObject(1, Long.class), 42L); - assertTrue(rs.next()); - assertNullRow(rs); - assertTrue(rs.next()); - assertEquals(rs.getObject(1, Long.class), -2L); - assertFalse(rs.next()); - }); + assertUnaryScalarFunction( + "java_add_bigint", DuckDBColumnType.BIGINT, DuckDBColumnType.BIGINT, + ctx + -> { ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setLong(row.getLong(0) + 3)); }, + "SELECT java_add_bigint(v) FROM (VALUES (39::BIGINT), (NULL), (-5::BIGINT)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Long.class), 42L); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Long.class), -2L); + assertFalse(rs.next()); + }); } public static void test_register_scalar_function_utinyint() throws Exception { assertUnaryScalarFunction( "java_add_utinyint", DuckDBColumnType.UTINYINT, DuckDBColumnType.UTINYINT, ctx - -> { ctx.stream().forEachOrdered(row -> row.setUint8(row.getUint8(0) + 1)); }, + -> { ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setUint8(row.getUint8(0) + 1)); }, "SELECT java_add_utinyint(v) FROM (VALUES (41::UTINYINT), (NULL), (254::UTINYINT)) t(v)", rs -> { assertTrue(rs.next()); @@ -1449,7 +1524,7 @@ public static void test_register_scalar_function_usmallint() throws Exception { assertUnaryScalarFunction( "java_add_usmallint", DuckDBColumnType.USMALLINT, DuckDBColumnType.USMALLINT, ctx - -> { ctx.stream().forEachOrdered(row -> row.setUint16(row.getUint16(0) + 2)); }, + -> { ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setUint16(row.getUint16(0) + 2)); }, "SELECT java_add_usmallint(v) FROM (VALUES (40::USMALLINT), (NULL), (65533::USMALLINT)) t(v)", rs -> { assertTrue(rs.next()); @@ -1466,7 +1541,7 @@ public static void test_register_scalar_function_uinteger() throws Exception { assertUnaryScalarFunction( "java_add_uinteger", DuckDBColumnType.UINTEGER, DuckDBColumnType.UINTEGER, ctx - -> { ctx.stream().forEachOrdered(row -> row.setUint32(row.getUint32(0) + 3)); }, + -> { ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setUint32(row.getUint32(0) + 3)); }, "SELECT java_add_uinteger(v) FROM (VALUES (39::UINTEGER), (NULL), (4294967292::UINTEGER)) t(v)", rs -> { assertTrue(rs.next()); @@ -1485,7 +1560,7 @@ public static void test_register_scalar_function_ubigint() throws Exception { ctx -> { BigInteger increment = BigInteger.ONE; - ctx.stream().forEachOrdered(row -> row.setUint64(row.getUint64(0).add(increment))); + ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setUint64(row.getUint64(0).add(increment))); }, "SELECT java_add_ubigint(v) FROM (VALUES (41::UBIGINT), (NULL), " + "(18446744073709551614::UBIGINT)) t(v)", @@ -1500,60 +1575,114 @@ public static void test_register_scalar_function_ubigint() throws Exception { }); } + public static void test_register_scalar_function_uhugeint() throws Exception { + assertUnaryScalarFunction( + "java_add_uhugeint", DuckDBColumnType.UHUGEINT, DuckDBColumnType.UHUGEINT, + ctx + -> { + BigInteger increment = BigInteger.ONE; + ctx.propagateNulls(true).stream().forEachOrdered( + row -> row.setUHugeInt(row.getUHugeInt(0).add(increment))); + }, + "SELECT java_add_uhugeint(v) FROM (VALUES (CAST('41' AS UHUGEINT)), (NULL), " + + "(CAST('340282366920938463463374607431768211454' AS UHUGEINT))) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, BigInteger.class), new BigInteger("42")); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, BigInteger.class), + new BigInteger("340282366920938463463374607431768211455")); + assertFalse(rs.next()); + }); + } + + public static void test_register_scalar_function_builder_java_function_uhugeint() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBFunctions.scalarFunction() + .withName("java_add_uhugeint_function") + .withParameter(DuckDBColumnType.UHUGEINT) + .withReturnType(DuckDBColumnType.UHUGEINT) + .withFunction( + (BigInteger value) + -> null != value ? value.add(BigInteger.ONE) : null) + .register(conn); + + try ( + ResultSet rs = stmt.executeQuery( + "SELECT java_add_uhugeint_function(v) FROM (VALUES (CAST('41' AS UHUGEINT)), (NULL), " + + "(CAST('340282366920938463463374607431768211454' AS UHUGEINT))) t(v)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, BigInteger.class), new BigInteger("42")); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, BigInteger.class), + new BigInteger("340282366920938463463374607431768211455")); + assertFalse(rs.next()); + } + } + } + public static void test_register_scalar_function_float() throws Exception { - assertUnaryScalarFunction("java_add_float", DuckDBColumnType.FLOAT, DuckDBColumnType.FLOAT, - ctx - -> { ctx.stream().forEachOrdered(row -> row.setFloat(row.getFloat(0) + 1.25f)); }, - "SELECT java_add_float(v) FROM (VALUES (40.75::FLOAT), (NULL), (-2.5::FLOAT)) t(v)", - rs -> { - assertTrue(rs.next()); - assertEquals(rs.getObject(1, Float.class), 42.0f); - assertTrue(rs.next()); - assertNullRow(rs); - assertTrue(rs.next()); - assertEquals(rs.getObject(1, Float.class), -1.25f); - assertFalse(rs.next()); - }); + assertUnaryScalarFunction( + "java_add_float", DuckDBColumnType.FLOAT, DuckDBColumnType.FLOAT, + ctx + -> { ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setFloat(row.getFloat(0) + 1.25f)); }, + "SELECT java_add_float(v) FROM (VALUES (40.75::FLOAT), (NULL), (-2.5::FLOAT)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Float.class), 42.0f); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Float.class), -1.25f); + assertFalse(rs.next()); + }); } public static void test_register_scalar_function_double() throws Exception { - assertUnaryScalarFunction("java_add_double", DuckDBColumnType.DOUBLE, DuckDBColumnType.DOUBLE, - ctx - -> { ctx.stream().forEachOrdered(row -> row.setDouble(row.getDouble(0) + 1.5d)); }, - "SELECT java_add_double(v) FROM (VALUES (40.5::DOUBLE), (NULL), (-3.0::DOUBLE)) t(v)", - rs -> { - assertTrue(rs.next()); - assertEquals(rs.getObject(1, Double.class), 42.0d); - assertTrue(rs.next()); - assertNullRow(rs); - assertTrue(rs.next()); - assertEquals(rs.getObject(1, Double.class), -1.5d); - assertFalse(rs.next()); - }); + assertUnaryScalarFunction( + "java_add_double", DuckDBColumnType.DOUBLE, DuckDBColumnType.DOUBLE, + ctx + -> { ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setDouble(row.getDouble(0) + 1.5d)); }, + "SELECT java_add_double(v) FROM (VALUES (40.5::DOUBLE), (NULL), (-3.0::DOUBLE)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Double.class), 42.0d); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Double.class), -1.5d); + assertFalse(rs.next()); + }); } public static void test_register_scalar_function_decimal() throws Exception { try (DuckDBLogicalType decimalType = DuckDBLogicalType.decimal(38, 10)) { - assertUnaryScalarFunction( - "java_add_decimal", decimalType, decimalType, - ctx - -> { - BigDecimal increment = new BigDecimal("0.0000000001"); - ctx.stream().forEachOrdered(row -> row.setBigDecimal(row.getBigDecimal(0).add(increment))); - }, - "SELECT java_add_decimal(v) FROM (VALUES " - + "(CAST('12345678901234567890.1234567890' AS DECIMAL(38,10))), " - + "(NULL), " - + "(CAST('-0.0000000001' AS DECIMAL(38,10)))) t(v)", - rs -> { - assertTrue(rs.next()); - assertEquals(rs.getObject(1, BigDecimal.class), new BigDecimal("12345678901234567890.1234567891")); - assertTrue(rs.next()); - assertNullRow(rs); - assertTrue(rs.next()); - assertEquals(rs.getObject(1, BigDecimal.class), BigDecimal.ZERO.setScale(10)); - assertFalse(rs.next()); - }); + assertUnaryScalarFunction("java_add_decimal", decimalType, decimalType, + ctx + -> { + BigDecimal increment = new BigDecimal("0.0000000001"); + ctx.propagateNulls(true).stream().forEachOrdered( + row -> row.setBigDecimal(row.getBigDecimal(0).add(increment))); + }, + "SELECT java_add_decimal(v) FROM (VALUES " + + "(CAST('12345678901234567890.1234567890' AS DECIMAL(38,10))), " + + "(NULL), " + + "(CAST('-0.0000000001' AS DECIMAL(38,10)))) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, BigDecimal.class), + new BigDecimal("12345678901234567890.1234567891")); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, BigDecimal.class), BigDecimal.ZERO.setScale(10)); + assertFalse(rs.next()); + }); } } @@ -1565,7 +1694,6 @@ public static void test_register_scalar_function_decimal_precision_overflow() th .withName("java_decimal_precision_overflow") .withParameter(decimalType) .withReturnType(decimalType) - .propagateNulls(true) .withVectorizedFunction(ctx -> { DuckDBWritableVector out = ctx.output(); long rowCount = ctx.rowCount(); @@ -1590,7 +1718,6 @@ public static void test_register_scalar_function_decimal_scale_overflow() throws .withName("java_decimal_scale_overflow") .withParameter(decimalType) .withReturnType(decimalType) - .propagateNulls(true) .withVectorizedFunction(ctx -> { DuckDBWritableVector out = ctx.output(); long rowCount = ctx.rowCount(); @@ -1611,7 +1738,9 @@ public static void test_register_scalar_function_date() throws Exception { assertUnaryScalarFunction( "java_add_date", DuckDBColumnType.DATE, DuckDBColumnType.DATE, ctx - -> { ctx.stream().forEachOrdered(row -> row.setDate(row.getLocalDate(0).plusDays(2))); }, + -> { + ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setDate(row.getLocalDate(0).plusDays(2))); + }, "SELECT java_add_date(v) FROM (VALUES (DATE '2024-07-20'), (NULL), (DATE '1969-12-31')) t(v)", rs -> { assertTrue(rs.next()); @@ -1629,7 +1758,7 @@ public static void test_register_scalar_function_date_from_java_util_date() thro assertUnaryScalarFunction("java_date_from_util_date", DuckDBColumnType.DATE, DuckDBColumnType.DATE, ctx -> { - ctx.stream().forEachOrdered(row -> { + ctx.propagateNulls(true).stream().forEachOrdered(row -> { LocalDate value = row.getLocalDate(0).plusDays(1); row.setDate(java.util.Date.from(value.atStartOfDay(UTC).toInstant())); }); @@ -1645,32 +1774,38 @@ public static void test_register_scalar_function_date_from_java_util_date() thro } public static void test_register_scalar_function_timestamp() throws Exception { - assertUnaryScalarFunction( - "java_add_timestamp", DuckDBColumnType.TIMESTAMP, DuckDBColumnType.TIMESTAMP, - ctx - -> { ctx.stream().forEachOrdered(row -> row.setTimestamp(row.getLocalDateTime(0).plusMinutes(30))); }, - "SELECT java_add_timestamp(v) FROM (VALUES " - + "(TIMESTAMP '2024-07-21 12:34:56.123456'), " - + "(NULL), " - + "(TIMESTAMP '1969-12-31 23:45:00')) t(v)", - rs -> { - assertTrue(rs.next()); - Timestamp ts1 = rs.getTimestamp(1); - assertEquals(ts1, Timestamp.valueOf("2024-07-21 13:04:56.123456")); - assertEquals(rs.getObject(1, LocalDateTime.class), LocalDateTime.of(2024, 7, 21, 13, 4, 56, 123456000)); - assertTrue(rs.next()); - assertNullRow(rs); - assertTrue(rs.next()); - assertEquals(rs.getTimestamp(1), Timestamp.valueOf("1970-01-01 00:15:00")); - assertFalse(rs.next()); - }); + assertUnaryScalarFunction("java_add_timestamp", DuckDBColumnType.TIMESTAMP, DuckDBColumnType.TIMESTAMP, + ctx + -> { + ctx.propagateNulls(true).stream().forEachOrdered( + row -> row.setTimestamp(row.getLocalDateTime(0).plusMinutes(30))); + }, + "SELECT java_add_timestamp(v) FROM (VALUES " + + "(TIMESTAMP '2024-07-21 12:34:56.123456'), " + + "(NULL), " + + "(TIMESTAMP '1969-12-31 23:45:00')) t(v)", + rs -> { + assertTrue(rs.next()); + Timestamp ts1 = rs.getTimestamp(1); + assertEquals(ts1, Timestamp.valueOf("2024-07-21 13:04:56.123456")); + assertEquals(rs.getObject(1, LocalDateTime.class), + LocalDateTime.of(2024, 7, 21, 13, 4, 56, 123456000)); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getTimestamp(1), Timestamp.valueOf("1970-01-01 00:15:00")); + assertFalse(rs.next()); + }); } public static void test_register_scalar_function_timestamp_s() throws Exception { assertUnaryScalarFunction( "java_add_timestamp_s", DuckDBColumnType.TIMESTAMP_S, DuckDBColumnType.TIMESTAMP_S, ctx - -> { ctx.stream().forEachOrdered(row -> row.setTimestamp(row.getLocalDateTime(0).plusSeconds(2))); }, + -> { + ctx.propagateNulls(true).stream().forEachOrdered( + row -> row.setTimestamp(row.getLocalDateTime(0).plusSeconds(2))); + }, "SELECT java_add_timestamp_s(v) FROM (VALUES (TIMESTAMP_S '2024-07-21 12:34:56'), (NULL)) t(v)", rs -> { assertTrue(rs.next()); @@ -1696,20 +1831,22 @@ public static void test_register_scalar_function_timestamp_s_pre_epoch() throws } public static void test_register_scalar_function_timestamp_ms() throws Exception { - assertUnaryScalarFunction( - "java_add_timestamp_ms", DuckDBColumnType.TIMESTAMP_MS, DuckDBColumnType.TIMESTAMP_MS, - ctx - -> { ctx.stream().forEachOrdered(row -> row.setTimestamp(row.getLocalDateTime(0).plusNanos(7_000_000))); }, - "SELECT java_add_timestamp_ms(v) FROM (VALUES " - + "(TIMESTAMP_MS '2024-07-21 12:34:56.123'), (NULL)) t(v)", - rs -> { - assertTrue(rs.next()); - assertEquals(rs.getObject(1, LocalDateTime.class), - LocalDateTime.of(2024, 7, 21, 12, 34, 56, 130_000_000)); - assertTrue(rs.next()); - assertNullRow(rs); - assertFalse(rs.next()); - }); + assertUnaryScalarFunction("java_add_timestamp_ms", DuckDBColumnType.TIMESTAMP_MS, DuckDBColumnType.TIMESTAMP_MS, + ctx + -> { + ctx.propagateNulls(true).stream().forEachOrdered( + row -> row.setTimestamp(row.getLocalDateTime(0).plusNanos(7_000_000))); + }, + "SELECT java_add_timestamp_ms(v) FROM (VALUES " + + "(TIMESTAMP_MS '2024-07-21 12:34:56.123'), (NULL)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, LocalDateTime.class), + LocalDateTime.of(2024, 7, 21, 12, 34, 56, 130_000_000)); + assertTrue(rs.next()); + assertNullRow(rs); + assertFalse(rs.next()); + }); } public static void test_register_scalar_function_timestamp_ms_pre_epoch() throws Exception { @@ -1728,20 +1865,22 @@ public static void test_register_scalar_function_timestamp_ms_pre_epoch() throws } public static void test_register_scalar_function_timestamp_ns() throws Exception { - assertUnaryScalarFunction( - "java_add_timestamp_ns", DuckDBColumnType.TIMESTAMP_NS, DuckDBColumnType.TIMESTAMP_NS, - ctx - -> { ctx.stream().forEachOrdered(row -> row.setTimestamp(row.getLocalDateTime(0).plusNanos(789))); }, - "SELECT java_add_timestamp_ns(v) FROM (VALUES " - + "(TIMESTAMP_NS '2024-07-21 12:34:56.123456789'), (NULL)) t(v)", - rs -> { - assertTrue(rs.next()); - assertEquals(rs.getObject(1, LocalDateTime.class), - LocalDateTime.of(2024, 7, 21, 12, 34, 56, 123457578)); - assertTrue(rs.next()); - assertNullRow(rs); - assertFalse(rs.next()); - }); + assertUnaryScalarFunction("java_add_timestamp_ns", DuckDBColumnType.TIMESTAMP_NS, DuckDBColumnType.TIMESTAMP_NS, + ctx + -> { + ctx.propagateNulls(true).stream().forEachOrdered( + row -> row.setTimestamp(row.getLocalDateTime(0).plusNanos(789))); + }, + "SELECT java_add_timestamp_ns(v) FROM (VALUES " + + "(TIMESTAMP_NS '2024-07-21 12:34:56.123456789'), (NULL)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, LocalDateTime.class), + LocalDateTime.of(2024, 7, 21, 12, 34, 56, 123457578)); + assertTrue(rs.next()); + assertNullRow(rs); + assertFalse(rs.next()); + }); } public static void test_register_scalar_function_timestamptz() throws Exception { @@ -1749,7 +1888,10 @@ public static void test_register_scalar_function_timestamptz() throws Exception "java_add_timestamptz", DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE, DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE, ctx - -> { ctx.stream().forEachOrdered(row -> row.setOffsetDateTime(row.getOffsetDateTime(0).plusMinutes(5))); }, + -> { + ctx.propagateNulls(true).stream().forEachOrdered( + row -> row.setOffsetDateTime(row.getOffsetDateTime(0).plusMinutes(5))); + }, "SELECT java_add_timestamptz(v) FROM (VALUES " + "(TIMESTAMPTZ '2024-07-21 12:34:56.123456+02:00'), (NULL)) t(v)", rs -> { @@ -1784,7 +1926,7 @@ public static void test_register_scalar_function_timestamp_from_java_util_date() ctx -> { long oneSecondMillis = 1000L; - ctx.stream().forEachOrdered( + ctx.propagateNulls(true).stream().forEachOrdered( row -> { row.setTimestamp(new java.util.Date(row.getTimestamp(0).getTime() + oneSecondMillis)); }); }, "SELECT epoch_ms(java_timestamp_from_util_date(v)) FROM (VALUES " @@ -1806,7 +1948,7 @@ public static void test_register_scalar_function_timestamp_from_java_util_date_t "java_timestamp_from_util_ts", DuckDBColumnType.TIMESTAMP, DuckDBColumnType.TIMESTAMP, ctx -> { - ctx.stream().forEachOrdered(row -> { + ctx.propagateNulls(true).stream().forEachOrdered(row -> { java.util.Date value = Timestamp.valueOf(row.getLocalDateTime(0).plusNanos(789000)); row.setTimestamp(value); }); @@ -1860,7 +2002,10 @@ public static void test_register_scalar_function_timestamp_from_local_date() thr assertUnaryScalarFunction( "java_timestamp_from_local_date", DuckDBColumnType.DATE, DuckDBColumnType.TIMESTAMP, ctx - -> { ctx.stream().forEachOrdered(row -> row.setTimestamp(row.getLocalDate(0).plusDays(1))); }, + -> { + ctx.propagateNulls(true).stream().forEachOrdered( + row -> row.setTimestamp(row.getLocalDate(0).plusDays(1))); + }, "SELECT java_timestamp_from_local_date(v) FROM (VALUES (DATE '2024-07-21'), (NULL)) t(v)", rs -> { assertTrue(rs.next()); @@ -1873,45 +2018,46 @@ public static void test_register_scalar_function_timestamp_from_local_date() thr } public static void test_register_scalar_function_varchar() throws Exception { - assertUnaryScalarFunction("java_suffix_varchar", DuckDBColumnType.VARCHAR, DuckDBColumnType.VARCHAR, - ctx - -> { ctx.stream().forEachOrdered(row -> row.setString(row.getString(0) + "_java")); }, - "SELECT java_suffix_varchar(v) FROM (VALUES ('duck'), (NULL), " - + "('abcdefghijklmnop')) t(v)", - rs -> { - assertTrue(rs.next()); - assertEquals(rs.getObject(1, String.class), "duck_java"); - assertTrue(rs.next()); - assertNullRow(rs); - assertTrue(rs.next()); - assertEquals(rs.getObject(1, String.class), "abcdefghijklmnop_java"); - assertFalse(rs.next()); - }); + assertUnaryScalarFunction( + "java_suffix_varchar", DuckDBColumnType.VARCHAR, DuckDBColumnType.VARCHAR, + ctx + -> { ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setString(row.getString(0) + "_java")); }, + "SELECT java_suffix_varchar(v) FROM (VALUES ('duck'), (NULL), " + + "('abcdefghijklmnop')) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, String.class), "duck_java"); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, String.class), "abcdefghijklmnop_java"); + assertFalse(rs.next()); + }); } public static void test_register_scalar_function_varchar_get_string_handles_null() throws Exception { - assertUnaryScalarFunction("java_echo_varchar_nullable", DuckDBColumnType.VARCHAR, DuckDBColumnType.VARCHAR, - false, - ctx - -> { ctx.stream().forEachOrdered(row -> row.setString(row.getString(0))); }, - "SELECT java_echo_varchar_nullable(v) FROM (VALUES ('duck'), (NULL), " - + "('abcdefghijklmnop')) t(v)", - rs -> { - assertTrue(rs.next()); - assertEquals(rs.getObject(1, String.class), "duck"); - assertTrue(rs.next()); - assertNullRow(rs); - assertTrue(rs.next()); - assertEquals(rs.getObject(1, String.class), "abcdefghijklmnop"); - assertFalse(rs.next()); - }); + assertUnaryScalarFunction( + "java_echo_varchar_nullable", DuckDBColumnType.VARCHAR, DuckDBColumnType.VARCHAR, + ctx + -> { ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setString(row.getString(0))); }, + "SELECT java_echo_varchar_nullable(v) FROM (VALUES ('duck'), (NULL), " + + "('abcdefghijklmnop')) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, String.class), "duck"); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, String.class), "abcdefghijklmnop"); + assertFalse(rs.next()); + }); } public static void test_register_scalar_function_varchar_revalidates_after_null() throws Exception { assertUnaryScalarFunction("java_revalidate_varchar", DuckDBColumnType.VARCHAR, DuckDBColumnType.VARCHAR, ctx -> { - ctx.stream().forEachOrdered(row -> { + ctx.propagateNulls(true).stream().forEachOrdered(row -> { row.setNull(); row.setString(row.getString(0) + "_ok"); }); @@ -1930,37 +2076,21 @@ public static void test_register_scalar_function_varchar_revalidates_after_null( private static void assertUnaryScalarFunction(String functionName, DuckDBColumnType parameterType, DuckDBColumnType returnType, DuckDBScalarFunction function, String query, ResultSetVerifier verifier) throws Exception { - assertUnaryScalarFunction(functionName, parameterType, returnType, true, function, query, verifier); - } - - private static void assertUnaryScalarFunction(String functionName, DuckDBColumnType parameterType, - DuckDBColumnType returnType, boolean propagateNulls, - DuckDBScalarFunction function, String query, - ResultSetVerifier verifier) throws Exception { try (DuckDBLogicalType parameterLogicalType = DuckDBLogicalType.of(parameterType); DuckDBLogicalType returnLogicalType = DuckDBLogicalType.of(returnType)) { - assertUnaryScalarFunction(functionName, parameterLogicalType, returnLogicalType, propagateNulls, function, - query, verifier); + assertUnaryScalarFunction(functionName, parameterLogicalType, returnLogicalType, function, query, verifier); } } private static void assertUnaryScalarFunction(String functionName, DuckDBLogicalType parameterType, DuckDBLogicalType returnType, DuckDBScalarFunction function, String query, ResultSetVerifier verifier) throws Exception { - assertUnaryScalarFunction(functionName, parameterType, returnType, true, function, query, verifier); - } - - private static void assertUnaryScalarFunction(String functionName, DuckDBLogicalType parameterType, - DuckDBLogicalType returnType, boolean propagateNulls, - DuckDBScalarFunction function, String query, - ResultSetVerifier verifier) throws Exception { try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); Statement stmt = conn.createStatement()) { DuckDBFunctions.scalarFunction() .withName(functionName) .withParameter(parameterType) .withReturnType(returnType) - .propagateNulls(propagateNulls) .withVectorizedFunction(function) .register(conn); try (ResultSet rs = stmt.executeQuery(query)) { From caa02cd4d7fabf25465f482d2c3e1abc351dabe6 Mon Sep 17 00:00:00 2001 From: Luis Fernando Kauer Date: Tue, 7 Apr 2026 23:49:38 -0300 Subject: [PATCH 9/9] Apply formatting-only cleanup after format-check Run project formatter and keep only style/layout changes required by format-check, with no functional modifications. --- .../java/org/duckdb/DuckDBLogicalType.java | 2 +- .../org/duckdb/DuckDBReadableVectorImpl.java | 12 ++- .../org/duckdb/DuckDBRegisteredFunction.java | 9 +- .../duckdb/DuckDBScalarFunctionAdapter.java | 18 ++-- src/main/java/org/duckdb/DuckDBScalarRow.java | 1 - .../org/duckdb/DuckDBWritableVectorImpl.java | 8 +- src/test/java/org/duckdb/TestBindings.java | 4 +- .../java/org/duckdb/TestScalarFunctions.java | 86 +++++++++---------- 8 files changed, 63 insertions(+), 77 deletions(-) diff --git a/src/main/java/org/duckdb/DuckDBLogicalType.java b/src/main/java/org/duckdb/DuckDBLogicalType.java index 26a0f7436..2d4f53b10 100644 --- a/src/main/java/org/duckdb/DuckDBLogicalType.java +++ b/src/main/java/org/duckdb/DuckDBLogicalType.java @@ -1,7 +1,7 @@ package org.duckdb; -import static org.duckdb.DuckDBBindings.CAPIType.*; import static org.duckdb.DuckDBBindings.*; +import static org.duckdb.DuckDBBindings.CAPIType.*; import java.nio.ByteBuffer; import java.sql.SQLException; diff --git a/src/main/java/org/duckdb/DuckDBReadableVectorImpl.java b/src/main/java/org/duckdb/DuckDBReadableVectorImpl.java index fe4c322e3..58b0a3f03 100644 --- a/src/main/java/org/duckdb/DuckDBReadableVectorImpl.java +++ b/src/main/java/org/duckdb/DuckDBReadableVectorImpl.java @@ -38,7 +38,8 @@ final class DuckDBReadableVectorImpl extends DuckDBReadableVector { } catch (java.sql.SQLException exception) { throw new DuckDBFunctionException("Failed to resolve vector type info", exception); } - this.data = duckdb_vector_get_data(vectorRef, Math.multiplyExact(rowCount, typeInfo.widthBytes)).order(NATIVE_ORDER); + this.data = + duckdb_vector_get_data(vectorRef, Math.multiplyExact(rowCount, typeInfo.widthBytes)).order(NATIVE_ORDER); ByteBuffer validityBuffer = duckdb_vector_get_validity(vectorRef, rowCount); this.validity = validityBuffer == null ? null : validityBuffer.order(NATIVE_ORDER); } @@ -320,14 +321,11 @@ public BigDecimal getBigDecimal(long row) { } switch (typeInfo.storageType) { case DUCKDB_TYPE_SMALLINT: - return BigDecimal.valueOf(data.getShort(checkedByteOffset(row, Short.BYTES)), - typeInfo.decimalMeta.scale); + return BigDecimal.valueOf(data.getShort(checkedByteOffset(row, Short.BYTES)), typeInfo.decimalMeta.scale); case DUCKDB_TYPE_INTEGER: - return BigDecimal.valueOf(data.getInt(checkedByteOffset(row, Integer.BYTES)), - typeInfo.decimalMeta.scale); + return BigDecimal.valueOf(data.getInt(checkedByteOffset(row, Integer.BYTES)), typeInfo.decimalMeta.scale); case DUCKDB_TYPE_BIGINT: - return BigDecimal.valueOf(data.getLong(checkedByteOffset(row, Long.BYTES)), - typeInfo.decimalMeta.scale); + return BigDecimal.valueOf(data.getLong(checkedByteOffset(row, Long.BYTES)), typeInfo.decimalMeta.scale); case DUCKDB_TYPE_HUGEINT: { int offset = checkedByteOffset(row, typeInfo.widthBytes); long lower = data.getLong(offset); diff --git a/src/main/java/org/duckdb/DuckDBRegisteredFunction.java b/src/main/java/org/duckdb/DuckDBRegisteredFunction.java index 517b7908c..f11f6c968 100644 --- a/src/main/java/org/duckdb/DuckDBRegisteredFunction.java +++ b/src/main/java/org/duckdb/DuckDBRegisteredFunction.java @@ -89,9 +89,10 @@ static DuckDBRegisteredFunction of(String name, List paramete DuckDBColumnType returnColumnType, DuckDBScalarFunction function, DuckDBLogicalType varArgType, boolean volatileFlag, boolean specialHandlingFlag, boolean propagateNullsFlag) { - return new DuckDBRegisteredFunction( - name, DuckDBFunctions.DuckDBFunctionKind.SCALAR, Collections.unmodifiableList(new ArrayList<>(parameterTypes)), - Collections.unmodifiableList(new ArrayList<>(parameterColumnTypes)), returnType, returnColumnType, function, - varArgType, volatileFlag, specialHandlingFlag, propagateNullsFlag); + return new DuckDBRegisteredFunction(name, DuckDBFunctions.DuckDBFunctionKind.SCALAR, + Collections.unmodifiableList(new ArrayList<>(parameterTypes)), + Collections.unmodifiableList(new ArrayList<>(parameterColumnTypes)), + returnType, returnColumnType, function, varArgType, volatileFlag, + specialHandlingFlag, propagateNullsFlag); } } diff --git a/src/main/java/org/duckdb/DuckDBScalarFunctionAdapter.java b/src/main/java/org/duckdb/DuckDBScalarFunctionAdapter.java index f95f9599f..f67d20a72 100644 --- a/src/main/java/org/duckdb/DuckDBScalarFunctionAdapter.java +++ b/src/main/java/org/duckdb/DuckDBScalarFunctionAdapter.java @@ -190,8 +190,7 @@ static DuckDBScalarFunction intUnary(IntUnaryOperator function) { out.setInt(row, function.applyAsInt(in.getInt(row))); } } catch (DuckDBFunctionException exception) { - throw new DuckDBFunctionException("Failed to execute withIntFunction at row " + row, - exception); + throw new DuckDBFunctionException("Failed to execute withIntFunction at row " + row, exception); } } }; @@ -214,8 +213,7 @@ static DuckDBScalarFunction intBinary(IntBinaryOperator function) { out.setInt(row, function.applyAsInt(left.getInt(row), right.getInt(row))); } } catch (DuckDBFunctionException exception) { - throw new DuckDBFunctionException("Failed to execute withIntFunction at row " + row, - exception); + throw new DuckDBFunctionException("Failed to execute withIntFunction at row " + row, exception); } } }; @@ -237,8 +235,7 @@ static DuckDBScalarFunction doubleUnary(DoubleUnaryOperator function) { out.setDouble(row, function.applyAsDouble(in.getDouble(row))); } } catch (DuckDBFunctionException exception) { - throw new DuckDBFunctionException("Failed to execute withDoubleFunction at row " + row, - exception); + throw new DuckDBFunctionException("Failed to execute withDoubleFunction at row " + row, exception); } } }; @@ -261,8 +258,7 @@ static DuckDBScalarFunction doubleBinary(DoubleBinaryOperator function) { out.setDouble(row, function.applyAsDouble(left.getDouble(row), right.getDouble(row))); } } catch (DuckDBFunctionException exception) { - throw new DuckDBFunctionException("Failed to execute withDoubleFunction at row " + row, - exception); + throw new DuckDBFunctionException("Failed to execute withDoubleFunction at row " + row, exception); } } }; @@ -284,8 +280,7 @@ static DuckDBScalarFunction longUnary(LongUnaryOperator function) { out.setLong(row, function.applyAsLong(in.getLong(row))); } } catch (DuckDBFunctionException exception) { - throw new DuckDBFunctionException("Failed to execute withLongFunction at row " + row, - exception); + throw new DuckDBFunctionException("Failed to execute withLongFunction at row " + row, exception); } } }; @@ -308,8 +303,7 @@ static DuckDBScalarFunction longBinary(LongBinaryOperator function) { out.setLong(row, function.applyAsLong(left.getLong(row), right.getLong(row))); } } catch (DuckDBFunctionException exception) { - throw new DuckDBFunctionException("Failed to execute withLongFunction at row " + row, - exception); + throw new DuckDBFunctionException("Failed to execute withLongFunction at row " + row, exception); } } }; diff --git a/src/main/java/org/duckdb/DuckDBScalarRow.java b/src/main/java/org/duckdb/DuckDBScalarRow.java index 5033faaba..7b51cf29e 100644 --- a/src/main/java/org/duckdb/DuckDBScalarRow.java +++ b/src/main/java/org/duckdb/DuckDBScalarRow.java @@ -468,5 +468,4 @@ private DuckDBFunctionException readFailure(String type, int columnIndex, DuckDB private DuckDBFunctionException writeFailure(String type, DuckDBFunctionException exception) { return new DuckDBFunctionException("Failed to write " + type + " to output row " + rowIndex, exception); } - } diff --git a/src/main/java/org/duckdb/DuckDBWritableVectorImpl.java b/src/main/java/org/duckdb/DuckDBWritableVectorImpl.java index 1b7e8cb13..33887abb7 100644 --- a/src/main/java/org/duckdb/DuckDBWritableVectorImpl.java +++ b/src/main/java/org/duckdb/DuckDBWritableVectorImpl.java @@ -37,7 +37,8 @@ final class DuckDBWritableVectorImpl extends DuckDBWritableVector { } catch (java.sql.SQLException exception) { throw new DuckDBFunctionException("Failed to resolve vector type info", exception); } - this.data = duckdb_vector_get_data(vectorRef, Math.multiplyExact(rowCount, typeInfo.widthBytes)).order(NATIVE_ORDER); + this.data = + duckdb_vector_get_data(vectorRef, Math.multiplyExact(rowCount, typeInfo.widthBytes)).order(NATIVE_ORDER); ByteBuffer validityBuffer = duckdb_vector_get_validity(vectorRef, rowCount); this.validity = validityBuffer == null ? null : validityBuffer.order(NATIVE_ORDER); } @@ -527,8 +528,9 @@ public void setOffsetDateTime(long row, OffsetDateTime value) { setNull(row); return; } - data.putLong(checkedByteOffset(row, Long.BYTES), - DuckDBTimestamp.localDateTime2Micros(value.withOffsetSameInstant(ZoneOffset.UTC).toLocalDateTime())); + data.putLong( + checkedByteOffset(row, Long.BYTES), + DuckDBTimestamp.localDateTime2Micros(value.withOffsetSameInstant(ZoneOffset.UTC).toLocalDateTime())); markValid(row); } diff --git a/src/test/java/org/duckdb/TestBindings.java b/src/test/java/org/duckdb/TestBindings.java index a0f4f1278..8bdf0ad3f 100644 --- a/src/test/java/org/duckdb/TestBindings.java +++ b/src/test/java/org/duckdb/TestBindings.java @@ -34,9 +34,7 @@ public static void test_bindings_vector_row_index_stream() throws Exception { DuckDBReadableVector readable = new DuckDBReadableVectorImpl(inputVec, 3); DuckDBWritableVector output = new DuckDBWritableVectorImpl(outputVec, 3); - readable.rowIndexStream().forEachOrdered(row -> { - output.setInt(row, readable.getInt(row) + 1); - }); + readable.rowIndexStream().forEachOrdered(row -> { output.setInt(row, readable.getInt(row) + 1); }); DuckDBReadableVector result = new DuckDBReadableVectorImpl(outputVec, 3); assertEquals(result.getInt(0), 2); diff --git a/src/test/java/org/duckdb/TestScalarFunctions.java b/src/test/java/org/duckdb/TestScalarFunctions.java index a9c3afbc8..8c4fc856e 100644 --- a/src/test/java/org/duckdb/TestScalarFunctions.java +++ b/src/test/java/org/duckdb/TestScalarFunctions.java @@ -1007,22 +1007,20 @@ public static void test_register_scalar_function_builder_java_function_supported } public static void test_register_scalar_function_builder_java_function_hugeint_class_mapping() throws Exception { - Function addHugeInt = - value -> null != value ? value.add(BigInteger.ONE) : null; - assertUnaryJavaFunction( - "java_add_hugeint_function", BigInteger.class, BigInteger.class, addHugeInt, - "SELECT java_add_hugeint_function(v) FROM (VALUES (CAST('41' AS HUGEINT)), (NULL), " - + "(CAST('170141183460469231731687303715884105726' AS HUGEINT))) t(v)", - rs -> { - assertTrue(rs.next()); - assertEquals(rs.getObject(1, BigInteger.class), new BigInteger("42")); - assertTrue(rs.next()); - assertNullRow(rs); - assertTrue(rs.next()); - assertEquals(rs.getObject(1, BigInteger.class), - new BigInteger("170141183460469231731687303715884105727")); - assertFalse(rs.next()); - }); + Function addHugeInt = value -> null != value ? value.add(BigInteger.ONE) : null; + assertUnaryJavaFunction("java_add_hugeint_function", BigInteger.class, BigInteger.class, addHugeInt, + "SELECT java_add_hugeint_function(v) FROM (VALUES (CAST('41' AS HUGEINT)), (NULL), " + + "(CAST('170141183460469231731687303715884105726' AS HUGEINT))) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, BigInteger.class), new BigInteger("42")); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, BigInteger.class), + new BigInteger("170141183460469231731687303715884105727")); + assertFalse(rs.next()); + }); } public static void test_register_scalar_function_builder_java_function_decimal() throws Exception { @@ -1255,7 +1253,8 @@ public static void test_register_scalar_function_primitive_nulls_handling() thro try { DuckDBReadableVector booleanVector = ctx.input(0); DuckDBReadableVector intVector = ctx.input(5); - assertThrows(() -> { booleanVector.getBoolean(row.index()); }, DuckDBFunctionException.class); + assertThrows( + () -> { booleanVector.getBoolean(row.index()); }, DuckDBFunctionException.class); assertTrue(booleanVector.getBoolean(row.index(), true)); assertThrows(() -> { intVector.getInt(row.index()); }, DuckDBFunctionException.class); assertEquals(intVector.getInt(row.index(), 42), 42); @@ -1268,8 +1267,7 @@ public static void test_register_scalar_function_primitive_nulls_handling() thro assertTrue(exception.getMessage().contains("Failed to read BOOLEAN")); assertNotNull(exception.getCause()); assertTrue(exception.getCause() instanceof DuckDBFunctionException); - assertTrue( - exception.getCause().getMessage().contains("Primitive value for BOOLEAN")); + assertTrue(exception.getCause().getMessage().contains("Primitive value for BOOLEAN")); } assertTrue(row.getBoolean(0, true)); assertThrows(() -> { row.getByte(1); }, DuckDBFunctionException.class); @@ -1576,26 +1574,25 @@ public static void test_register_scalar_function_ubigint() throws Exception { } public static void test_register_scalar_function_uhugeint() throws Exception { - assertUnaryScalarFunction( - "java_add_uhugeint", DuckDBColumnType.UHUGEINT, DuckDBColumnType.UHUGEINT, - ctx - -> { - BigInteger increment = BigInteger.ONE; - ctx.propagateNulls(true).stream().forEachOrdered( - row -> row.setUHugeInt(row.getUHugeInt(0).add(increment))); - }, - "SELECT java_add_uhugeint(v) FROM (VALUES (CAST('41' AS UHUGEINT)), (NULL), " - + "(CAST('340282366920938463463374607431768211454' AS UHUGEINT))) t(v)", - rs -> { - assertTrue(rs.next()); - assertEquals(rs.getObject(1, BigInteger.class), new BigInteger("42")); - assertTrue(rs.next()); - assertNullRow(rs); - assertTrue(rs.next()); - assertEquals(rs.getObject(1, BigInteger.class), - new BigInteger("340282366920938463463374607431768211455")); - assertFalse(rs.next()); - }); + assertUnaryScalarFunction("java_add_uhugeint", DuckDBColumnType.UHUGEINT, DuckDBColumnType.UHUGEINT, + ctx + -> { + BigInteger increment = BigInteger.ONE; + ctx.propagateNulls(true).stream().forEachOrdered( + row -> row.setUHugeInt(row.getUHugeInt(0).add(increment))); + }, + "SELECT java_add_uhugeint(v) FROM (VALUES (CAST('41' AS UHUGEINT)), (NULL), " + + "(CAST('340282366920938463463374607431768211454' AS UHUGEINT))) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, BigInteger.class), new BigInteger("42")); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, BigInteger.class), + new BigInteger("340282366920938463463374607431768211455")); + assertFalse(rs.next()); + }); } public static void test_register_scalar_function_builder_java_function_uhugeint() throws Exception { @@ -1605,15 +1602,12 @@ public static void test_register_scalar_function_builder_java_function_uhugeint( .withName("java_add_uhugeint_function") .withParameter(DuckDBColumnType.UHUGEINT) .withReturnType(DuckDBColumnType.UHUGEINT) - .withFunction( - (BigInteger value) - -> null != value ? value.add(BigInteger.ONE) : null) + .withFunction((BigInteger value) -> null != value ? value.add(BigInteger.ONE) : null) .register(conn); - try ( - ResultSet rs = stmt.executeQuery( - "SELECT java_add_uhugeint_function(v) FROM (VALUES (CAST('41' AS UHUGEINT)), (NULL), " - + "(CAST('340282366920938463463374607431768211454' AS UHUGEINT))) t(v)")) { + try (ResultSet rs = stmt.executeQuery( + "SELECT java_add_uhugeint_function(v) FROM (VALUES (CAST('41' AS UHUGEINT)), (NULL), " + + "(CAST('340282366920938463463374607431768211454' AS UHUGEINT))) t(v)")) { assertTrue(rs.next()); assertEquals(rs.getObject(1, BigInteger.class), new BigInteger("42")); assertTrue(rs.next());