From b8908f07908dc4155760522393d003eec53166b3 Mon Sep 17 00:00:00 2001 From: Alex Kasko Date: Fri, 10 Apr 2026 23:32:35 +0100 Subject: [PATCH] Scalar functions: simplify JNI handling This PR is a follow-up to #630 and #637. It removes JNI utilities specific to scalar functions in favour of more generic `GlobalRefHolder` utility. Testing: no functional changes, no new tests --- CMakeLists.txt | 2 +- CMakeLists.txt.in | 2 +- duckdb_java.def | 3 +- duckdb_java.exp | 3 +- duckdb_java.map | 3 +- src/jni/bindings_scalar_function.cpp | 129 ++++++++- src/jni/holders.cpp | 120 ++++++++ src/jni/holders.hpp | 75 +++-- src/jni/refs.cpp | 8 + src/jni/refs.hpp | 3 + src/jni/scalar_functions.cpp | 271 ------------------ src/jni/scalar_functions.hpp | 5 - src/main/java/org/duckdb/DuckDBBindings.java | 4 +- .../duckdb/DuckDBScalarFunctionBuilder.java | 4 +- .../duckdb/DuckDBScalarFunctionWrapper.java | 3 +- 15 files changed, 305 insertions(+), 330 deletions(-) create mode 100644 src/jni/holders.cpp delete mode 100644 src/jni/scalar_functions.cpp delete mode 100644 src/jni/scalar_functions.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 7bd2316a9..6e51ce504 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -597,8 +597,8 @@ add_library(duckdb_java SHARED src/jni/config.cpp src/jni/duckdb_java.cpp src/jni/functions.cpp + src/jni/holders.cpp src/jni/refs.cpp - src/jni/scalar_functions.cpp src/jni/types.cpp src/jni/util.cpp ${DUCKDB_SRC_FILES}) diff --git a/CMakeLists.txt.in b/CMakeLists.txt.in index 751076c19..928575a56 100644 --- a/CMakeLists.txt.in +++ b/CMakeLists.txt.in @@ -115,8 +115,8 @@ add_library(duckdb_java SHARED src/jni/config.cpp src/jni/duckdb_java.cpp src/jni/functions.cpp + src/jni/holders.cpp src/jni/refs.cpp - src/jni/scalar_functions.cpp src/jni/types.cpp src/jni/util.cpp ${DUCKDB_SRC_FILES}) diff --git a/duckdb_java.def b/duckdb_java.def index 7f340311b..503344a45 100644 --- a/duckdb_java.def +++ b/duckdb_java.def @@ -60,6 +60,8 @@ Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1return_1type Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1varargs Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1volatile Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1special_1handling +Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1extra_1info +Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1function Java_org_duckdb_DuckDBBindings_duckdb_1register_1scalar_1function Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1error Java_org_duckdb_DuckDBBindings_duckdb_1create_1logical_1type @@ -111,7 +113,6 @@ Java_org_duckdb_DuckDBBindings_duckdb_1appender_1column_1count Java_org_duckdb_DuckDBBindings_duckdb_1appender_1column_1type Java_org_duckdb_DuckDBBindings_duckdb_1append_1data_1chunk Java_org_duckdb_DuckDBBindings_duckdb_1append_1default_1to_1chunk -Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1function duckdb_adbc_init duckdb_add_aggregate_function_to_set diff --git a/duckdb_java.exp b/duckdb_java.exp index b0611b6f6..5d2f47f4f 100644 --- a/duckdb_java.exp +++ b/duckdb_java.exp @@ -57,13 +57,14 @@ _Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1return_1type _Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1varargs _Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1volatile _Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1special_1handling -_Java_org_duckdb_DuckDBBindings_duckdb_1register_1scalar_1function _Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1error _Java_org_duckdb_DuckDBBindings_duckdb_1vector_1get_1string _Java_org_duckdb_DuckDBBindings_duckdb_1vector_1get_1string__Ljava_nio_ByteBuffer_2J _Java_org_duckdb_DuckDBBindings_duckdb_1create_1logical_1type _Java_org_duckdb_DuckDBBindings_duckdb_1create_1decimal_1type +_Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1extra_1info _Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1function +_Java_org_duckdb_DuckDBBindings_duckdb_1register_1scalar_1function _Java_org_duckdb_DuckDBBindings_duckdb_1get_1type_1id _Java_org_duckdb_DuckDBBindings_duckdb_1decimal_1width _Java_org_duckdb_DuckDBBindings_duckdb_1decimal_1scale diff --git a/duckdb_java.map b/duckdb_java.map index 0fc9c6fc7..435064f74 100644 --- a/duckdb_java.map +++ b/duckdb_java.map @@ -59,6 +59,8 @@ DUCKDB_JAVA { Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1varargs; Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1volatile; Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1special_1handling; + Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1extra_1info; + Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1function; Java_org_duckdb_DuckDBBindings_duckdb_1register_1scalar_1function; Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1error; Java_org_duckdb_DuckDBBindings_duckdb_1vector_1get_1string; @@ -110,7 +112,6 @@ DUCKDB_JAVA { Java_org_duckdb_DuckDBBindings_duckdb_1appender_1column_1type; Java_org_duckdb_DuckDBBindings_duckdb_1append_1data_1chunk; Java_org_duckdb_DuckDBBindings_duckdb_1append_1default_1to_1chunk; - Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1function; duckdb_adbc_init; duckdb_add_aggregate_function_to_set; diff --git a/src/jni/bindings_scalar_function.cpp b/src/jni/bindings_scalar_function.cpp index 5fb9c0b7a..369b6fb7e 100644 --- a/src/jni/bindings_scalar_function.cpp +++ b/src/jni/bindings_scalar_function.cpp @@ -2,7 +2,6 @@ #include "functions.hpp" #include "holders.hpp" #include "refs.hpp" -#include "scalar_functions.hpp" #include "util.hpp" static duckdb_scalar_function scalar_function_buf_to_scalar_function(JNIEnv *env, jobject scalar_function_buf) { @@ -36,10 +35,20 @@ static duckdb_function_info function_info_buf_to_function_info(JNIEnv *env, jobj return function_info; } +/* + * Class: org_duckdb_DuckDBBindings + * Method: duckdb_create_scalar_function + * Signature: ()Ljava/nio/ByteBuffer; + */ JNIEXPORT jobject JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1create_1scalar_1function(JNIEnv *env, jclass) { return make_ptr_buf(env, duckdb_create_scalar_function()); } +/* + * Class: org_duckdb_DuckDBBindings + * Method: duckdb_destroy_scalar_function + * Signature: (Ljava/nio/ByteBuffer;)V + */ JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1destroy_1scalar_1function(JNIEnv *env, jclass, jobject scalar_function) { auto function = scalar_function_buf_to_scalar_function(env, scalar_function); @@ -49,6 +58,11 @@ JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1destroy_1scalar_1f duckdb_destroy_scalar_function(&function); } +/* + * Class: org_duckdb_DuckDBBindings + * Method: duckdb_scalar_function_set_name + * Signature: (Ljava/nio/ByteBuffer;[B)V + */ JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1name(JNIEnv *env, jclass, jobject scalar_function, jbyteArray name) { @@ -67,6 +81,11 @@ JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1 duckdb_scalar_function_set_name(function, function_name.c_str()); } +/* + * Class: org_duckdb_DuckDBBindings + * Method: duckdb_scalar_function_add_parameter + * Signature: (Ljava/nio/ByteBuffer;Ljava/nio/ByteBuffer;)V + */ JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1add_1parameter(JNIEnv *env, jclass, jobject scalar_function, jobject logical_type) { @@ -81,6 +100,11 @@ JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1 duckdb_scalar_function_add_parameter(function, type); } +/* + * Class: org_duckdb_DuckDBBindings + * Method: duckdb_scalar_function_set_return_type + * Signature: (Ljava/nio/ByteBuffer;Ljava/nio/ByteBuffer;)V + */ JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1return_1type( JNIEnv *env, jclass, jobject scalar_function, jobject logical_type) { auto function = scalar_function_buf_to_scalar_function(env, scalar_function); @@ -94,6 +118,11 @@ JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1 duckdb_scalar_function_set_return_type(function, type); } +/* + * Class: org_duckdb_DuckDBBindings + * Method: duckdb_scalar_function_set_varargs + * Signature: (Ljava/nio/ByteBuffer;Ljava/nio/ByteBuffer;)V + */ JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1varargs(JNIEnv *env, jclass, jobject scalar_function, jobject logical_type) { @@ -108,6 +137,11 @@ JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1 duckdb_scalar_function_set_varargs(function, type); } +/* + * Class: org_duckdb_DuckDBBindings + * Method: duckdb_scalar_function_set_volatile + * Signature: (Ljava/nio/ByteBuffer;)V + */ JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1volatile(JNIEnv *env, jclass, jobject scalar_function) { auto function = scalar_function_buf_to_scalar_function(env, scalar_function); @@ -117,6 +151,11 @@ JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1 duckdb_scalar_function_set_volatile(function); } +/* + * Class: org_duckdb_DuckDBBindings + * Method: duckdb_scalar_function_set_special_handling + * Signature: (Ljava/nio/ByteBuffer;)V + */ JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1special_1handling( JNIEnv *env, jclass, jobject scalar_function) { auto function = scalar_function_buf_to_scalar_function(env, scalar_function); @@ -126,6 +165,11 @@ JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1 duckdb_scalar_function_set_special_handling(function); } +/* + * Class: org_duckdb_DuckDBBindings + * Method: duckdb_register_scalar_function + * Signature: (Ljava/nio/ByteBuffer;Ljava/nio/ByteBuffer;)I + */ JNIEXPORT jint JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1register_1scalar_1function(JNIEnv *env, jclass, jobject connection, jobject scalar_function) { @@ -140,16 +184,85 @@ JNIEXPORT jint JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1register_1scalar_1 return static_cast(duckdb_register_scalar_function(conn, function)); } -JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1function( - JNIEnv *env, jclass, jobject scalar_function_buf, jobject function_j) { - try { - scalar_function_set_function(env, scalar_function_buf, function_j); - } catch (const std::exception &e) { - duckdb::ErrorData error(e); - ThrowJNI(env, error.Message().c_str()); +/* + * Class: org_duckdb_DuckDBBindings + * Method: duckdb_scalar_function_set_extra_info + * Signature: (Ljava/nio/ByteBuffer;Ljava/lang/Object;)V + */ +JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1extra_1info( + JNIEnv *env, jclass, jobject scalar_function, jobject callback) { + + if (callback == nullptr) { + env->ThrowNew(J_SQLException, "Specified callback must be not null"); + return; + } + + auto sf = scalar_function_buf_to_scalar_function(env, scalar_function); + if (env->ExceptionCheck()) { + return; + } + + auto callback_holder = std::unique_ptr(new GlobalRefHolder(env, callback)); + if (callback_holder->vm == nullptr) { + env->ThrowNew(J_SQLException, "Unable to create a global reference to the specified scalar function callback"); + return; } + + duckdb_scalar_function_set_extra_info(sf, callback_holder.release(), GlobalRefHolder::destroy); +} + +/* + * Class: org_duckdb_DuckDBBindings + * Method: duckdb_scalar_function_set_function + * Signature: (Ljava/nio/ByteBuffer;)V + */ +JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1function(JNIEnv *env, jclass, + jobject scalar_function) { + auto sf = scalar_function_buf_to_scalar_function(env, scalar_function); + if (env->ExceptionCheck()) { + return; + } + + duckdb_scalar_function_set_function( + sf, [](duckdb_function_info info, duckdb_data_chunk input, duckdb_vector output) { + auto callback_holder = reinterpret_cast(duckdb_scalar_function_get_extra_info(info)); + AttachedJNIEnv attached = callback_holder->attach_current_thread(); + if (attached.env == nullptr) { + duckdb_scalar_function_set_error(info, "Unable to attach JNI environment"); + return; + } + jobject info_buf = make_ptr_buf(attached.env, info); + if (info_buf == nullptr) { + duckdb_scalar_function_set_error(info, "Unable to create function info buffer"); + return; + } + LocalRefHolder info_buf_holder(attached.env, info_buf); + jobject input_buf = make_ptr_buf(attached.env, input); + if (input_buf == nullptr) { + duckdb_scalar_function_set_error(info, "Unable to create input buffer"); + return; + } + LocalRefHolder input_buf_holder(attached.env, input_buf); + jobject output_buf = make_ptr_buf(attached.env, output); + if (output_buf == nullptr) { + duckdb_scalar_function_set_error(info, "Unable to create output buffer"); + return; + } + LocalRefHolder output_buf_holder(attached.env, output_buf); + + attached.env->CallVoidMethod(callback_holder->global_ref, J_DuckDBScalarFunctionWrapper_execute, info_buf, + input_buf, output_buf); + if (attached.env->ExceptionCheck()) { + duckdb_scalar_function_set_error(info, "Java callback system error"); + } + }); } +/* + * Class: org_duckdb_DuckDBBindings + * Method: duckdb_scalar_function_set_error + * Signature: (Ljava/nio/ByteBuffer;[B)V + */ JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1error(JNIEnv *env, jclass, jobject function_info_buf, jbyteArray error) { diff --git a/src/jni/holders.cpp b/src/jni/holders.cpp new file mode 100644 index 000000000..a8b3661dc --- /dev/null +++ b/src/jni/holders.cpp @@ -0,0 +1,120 @@ +#include "holders.hpp" + +ConnectionHolder *get_connection_ref(JNIEnv *env, jobject conn_ref_buf) { + if (!conn_ref_buf) { + throw duckdb::ConnectionException("Invalid connection buffer ref"); + } + auto conn_holder = reinterpret_cast(env->GetDirectBufferAddress(conn_ref_buf)); + if (!conn_holder) { + throw duckdb::ConnectionException("Invalid connection buffer"); + } + return conn_holder; +} + +/** + * Throws a SQLException and returns nullptr if a valid Connection can't be retrieved from the buffer. + */ +duckdb::Connection *get_connection(JNIEnv *env, jobject conn_ref_buf) { + auto conn_holder = get_connection_ref(env, conn_ref_buf); + auto conn_ref = conn_holder->connection.get(); + if (!conn_ref || !conn_ref->context) { + throw duckdb::ConnectionException("Invalid connection"); + } + + return conn_ref; +} + +duckdb_connection conn_ref_buf_to_conn(JNIEnv *env, jobject conn_ref_buf) { + if (conn_ref_buf == nullptr) { + env->ThrowNew(J_SQLException, "Invalid connection buffer"); + return nullptr; + } + auto conn_holder = reinterpret_cast(env->GetDirectBufferAddress(conn_ref_buf)); + if (conn_holder == nullptr) { + env->ThrowNew(J_SQLException, "Invalid connection holder"); + return nullptr; + } + auto conn_ref = conn_holder->connection.get(); + if (conn_ref == nullptr || conn_ref->context == nullptr) { + env->ThrowNew(J_SQLException, "Invalid connection"); + return nullptr; + } + + return reinterpret_cast(conn_ref); +} + +AttachedJNIEnv::AttachedJNIEnv() { +} + +AttachedJNIEnv::AttachedJNIEnv(JavaVM *vm_in, JNIEnv *env_in, bool need_to_detach_in) + : vm(vm_in), env(env_in), need_to_detach(need_to_detach_in) { +} + +AttachedJNIEnv::~AttachedJNIEnv() noexcept { + if (vm == nullptr) { + return; + } + if (need_to_detach) { + vm->DetachCurrentThread(); + } +} + +GlobalRefHolder::GlobalRefHolder(JNIEnv *env, jobject local_ref) { + if (env->GetJavaVM(&this->vm) != JNI_OK || this->vm == nullptr) { + this->vm = nullptr; + return; + } + if (local_ref != nullptr) { + this->global_ref = env->NewGlobalRef(local_ref); + if (this->global_ref == nullptr) { + this->vm = nullptr; + } + } +} + +GlobalRefHolder::~GlobalRefHolder() noexcept { + if (global_ref == nullptr) { + return; + } + AttachedJNIEnv attached = attach_current_thread(); + if (attached.env == nullptr) { + return; + } + attached.env->DeleteGlobalRef(global_ref); +} + +AttachedJNIEnv GlobalRefHolder::attach_current_thread() { + if (vm == nullptr) { + return AttachedJNIEnv(); + } + JNIEnv *env = nullptr; + auto env_status = vm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6); + if (env_status != JNI_OK && env_status != JNI_EDETACHED) { + return AttachedJNIEnv(); + } + bool need_to_detach = false; + if (env_status == JNI_EDETACHED) { + auto attach_status = vm->AttachCurrentThread(reinterpret_cast(&env), nullptr); + if (attach_status != JNI_OK || env == nullptr) { + return AttachedJNIEnv(); + } + need_to_detach = true; + } + + return AttachedJNIEnv(vm, env, need_to_detach); +} + +void GlobalRefHolder::destroy(void *holder_in) noexcept { + auto holder = reinterpret_cast(holder_in); + delete holder; +} + +LocalRefHolder::LocalRefHolder(JNIEnv *env_in, jobject local_ref_in) : env(env_in), local_ref(local_ref_in) { +} + +LocalRefHolder::~LocalRefHolder() noexcept { + if (env == nullptr || local_ref == nullptr) { + return; + } + env->DeleteLocalRef(local_ref); +} diff --git a/src/jni/holders.hpp b/src/jni/holders.hpp index 74e0744cf..15a37a21b 100644 --- a/src/jni/holders.hpp +++ b/src/jni/holders.hpp @@ -55,45 +55,44 @@ struct ResultHolder { duckdb::unique_ptr chunk; }; -inline ConnectionHolder *get_connection_ref(JNIEnv *env, jobject conn_ref_buf) { - if (!conn_ref_buf) { - throw duckdb::ConnectionException("Invalid connection buffer ref"); - } - auto conn_holder = reinterpret_cast(env->GetDirectBufferAddress(conn_ref_buf)); - if (!conn_holder) { - throw duckdb::ConnectionException("Invalid connection buffer"); - } - return conn_holder; -} +struct AttachedJNIEnv { + JavaVM *vm = nullptr; + JNIEnv *env = nullptr; + bool need_to_detach = false; -/** - * Throws a SQLException and returns nullptr if a valid Connection can't be retrieved from the buffer. - */ -inline duckdb::Connection *get_connection(JNIEnv *env, jobject conn_ref_buf) { - auto conn_holder = get_connection_ref(env, conn_ref_buf); - auto conn_ref = conn_holder->connection.get(); - if (!conn_ref || !conn_ref->context) { - throw duckdb::ConnectionException("Invalid connection"); - } + AttachedJNIEnv(); - return conn_ref; -} + AttachedJNIEnv(JavaVM *vm_in, JNIEnv *env_in, bool need_to_detach_in); -inline duckdb_connection conn_ref_buf_to_conn(JNIEnv *env, jobject conn_ref_buf) { - if (conn_ref_buf == nullptr) { - env->ThrowNew(J_SQLException, "Invalid connection buffer"); - return nullptr; - } - auto conn_holder = reinterpret_cast(env->GetDirectBufferAddress(conn_ref_buf)); - if (conn_holder == nullptr) { - env->ThrowNew(J_SQLException, "Invalid connection holder"); - return nullptr; - } - auto conn_ref = conn_holder->connection.get(); - if (conn_ref == nullptr || conn_ref->context == nullptr) { - env->ThrowNew(J_SQLException, "Invalid connection"); - return nullptr; - } + ~AttachedJNIEnv() noexcept; +}; - return reinterpret_cast(conn_ref); -} +struct GlobalRefHolder { + JavaVM *vm = nullptr; + jobject global_ref = nullptr; + + GlobalRefHolder(JNIEnv *env, jobject local_ref); + + ~GlobalRefHolder() noexcept; + + AttachedJNIEnv attach_current_thread(); + + void detach_current_thread(); + + static void destroy(void *holder_in) noexcept; +}; + +struct LocalRefHolder { + JNIEnv *env = nullptr; + jobject local_ref = nullptr; + + LocalRefHolder(JNIEnv *env_in, jobject local_ref_in); + + ~LocalRefHolder() noexcept; +}; + +ConnectionHolder *get_connection_ref(JNIEnv *env, jobject conn_ref_buf); + +duckdb::Connection *get_connection(JNIEnv *env, jobject conn_ref_buf); + +duckdb_connection conn_ref_buf_to_conn(JNIEnv *env, jobject conn_ref_buf); diff --git a/src/jni/refs.cpp b/src/jni/refs.cpp index 1dc6cbafe..e1a881d2c 100644 --- a/src/jni/refs.cpp +++ b/src/jni/refs.cpp @@ -119,6 +119,9 @@ jobject J_ProfilerPrintFormat_GRAPHVIZ; jclass J_QueryProgress; jmethodID J_QueryProgress_init; +jclass J_DuckDBScalarFunctionWrapper; +jmethodID J_DuckDBScalarFunctionWrapper_execute; + static std::vector global_refs; template @@ -308,6 +311,11 @@ void create_refs(JNIEnv *env) { J_QueryProgress = make_class_ref(env, "org/duckdb/QueryProgress"); J_QueryProgress_init = get_method_id(env, J_QueryProgress, "", "(DJJ)V"); + + J_DuckDBScalarFunctionWrapper = make_class_ref(env, "org/duckdb/DuckDBScalarFunctionWrapper"); + J_DuckDBScalarFunctionWrapper_execute = + get_method_id(env, J_DuckDBScalarFunctionWrapper, "execute", + "(Ljava/nio/ByteBuffer;Ljava/nio/ByteBuffer;Ljava/nio/ByteBuffer;)V"); } void delete_global_refs(JNIEnv *env) noexcept { diff --git a/src/jni/refs.hpp b/src/jni/refs.hpp index cda859d33..fe830584d 100644 --- a/src/jni/refs.hpp +++ b/src/jni/refs.hpp @@ -116,6 +116,9 @@ extern jobject J_ProfilerPrintFormat_GRAPHVIZ; extern jclass J_QueryProgress; extern jmethodID J_QueryProgress_init; +extern jclass J_DuckDBScalarFunctionWrapper; +extern jmethodID J_DuckDBScalarFunctionWrapper_execute; + void create_refs(JNIEnv *env); void delete_global_refs(JNIEnv *env) noexcept; diff --git a/src/jni/scalar_functions.cpp b/src/jni/scalar_functions.cpp deleted file mode 100644 index 054561ac1..000000000 --- a/src/jni/scalar_functions.cpp +++ /dev/null @@ -1,271 +0,0 @@ -extern "C" { -#include "duckdb.h" -} - -#include "holders.hpp" -#include "refs.hpp" -#include "scalar_functions.hpp" -#include "util.hpp" - -#include -#include - -class ScalarFunctionException : public std::exception { -public: - explicit ScalarFunctionException(std::string message_p) : message(std::move(message_p)) { - } - - const char *what() const noexcept override { - return message.c_str(); - } - -private: - std::string message; -}; - -struct JNIEnvGuard { - JavaVM *vm; - JNIEnv *env; - bool detach_when_done; - - explicit JNIEnvGuard(JavaVM *vm_p) : vm(vm_p), env(nullptr), detach_when_done(false) { - if (!vm) { - throw ScalarFunctionException("JVM is not available"); - } - auto get_env_status = vm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6); - if (get_env_status == JNI_OK) { - return; - } - if (get_env_status != JNI_EDETACHED) { - throw ScalarFunctionException("Failed to get JNI environment"); - } - auto attach_status = vm->AttachCurrentThread(reinterpret_cast(&env), nullptr); - if (attach_status != JNI_OK || !env) { - throw ScalarFunctionException("Failed to attach current thread to JVM"); - } - detach_when_done = true; - } - - ~JNIEnvGuard() { - if (detach_when_done && vm) { - vm->DetachCurrentThread(); - } - } -}; - -struct JavaScalarFunctionState { - JavaVM *vm; - jobject callback; - jmethodID apply_method; - - JavaScalarFunctionState(JavaVM *vm_p, jobject callback_p, jmethodID apply_method_p) - : vm(vm_p), callback(callback_p), apply_method(apply_method_p) { - } - - ~JavaScalarFunctionState() { - if (!vm || !callback) { - return; - } - try { - JNIEnvGuard env_guard(vm); - env_guard.env->DeleteGlobalRef(callback); - } catch (...) { - // noop in destructor - } - } -}; - -struct JavaScalarFunctionLocalState { - JavaVM *vm; - JNIEnv *env; - bool detach_when_done; -}; - -static duckdb_scalar_function scalar_function_buf_to_scalar_function(JNIEnv *env, jobject scalar_function_buf) { - if (scalar_function_buf == nullptr) { - env->ThrowNew(J_SQLException, "Invalid scalar function buffer"); - return nullptr; - } - - auto scalar_function = reinterpret_cast(env->GetDirectBufferAddress(scalar_function_buf)); - if (scalar_function == nullptr) { - env->ThrowNew(J_SQLException, "Invalid scalar function"); - return nullptr; - } - return scalar_function; -} - -static std::string consume_java_exception_message(JNIEnv *env) { - auto throwable = env->ExceptionOccurred(); - if (!throwable) { - return "Java exception"; - } - env->ExceptionClear(); - - std::string message = "Java exception"; - auto msg = (jstring)env->CallObjectMethod(throwable, J_Throwable_getMessage); - if (!env->ExceptionCheck() && msg) { - message = jstring_to_string(env, msg); - } - if (env->ExceptionCheck()) { - env->ExceptionClear(); - } - - env->DeleteLocalRef(throwable); - if (msg) { - env->DeleteLocalRef(msg); - } - - return message; -} - -static void get_or_attach_jni_env(JavaVM *vm, JNIEnv *&env, bool &detach_when_done) { - if (!vm) { - throw ScalarFunctionException("JVM is not available"); - } - - detach_when_done = false; - auto get_env_status = vm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6); - if (get_env_status == JNI_OK) { - return; - } - if (get_env_status != JNI_EDETACHED) { - throw ScalarFunctionException("Failed to get JNI environment"); - } - - auto attach_status = vm->AttachCurrentThread(reinterpret_cast(&env), nullptr); - if (attach_status != JNI_OK || !env) { - throw ScalarFunctionException("Failed to attach current thread to JVM"); - } - detach_when_done = true; -} - -static void execute_java_scalar_function(JNIEnv *env, JavaScalarFunctionState &state, duckdb_function_info info, - duckdb_data_chunk input, duckdb_vector output) { - jobject function_info_buf = make_ptr_buf(env, info); - jobject input_chunk_buf = make_ptr_buf(env, input); - jobject output_vector_buf = make_ptr_buf(env, output); - env->CallVoidMethod(state.callback, state.apply_method, function_info_buf, input_chunk_buf, output_vector_buf); - if (function_info_buf) { - env->DeleteLocalRef(function_info_buf); - } - if (input_chunk_buf) { - env->DeleteLocalRef(input_chunk_buf); - } - if (output_vector_buf) { - env->DeleteLocalRef(output_vector_buf); - } - - if (env->ExceptionCheck()) { - throw ScalarFunctionException("Java scalar function wrapper threw exception: " + - consume_java_exception_message(env)); - } -} - -static void destroy_java_scalar_function_state(void *extra_info); -static void init_java_scalar_function_capi(duckdb_init_info info); -static void execute_java_scalar_function_capi(duckdb_function_info info, duckdb_data_chunk input, duckdb_vector output); - -static jmethodID get_scalar_callback_method(JNIEnv *env, jobject function_j, const char *signature, - const char *method_name, const char *error_message) { - auto callback_class = env->GetObjectClass(function_j); - auto apply_method = env->GetMethodID(callback_class, method_name, signature); - env->DeleteLocalRef(callback_class); - if (!apply_method || env->ExceptionCheck()) { - consume_java_exception_message(env); - throw ScalarFunctionException(error_message); - } - return apply_method; -} - -void scalar_function_set_function(JNIEnv *env, jobject scalar_function_buf, jobject function_j) { - auto scalar_function = scalar_function_buf_to_scalar_function(env, scalar_function_buf); - if (env->ExceptionCheck()) { - return; - } - if (!function_j) { - throw ScalarFunctionException("Invalid scalar function callback"); - } - - JavaVM *vm = nullptr; - if (env->GetJavaVM(&vm) != JNI_OK || !vm) { - throw ScalarFunctionException("Failed to get JVM reference"); - } - - auto callback_ref = env->NewGlobalRef(function_j); - if (!callback_ref) { - throw ScalarFunctionException("Could not create global reference for scalar function callback"); - } - - try { - auto apply_method = get_scalar_callback_method( - env, function_j, "(Ljava/nio/ByteBuffer;Ljava/nio/ByteBuffer;Ljava/nio/ByteBuffer;)V", "execute", - "Could not find execute(ByteBuffer, ByteBuffer, ByteBuffer) on scalar function callback"); - auto state = new JavaScalarFunctionState(vm, callback_ref, apply_method); - duckdb_scalar_function_set_extra_info(scalar_function, state, destroy_java_scalar_function_state); - duckdb_scalar_function_set_function(scalar_function, execute_java_scalar_function_capi); - duckdb_scalar_function_set_init(scalar_function, init_java_scalar_function_capi); - } catch (...) { - env->DeleteGlobalRef(callback_ref); - throw; - } -} - -static void destroy_java_scalar_function_state(void *extra_info) { - if (!extra_info) { - return; - } - delete reinterpret_cast(extra_info); -} - -static void destroy_java_scalar_function_local_state(void *state_ptr) { - if (!state_ptr) { - return; - } - - auto state = reinterpret_cast(state_ptr); - if (state->detach_when_done && state->vm) { - state->vm->DetachCurrentThread(); - } - delete state; -} - -static void init_java_scalar_function_capi(duckdb_init_info info) { - JavaScalarFunctionLocalState *local_state = nullptr; - try { - auto state = reinterpret_cast(duckdb_scalar_function_init_get_extra_info(info)); - if (!state) { - duckdb_scalar_function_init_set_error(info, "Invalid Java scalar function callback state"); - return; - } - - local_state = new JavaScalarFunctionLocalState(); - local_state->vm = state->vm; - local_state->env = nullptr; - local_state->detach_when_done = false; - get_or_attach_jni_env(local_state->vm, local_state->env, local_state->detach_when_done); - duckdb_scalar_function_init_set_state(info, local_state, destroy_java_scalar_function_local_state); - local_state = nullptr; - } catch (const std::exception &e) { - if (local_state) { - destroy_java_scalar_function_local_state(local_state); - } - duckdb_scalar_function_init_set_error(info, e.what()); - } -} - -static void execute_java_scalar_function_capi(duckdb_function_info info, duckdb_data_chunk input, - duckdb_vector output) { - auto state = reinterpret_cast(duckdb_scalar_function_get_extra_info(info)); - auto local_state = reinterpret_cast(duckdb_scalar_function_get_state(info)); - if (!state || !local_state || !local_state->env || !input || !output) { - duckdb_scalar_function_set_error(info, "Invalid Java scalar function callback state"); - return; - } - - try { - execute_java_scalar_function(local_state->env, *state, info, input, output); - } catch (const std::exception &e) { - duckdb_scalar_function_set_error(info, e.what()); - } -} diff --git a/src/jni/scalar_functions.hpp b/src/jni/scalar_functions.hpp deleted file mode 100644 index 966213bb0..000000000 --- a/src/jni/scalar_functions.hpp +++ /dev/null @@ -1,5 +0,0 @@ -#pragma once - -#include "bindings.hpp" - -void scalar_function_set_function(JNIEnv *env, jobject scalar_function_buf, jobject function_j); diff --git a/src/main/java/org/duckdb/DuckDBBindings.java b/src/main/java/org/duckdb/DuckDBBindings.java index eb535881b..3517fc13f 100644 --- a/src/main/java/org/duckdb/DuckDBBindings.java +++ b/src/main/java/org/duckdb/DuckDBBindings.java @@ -37,7 +37,9 @@ public class DuckDBBindings { static native int duckdb_register_scalar_function(ByteBuffer connection, ByteBuffer scalarFunction); - static native void duckdb_scalar_function_set_function(ByteBuffer scalarFunction, Object function); + static native void duckdb_scalar_function_set_extra_info(ByteBuffer scalarFunction, Object callback); + + static native void duckdb_scalar_function_set_function(ByteBuffer scalarFunction); static native void duckdb_scalar_function_set_error(ByteBuffer functionInfo, byte[] error); diff --git a/src/main/java/org/duckdb/DuckDBScalarFunctionBuilder.java b/src/main/java/org/duckdb/DuckDBScalarFunctionBuilder.java index cf3d16c29..f4648d4fc 100644 --- a/src/main/java/org/duckdb/DuckDBScalarFunctionBuilder.java +++ b/src/main/java/org/duckdb/DuckDBScalarFunctionBuilder.java @@ -434,7 +434,9 @@ private DuckDBScalarFunctionBuilder setCallback(DuckDBScalarFunction function, b throws SQLException { this.callback = function; this.propagateNullsFlag = requiresNullPropagation; - duckdb_scalar_function_set_function(scalarFunctionRef, new DuckDBScalarFunctionWrapper(function)); + DuckDBScalarFunctionWrapper wrapper = new DuckDBScalarFunctionWrapper(function); + duckdb_scalar_function_set_extra_info(scalarFunctionRef, wrapper); + duckdb_scalar_function_set_function(scalarFunctionRef); return this; } diff --git a/src/main/java/org/duckdb/DuckDBScalarFunctionWrapper.java b/src/main/java/org/duckdb/DuckDBScalarFunctionWrapper.java index 9390bd135..7ca90b7ff 100644 --- a/src/main/java/org/duckdb/DuckDBScalarFunctionWrapper.java +++ b/src/main/java/org/duckdb/DuckDBScalarFunctionWrapper.java @@ -4,7 +4,7 @@ import java.nio.ByteBuffer; -final class DuckDBScalarFunctionWrapper { +class DuckDBScalarFunctionWrapper { private final DuckDBScalarFunction function; DuckDBScalarFunctionWrapper(DuckDBScalarFunction function) { @@ -21,6 +21,7 @@ public void execute(ByteBuffer functionInfo, ByteBuffer inputChunk, ByteBuffer o } } + // todo: stacktrace private static void reportError(ByteBuffer functionInfo, Throwable throwable) { String message = throwable.getMessage(); String className = throwable.getClass().getName();