diff --git a/UDF.MD b/UDF.MD index f8407c0e2..8e4b83dd3 100644 --- a/UDF.MD +++ b/UDF.MD @@ -1,11 +1,6 @@ # Java Scalar Functions (UDF) Use `DuckDBFunctions.scalarFunction()` to build and register scalar functions. -`register(java.sql.Connection)` returns a `DuckDBRegisteredFunction` with metadata about the registered scalar function. - -Registered functions are also tracked in a Java-side registry exposed by `DuckDBDriver.registeredFunctions()`. -This registry is bookkeeping for functions registered through the JDBC API, not an authoritative view of the DuckDB catalog. -`DuckDBDriver.clearFunctionsRegistry()` clears only the Java-side registry and does not de-register functions from DuckDB. ## Recommended API (Functional Interfaces) @@ -22,12 +17,12 @@ Use these overloads for simple functions: ```java try (Connection conn = DriverManager.getConnection("jdbc:duckdb:")) { - DuckDBRegisteredFunction function = DuckDBFunctions.scalarFunction() - .withName("java_add_one") - .withParameter(Integer.class) - .withReturnType(Integer.class) - .withIntFunction(x -> x + 1) - .register(conn); + DuckDBFunctions.scalarFunction() + .withName("java_add_one") + .withParameter(int.class) + .withReturnType(int.class) + .withIntFunction(x -> x + 1) + .register(conn); } ``` @@ -41,9 +36,8 @@ SELECT java_add_one(41); try (Connection conn = DriverManager.getConnection("jdbc:duckdb:")) { DuckDBFunctions.scalarFunction() .withName("java_weighted_sum") - .withParameter(Double.class) - .withParameter(Double.class) - .withReturnType(Double.class) + .withParameters(double.class, double.class) + .withReturnType(double.class) .withDoubleFunction((x, w) -> x * w + 10.0) .register(conn); } @@ -53,32 +47,38 @@ try (Connection conn = DriverManager.getConnection("jdbc:duckdb:")) { SELECT java_weighted_sum(2.5, 4.0); ``` -Behavior: +NULL handling behavior: -- `Function` and `BiFunction` callbacks receive `null` for NULL inputs; implement null handling in Java callback logic. -- `withIntFunction(...)`, `withLongFunction(...)`, and `withDoubleFunction(...)` run with null propagation enabled. -- For `withVectorizedFunction(...)`, use `ctx.propagateNulls(true)` when you want stream-level null skipping and automatic NULL output. -- For `Supplier`, returning `null` writes NULL output. -- `Function` and `BiFunction` are fixed arity only (no varargs). +- Java functions that take object arguments (`Function` and `BiFunction` callbacks) are registered with + `duckdb_scalar_function_set_special_handling()` C API option enabled - `null` Java arguments are passed for `NULL` input values +- when `withNullInNullOut()` option is set (skips the `duckdb_scalar_function_set_special_handling()` call on function registration), + then DuckDB engine may skip the function call completely (and set `NULL` to the call result automatically), + but this does not happen in all cases (for example, when the function is applied to a relation where some rows are `NULL`), + so the function still can be passed `null` arguments and need to check them and return `null` accordingly +- Java functions that take primitive arguments (`withIntFunction(...)`, `withLongFunction(...)`, and `withDoubleFunction(...)`) + are never passed `null` arguments as an input - even when the DuckDB engine calls such function with `NULL` value - + the call will be skipped on Java side and `null` value will be returned (returning `null` writes NULL output) Runtime error model: -- Callback-time reader/writer/context type and value failures throw `DuckDBFunctionException`. +- Callback-time reader/writer/context type and value failures throw `DuckDBFunctions.CallException extends RuntimeException`. - Invalid row/column indexes throw `IndexOutOfBoundsException`. -- `SQLException` remains for registration-time API usage and type declaration/validation. +- `SQLException` is used for registration-time API usage and type declaration/validation. ## Type declaration and mapping `withParameter(...)` and `withReturnType(...)` accept: -- `Class` +- `Class` (Object or primitive class, like `int.class` or `Integer.TYPE`) - `DuckDBColumnType` - `DuckDBLogicalType` Common class mappings include: -- `Integer` -> `INTEGER` -- `Long` -> `BIGINT` +- `int` -> `INTEGER` +- `long` -> `BIGINT` +- `float` -> `FLOAT` +- `double` -> `DOUBLE` - `String` -> `VARCHAR` - `BigDecimal` -> `DECIMAL` - `BigInteger` -> `HUGEINT` @@ -140,12 +140,16 @@ Notes: - `withVarArgsFunction(Function)` - `withVectorizedFunction(DuckDBScalarFunction)` - `withVolatile()` -- `withSpecialHandling()` +- `withNullInNullOut()` - `register(java.sql.Connection)` ## Registered Function Metadata And Registry -`DuckDBRegisteredFunction` exposes immutable metadata about the successful registration result: +Registered functions, returned from `.register()` call, are additionally tracked in a Java-side registry exposed by `DuckDBDriver.registeredFunctions()`. +This registry provides bookkeeping for functions registered through the JDBC API, not an authoritative view of the DuckDB catalog. +`DuckDBDriver.clearFunctionsRegistry()` clears only the Java-side registry and does not de-register functions from DuckDB. + +`DuckDBFunctions.RegisteredFunction` exposes immutable metadata about the successful registration result: - `name()` - `functionKind()` @@ -156,47 +160,50 @@ Notes: To inspect Java-side registrations: ```java -List functions = DuckDBDriver.registeredFunctions(); +List functions = DuckDBDriver.registeredFunctions(); ``` The returned list is read-only. Duplicate function names may appear in the registry. ## Advanced API (`DuckDBScalarFunction`) -Use `withVectorizedFunction(...)` for full context control through `DuckDBScalarContext`. +Use `withVectorizedFunction(...)` to access multiple input rows (when the scalar function is applied +to a row set from a relation) in a single call. DuckDB engine splits the input row set into "data chunks" +(represented in Java as `DuckDBDataChunkReader`) each chunk containing up to 2048 rows. + +`DuckDBScalarFunction` callback receives the `DuckDBDataChunkReader` as an `input` argument. +Data chunk contains a number of "data vectors" (`DuckDBReadableVector`) - single vector for each input column. +Vectors can be accessed using `input.vector(columnIndex)`. It writes the results into the `output` `DuckDBWritableVector`. +Results in the `output` vector must be set **on the same `row` indices** that are used to read `input`. + +`input.stream()` call returns a `LongStream` of `row` indices that can be used with Java Streams API. + +- `DuckDBDataChunkReader`, `DuckDBReadableVector`, and `DuckDBWritableVector` are valid only during callback execution. Example with multiple input types (`TIMESTAMP`, `VARCHAR`, `DOUBLE`) and `VARCHAR` output: -```java -try (Connection conn = DriverManager.getConnection("jdbc:duckdb:"); +``` +try (Connection conn = DriverManager.getConnection("jdbc:duckdb:"); Statement stmt = conn.createStatement(); DuckDBLogicalType tsType = DuckDBLogicalType.of(DuckDBColumnType.TIMESTAMP); DuckDBLogicalType strType = DuckDBLogicalType.of(DuckDBColumnType.VARCHAR); DuckDBLogicalType dblType = DuckDBLogicalType.of(DuckDBColumnType.DOUBLE)) { DuckDBFunctions.scalarFunction() .withName("java_event_label") - .withParameter(tsType) - .withParameter(strType) - .withParameter(dblType) + .withParameters(tsType, strType, dblType) .withReturnType(strType) - .withVectorizedFunction(ctx -> { - ctx.propagateNulls(true).stream().forEachOrdered(row -> { - String value = row.getLocalDateTime(0) + " | " - + row.getString(1).trim().toUpperCase() - + " | " + row.getDouble(2); - row.setString(value); + .withVectorizedFunction((input, output) -> { + input.stream().forEach(row -> { + String value = input.vector(0).getLocalDateTime(row) + " | " + + String.valueOf(input.vector(1).getString(row)).trim().toUpperCase() + " | " + + input.vector(2).getDouble(row, 0.0d); + output.setString(row, value); }); }) .register(conn); + ... } ``` ```sql SELECT java_event_label(TIMESTAMP '2026-04-04 12:00:00', 'launch', 4.5); ``` - -Lifecycle rules: - -- `DuckDBScalarContext`, `DuckDBScalarRow`, `DuckDBReadableVector`, and `DuckDBWritableVector` are valid only during callback execution. -- `DuckDBReadableVector` and `DuckDBWritableVector` are abstract callback runtime types (not interfaces). -- Write exactly one output value per input row for each callback invocation. -- With `propagateNulls(true)`, `DuckDBScalarContext.stream()` skips rows that contain NULL in any input column and writes NULL to the output for those rows. diff --git a/src/main/java/org/duckdb/DuckDBDataChunkReader.java b/src/main/java/org/duckdb/DuckDBDataChunkReader.java index 23fc16190..d91b78b64 100644 --- a/src/main/java/org/duckdb/DuckDBDataChunkReader.java +++ b/src/main/java/org/duckdb/DuckDBDataChunkReader.java @@ -3,6 +3,8 @@ import static org.duckdb.DuckDBBindings.*; import java.nio.ByteBuffer; +import java.util.stream.LongStream; +import org.duckdb.DuckDBFunctions.FunctionException; /** * Reader over callback input data chunks. @@ -17,12 +19,18 @@ public final class DuckDBDataChunkReader { DuckDBDataChunkReader(ByteBuffer chunkRef) { if (chunkRef == null) { - throw new DuckDBFunctionException("Invalid data chunk reference"); + throw new FunctionException("Invalid data chunk reference"); } this.chunkRef = chunkRef; this.rowCount = duckdb_data_chunk_get_size(chunkRef); this.columnCount = duckdb_data_chunk_get_column_count(chunkRef); this.vectors = new DuckDBReadableVector[Math.toIntExact(columnCount)]; + + for (long columnIndex = 0; columnIndex < columnCount; columnIndex++) { + ByteBuffer vectorRef = duckdb_data_chunk_get_vector(chunkRef, columnIndex); + int arrayIndex = Math.toIntExact(columnIndex); + vectors[arrayIndex] = new DuckDBReadableVector(vectorRef, rowCount); + } } public long rowCount() { @@ -33,17 +41,15 @@ public long columnCount() { return columnCount; } + public LongStream stream() { + return LongStream.range(0, rowCount); + } + public DuckDBReadableVector vector(long columnIndex) { if (columnIndex < 0 || columnIndex >= columnCount) { throw new IndexOutOfBoundsException("Column index out of bounds: " + columnIndex); } int arrayIndex = Math.toIntExact(columnIndex); - DuckDBReadableVector vector = vectors[arrayIndex]; - if (vector == null) { - ByteBuffer vectorRef = duckdb_data_chunk_get_vector(chunkRef, columnIndex); - vector = new DuckDBReadableVectorImpl(vectorRef, rowCount); - vectors[arrayIndex] = vector; - } - return vector; + return vectors[arrayIndex]; } } diff --git a/src/main/java/org/duckdb/DuckDBDriver.java b/src/main/java/org/duckdb/DuckDBDriver.java index 9e9c4cffa..3797e9ced 100644 --- a/src/main/java/org/duckdb/DuckDBDriver.java +++ b/src/main/java/org/duckdb/DuckDBDriver.java @@ -16,6 +16,7 @@ import java.util.concurrent.ThreadFactory; import java.util.concurrent.locks.ReentrantLock; import java.util.logging.Logger; +import org.duckdb.DuckDBFunctions.RegisteredFunction; import org.duckdb.io.LimitedInputStream; public class DuckDBDriver implements java.sql.Driver { @@ -42,7 +43,7 @@ public class DuckDBDriver implements java.sql.Driver { private static boolean pinnedDbRefsShutdownHookRegistered = false; private static boolean pinnedDbRefsShutdownHookRun = false; - private static final ArrayList functionsRegistry = new ArrayList<>(); + private static final ArrayList functionsRegistry = new ArrayList<>(); private static final ReentrantLock functionsRegistryLock = new ReentrantLock(); private static final Set supportedOptions = new LinkedHashSet<>(); @@ -266,7 +267,7 @@ public static boolean shutdownQueryCancelScheduler() { return true; } - public static List registeredFunctions() { + public static List registeredFunctions() { functionsRegistryLock.lock(); try { return Collections.unmodifiableList(new ArrayList<>(functionsRegistry)); @@ -284,7 +285,7 @@ public static void clearFunctionsRegistry() { } } - static void registerFunction(DuckDBRegisteredFunction function) { + static void registerFunction(RegisteredFunction function) { functionsRegistryLock.lock(); try { functionsRegistry.add(function); diff --git a/src/main/java/org/duckdb/DuckDBFunctionException.java b/src/main/java/org/duckdb/DuckDBFunctionException.java deleted file mode 100644 index 2010fe433..000000000 --- a/src/main/java/org/duckdb/DuckDBFunctionException.java +++ /dev/null @@ -1,13 +0,0 @@ -package org.duckdb; - -public final class DuckDBFunctionException extends RuntimeException { - private static final long serialVersionUID = 1L; - - public DuckDBFunctionException(String message) { - super(message); - } - - public DuckDBFunctionException(String message, Throwable cause) { - super(message, cause); - } -} diff --git a/src/main/java/org/duckdb/DuckDBFunctions.java b/src/main/java/org/duckdb/DuckDBFunctions.java index 0a9337301..201a1ffa4 100644 --- a/src/main/java/org/duckdb/DuckDBFunctions.java +++ b/src/main/java/org/duckdb/DuckDBFunctions.java @@ -1,9 +1,12 @@ package org.duckdb; import java.sql.SQLException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; public final class DuckDBFunctions { - public enum DuckDBFunctionKind { SCALAR } + public enum Kind { SCALAR } private DuckDBFunctions() { } @@ -11,4 +14,108 @@ private DuckDBFunctions() { public static DuckDBScalarFunctionBuilder scalarFunction() throws SQLException { return new DuckDBScalarFunctionBuilder(); } + + static RegisteredFunction createRegisteredFunction(String name, List parameterTypes, + List parameterColumnTypes, + DuckDBLogicalType returnType, DuckDBColumnType returnColumnType, + DuckDBScalarFunction function, DuckDBLogicalType varArgType, + boolean volatileFlag, boolean specialHandlingFlag, + boolean propagateNullsFlag) { + return new RegisteredFunction(name, Kind.SCALAR, Collections.unmodifiableList(new ArrayList<>(parameterTypes)), + Collections.unmodifiableList(new ArrayList<>(parameterColumnTypes)), returnType, + returnColumnType, function, varArgType, volatileFlag, specialHandlingFlag, + propagateNullsFlag); + } + + public static class FunctionException extends RuntimeException { + private static final long serialVersionUID = 1L; + + public FunctionException(String message) { + super(message); + } + + public FunctionException(String message, Throwable cause) { + super(message, cause); + } + } + + public static final class RegisteredFunction { + private final String name; + private final Kind functionKind; + private final List parameterTypes; + private final List parameterColumnTypes; + private final DuckDBLogicalType returnType; + private final DuckDBColumnType returnColumnType; + private final DuckDBScalarFunction function; + private final DuckDBLogicalType varArgType; + private final boolean volatileFlag; + private final boolean nullInNullOutFlag; + private final boolean propagateNullsFlag; + + private RegisteredFunction(String name, Kind functionKind, List parameterTypes, + List parameterColumnTypes, DuckDBLogicalType returnType, + DuckDBColumnType returnColumnType, DuckDBScalarFunction function, + DuckDBLogicalType varArgType, boolean volatileFlag, boolean nullInNullOutFlag, + boolean propagateNullsFlag) { + this.name = name; + this.functionKind = functionKind; + this.parameterTypes = parameterTypes; + this.parameterColumnTypes = parameterColumnTypes; + this.returnType = returnType; + this.returnColumnType = returnColumnType; + this.function = function; + this.varArgType = varArgType; + this.volatileFlag = volatileFlag; + this.nullInNullOutFlag = nullInNullOutFlag; + this.propagateNullsFlag = propagateNullsFlag; + } + + public String name() { + return name; + } + + public Kind functionKind() { + return functionKind; + } + + public List parameterTypes() { + return parameterTypes; + } + + public List parameterColumnTypes() { + return parameterColumnTypes; + } + + public DuckDBLogicalType returnType() { + return returnType; + } + + public DuckDBColumnType returnColumnType() { + return returnColumnType; + } + + public DuckDBScalarFunction function() { + return function; + } + + public DuckDBLogicalType varArgType() { + return varArgType; + } + + public boolean isVolatile() { + return volatileFlag; + } + + public boolean isNullInNullOut() { + return nullInNullOutFlag; + } + + public boolean propagateNulls() { + return propagateNullsFlag; + } + + public boolean isScalar() { + return functionKind == Kind.SCALAR; + } + } } diff --git a/src/main/java/org/duckdb/DuckDBReadableVector.java b/src/main/java/org/duckdb/DuckDBReadableVector.java index 52324ad70..56fef2187 100644 --- a/src/main/java/org/duckdb/DuckDBReadableVector.java +++ b/src/main/java/org/duckdb/DuckDBReadableVector.java @@ -1,86 +1,395 @@ package org.duckdb; +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.duckdb.DuckDBBindings.*; + import java.math.BigDecimal; import java.math.BigInteger; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.sql.Date; import java.sql.Timestamp; +import java.time.Instant; import java.time.LocalDate; import java.time.LocalDateTime; import java.time.OffsetDateTime; +import java.time.ZoneId; +import java.time.temporal.ChronoUnit; import java.util.stream.LongStream; - -/** - * Read-only scalar callback view over a DuckDB vector. - * - *

Implementations throw {@link DuckDBFunctionException} for callback-time type/value errors. - * Invalid row indexes throw {@link IndexOutOfBoundsException}. - */ -public abstract class DuckDBReadableVector { - public abstract DuckDBColumnType getType(); - - public abstract long rowCount(); - - public abstract LongStream rowIndexStream(); - - public abstract boolean isNull(long row); - - public abstract boolean getBoolean(long row); - - public abstract boolean getBoolean(long row, boolean defaultVal); - - public abstract byte getByte(long row); - - public abstract byte getByte(long row, byte defaultVal); - - public abstract short getShort(long row); - - public abstract short getShort(long row, short defaultVal); - - public abstract short getUint8(long row); - - public abstract short getUint8(long row, short defaultVal); - - public abstract int getUint16(long row); - - public abstract int getUint16(long row, int defaultVal); - - public abstract int getInt(long row); - - public abstract int getInt(long row, int defaultVal); - - public abstract long getUint32(long row); - - public abstract long getUint32(long row, long defaultVal); - - public abstract long getLong(long row); - - public abstract long getLong(long row, long defaultVal); - - public abstract BigInteger getHugeInt(long row); - - public abstract BigInteger getUHugeInt(long row); - - public abstract BigInteger getUint64(long row); - - public abstract float getFloat(long row); - - public abstract float getFloat(long row, float defaultVal); - - public abstract double getDouble(long row); - - public abstract double getDouble(long row, double defaultVal); - - public abstract LocalDate getLocalDate(long row); - - public abstract Date getDate(long row); - - public abstract LocalDateTime getLocalDateTime(long row); - - public abstract Timestamp getTimestamp(long row); - - public abstract OffsetDateTime getOffsetDateTime(long row); - - public abstract BigDecimal getBigDecimal(long row); - - public abstract String getString(long row); +import org.duckdb.DuckDBFunctions.FunctionException; + +public final class DuckDBReadableVector { + private static final BigDecimal ULONG_MULTIPLIER = new BigDecimal("18446744073709551616"); + private static final ByteOrder NATIVE_ORDER = ByteOrder.nativeOrder(); + + private final ByteBuffer vectorRef; + private final long rowCount; + private final DuckDBVectorTypeInfo typeInfo; + private final ByteBuffer data; + private final ByteBuffer validity; + + DuckDBReadableVector(ByteBuffer vectorRef, long rowCount) { + if (vectorRef == null) { + throw new FunctionException("Invalid vector reference"); + } + this.vectorRef = vectorRef; + this.rowCount = rowCount; + try { + this.typeInfo = DuckDBVectorTypeInfo.fromVector(vectorRef); + } catch (java.sql.SQLException exception) { + throw new FunctionException("Failed to resolve vector type info", exception); + } + this.data = + duckdb_vector_get_data(vectorRef, Math.multiplyExact(rowCount, typeInfo.widthBytes)).order(NATIVE_ORDER); + this.validity = duckdb_vector_get_validity(vectorRef, rowCount); + if (this.validity != null) { + this.validity.order(NATIVE_ORDER); + } + } + + public DuckDBColumnType getType() { + return typeInfo.columnType; + } + + public long rowCount() { + return rowCount; + } + + public LongStream stream() { + return LongStream.range(0, rowCount); + } + + public boolean isNull(long row) { + checkRowIndex(row); + if (validity == null) { + return false; + } + int entryPos = Math.toIntExact(Math.multiplyExact(row / Long.SIZE, (long) Long.BYTES)); + long mask = validity.getLong(entryPos); + return (mask & (1L << (row % Long.SIZE))) == 0; + } + + public boolean getBoolean(long row) { + requireType(DuckDBColumnType.BOOLEAN); + if (isNull(row)) { + throw primitiveNullValue(DuckDBColumnType.BOOLEAN, row); + } + return data.get(checkedRowIndex(row)) != 0; + } + + public boolean getBoolean(long row, boolean defaultValue) { + requireType(DuckDBColumnType.BOOLEAN); + return isNull(row) ? defaultValue : data.get(checkedRowIndex(row)) != 0; + } + + public byte getByte(long row) { + requireType(DuckDBColumnType.TINYINT); + if (isNull(row)) { + throw primitiveNullValue(DuckDBColumnType.TINYINT, row); + } + return data.get(checkedRowIndex(row)); + } + + public byte getByte(long row, byte defaultValue) { + requireType(DuckDBColumnType.TINYINT); + return isNull(row) ? defaultValue : data.get(checkedRowIndex(row)); + } + + public short getShort(long row) { + requireType(DuckDBColumnType.SMALLINT); + if (isNull(row)) { + throw primitiveNullValue(DuckDBColumnType.SMALLINT, row); + } + return data.getShort(checkedByteOffset(row, Short.BYTES)); + } + + public short getShort(long row, short defaultValue) { + requireType(DuckDBColumnType.SMALLINT); + return isNull(row) ? defaultValue : data.getShort(checkedByteOffset(row, Short.BYTES)); + } + + public short getUint8(long row) { + requireType(DuckDBColumnType.UTINYINT); + if (isNull(row)) { + throw primitiveNullValue(DuckDBColumnType.UTINYINT, row); + } + return (short) Byte.toUnsignedInt(data.get(checkedRowIndex(row))); + } + + public short getUint8(long row, short defaultValue) { + requireType(DuckDBColumnType.UTINYINT); + return isNull(row) ? defaultValue : (short) Byte.toUnsignedInt(data.get(checkedRowIndex(row))); + } + + public int getUint16(long row) { + requireType(DuckDBColumnType.USMALLINT); + if (isNull(row)) { + throw primitiveNullValue(DuckDBColumnType.USMALLINT, row); + } + return Short.toUnsignedInt(data.getShort(checkedByteOffset(row, Short.BYTES))); + } + + public int getUint16(long row, int defaultValue) { + requireType(DuckDBColumnType.USMALLINT); + return isNull(row) ? defaultValue : Short.toUnsignedInt(data.getShort(checkedByteOffset(row, Short.BYTES))); + } + + public int getInt(long row) { + requireType(DuckDBColumnType.INTEGER); + if (isNull(row)) { + throw primitiveNullValue(DuckDBColumnType.INTEGER, row); + } + return data.getInt(checkedByteOffset(row, Integer.BYTES)); + } + + public int getInt(long row, int defaultValue) { + requireType(DuckDBColumnType.INTEGER); + return isNull(row) ? defaultValue : data.getInt(checkedByteOffset(row, Integer.BYTES)); + } + + public long getUint32(long row) { + requireType(DuckDBColumnType.UINTEGER); + if (isNull(row)) { + throw primitiveNullValue(DuckDBColumnType.UINTEGER, row); + } + return Integer.toUnsignedLong(data.getInt(checkedByteOffset(row, Integer.BYTES))); + } + + public long getUint32(long row, long defaultValue) { + requireType(DuckDBColumnType.UINTEGER); + return isNull(row) ? defaultValue : Integer.toUnsignedLong(data.getInt(checkedByteOffset(row, Integer.BYTES))); + } + + public long getLong(long row) { + requireType(DuckDBColumnType.BIGINT); + if (isNull(row)) { + throw primitiveNullValue(DuckDBColumnType.BIGINT, row); + } + return data.getLong(checkedByteOffset(row, Long.BYTES)); + } + + public long getLong(long row, long defaultValue) { + requireType(DuckDBColumnType.BIGINT); + return isNull(row) ? defaultValue : data.getLong(checkedByteOffset(row, Long.BYTES)); + } + + public BigInteger getHugeInt(long row) { + requireType(DuckDBColumnType.HUGEINT); + if (isNull(row)) { + return null; + } + int offset = checkedByteOffset(row, typeInfo.widthBytes); + long lower = data.getLong(offset); + long upper = data.getLong(offset + Long.BYTES); + return DuckDBHugeInt.toBigInteger(lower, upper); + } + + public BigInteger getUHugeInt(long row) { + requireType(DuckDBColumnType.UHUGEINT); + if (isNull(row)) { + return null; + } + int offset = checkedByteOffset(row, typeInfo.widthBytes); + long lower = data.getLong(offset); + long upper = data.getLong(offset + Long.BYTES); + return DuckDBHugeInt.toUnsignedBigInteger(lower, upper); + } + + public BigInteger getUint64(long row) { + requireType(DuckDBColumnType.UBIGINT); + if (isNull(row)) { + return null; + } + long value = data.getLong(checkedByteOffset(row, Long.BYTES)); + return unsignedLongToBigInteger(value); + } + + public float getFloat(long row) { + requireType(DuckDBColumnType.FLOAT); + if (isNull(row)) { + throw primitiveNullValue(DuckDBColumnType.FLOAT, row); + } + return data.getFloat(checkedByteOffset(row, Float.BYTES)); + } + + public float getFloat(long row, float defaultValue) { + requireType(DuckDBColumnType.FLOAT); + return isNull(row) ? defaultValue : data.getFloat(checkedByteOffset(row, Float.BYTES)); + } + + public double getDouble(long row) { + requireType(DuckDBColumnType.DOUBLE); + if (isNull(row)) { + throw primitiveNullValue(DuckDBColumnType.DOUBLE, row); + } + return data.getDouble(checkedByteOffset(row, Double.BYTES)); + } + + public double getDouble(long row, double defaultValue) { + requireType(DuckDBColumnType.DOUBLE); + return isNull(row) ? defaultValue : data.getDouble(checkedByteOffset(row, Double.BYTES)); + } + + public LocalDate getLocalDate(long row) { + requireType(DuckDBColumnType.DATE); + if (isNull(row)) { + return null; + } + return LocalDate.ofEpochDay(data.getInt(checkedByteOffset(row, Integer.BYTES))); + } + + public Date getDate(long row) { + LocalDate value = getLocalDate(row); + return value == null ? null : Date.valueOf(value); + } + + public LocalDateTime getLocalDateTime(long row) { + requireTimestampType(); + if (isNull(row)) { + return null; + } + long epochValue = data.getLong(checkedByteOffset(row, Long.BYTES)); + try { + switch (typeInfo.capiType) { + case DUCKDB_TYPE_TIMESTAMP_S: + return DuckDBTimestamp.localDateTimeFromTimestamp(epochValue, ChronoUnit.SECONDS, null); + case DUCKDB_TYPE_TIMESTAMP_MS: + return DuckDBTimestamp.localDateTimeFromTimestamp(epochValue, ChronoUnit.MILLIS, null); + case DUCKDB_TYPE_TIMESTAMP: + return DuckDBTimestamp.localDateTimeFromTimestamp(epochValue, ChronoUnit.MICROS, null); + case DUCKDB_TYPE_TIMESTAMP_NS: + return DuckDBTimestamp.localDateTimeFromTimestamp(epochValue, ChronoUnit.NANOS, null); + case DUCKDB_TYPE_TIMESTAMP_TZ: + return DuckDBTimestamp.localDateTimeFromTimestampWithTimezone(epochValue, ChronoUnit.MICROS, null); + default: + throw new FunctionException("Expected vector type TIMESTAMP*, found " + typeInfo.columnType); + } + } catch (java.sql.SQLException exception) { + throw new FunctionException("Failed to decode timestamp at row " + row, exception); + } + } + + public Timestamp getTimestamp(long row) { + LocalDateTime value = getLocalDateTime(row); + return value == null ? null : Timestamp.valueOf(value); + } + + public OffsetDateTime getOffsetDateTime(long row) { + requireType(DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE); + if (isNull(row)) { + return null; + } + long micros = data.getLong(checkedByteOffset(row, Long.BYTES)); + Instant instant = instantFromEpoch(micros, ChronoUnit.MICROS); + return instant.atZone(ZoneId.systemDefault()).toOffsetDateTime(); + } + + public BigDecimal getBigDecimal(long row) { + requireType(DuckDBColumnType.DECIMAL); + if (isNull(row)) { + return null; + } + switch (typeInfo.storageType) { + case DUCKDB_TYPE_SMALLINT: + return BigDecimal.valueOf(data.getShort(checkedByteOffset(row, Short.BYTES)), typeInfo.decimalMeta.scale); + case DUCKDB_TYPE_INTEGER: + return BigDecimal.valueOf(data.getInt(checkedByteOffset(row, Integer.BYTES)), typeInfo.decimalMeta.scale); + case DUCKDB_TYPE_BIGINT: + return BigDecimal.valueOf(data.getLong(checkedByteOffset(row, Long.BYTES)), typeInfo.decimalMeta.scale); + case DUCKDB_TYPE_HUGEINT: { + int offset = checkedByteOffset(row, typeInfo.widthBytes); + long lower = data.getLong(offset); + long upper = data.getLong(offset + Long.BYTES); + return new BigDecimal(upper) + .multiply(ULONG_MULTIPLIER) + .add(new BigDecimal(Long.toUnsignedString(lower))) + .scaleByPowerOfTen(typeInfo.decimalMeta.scale * -1); + } + default: + throw new FunctionException("Unsupported DECIMAL storage type: " + typeInfo.storageType); + } + } + + public String getString(long row) { + requireType(DuckDBColumnType.VARCHAR); + if (isNull(row)) { + return null; + } + byte[] bytes = duckdb_vector_get_string(data, row); + if (bytes == null) { + return null; + } + return new String(bytes, UTF_8); + } + + ByteBuffer vectorRef() { + return vectorRef; + } + + private void requireType(DuckDBColumnType expected) { + if (typeInfo.columnType != expected) { + throw new FunctionException("Expected vector type " + expected + ", found " + typeInfo.columnType); + } + } + + private void requireTimestampType() { + switch (typeInfo.columnType) { + case TIMESTAMP: + case TIMESTAMP_S: + case TIMESTAMP_MS: + case TIMESTAMP_NS: + case TIMESTAMP_WITH_TIME_ZONE: + return; + default: + throw new FunctionException("Expected vector type TIMESTAMP*, found " + typeInfo.columnType); + } + } + + private void checkRowIndex(long row) { + if (row < 0 || row >= rowCount) { + throw new IndexOutOfBoundsException("Row index out of bounds: " + row); + } + } + + private int checkedRowIndex(long row) { + checkRowIndex(row); + return Math.toIntExact(row); + } + + private int checkedByteOffset(long row, int elementWidth) { + checkRowIndex(row); + return Math.toIntExact(Math.multiplyExact(row, (long) elementWidth)); + } + + private static Instant instantFromEpoch(long value, ChronoUnit unit) { + switch (unit) { + case SECONDS: + return Instant.ofEpochSecond(value); + case MILLIS: + return Instant.ofEpochMilli(value); + case MICROS: { + long epochSecond = Math.floorDiv(value, 1_000_000L); + long nanoAdjustment = Math.floorMod(value, 1_000_000L) * 1000L; + return Instant.ofEpochSecond(epochSecond, nanoAdjustment); + } + case NANOS: { + long epochSecond = Math.floorDiv(value, 1_000_000_000L); + long nanoAdjustment = Math.floorMod(value, 1_000_000_000L); + return Instant.ofEpochSecond(epochSecond, nanoAdjustment); + } + default: + throw new FunctionException("Unsupported unit type: " + unit); + } + } + + private static BigInteger unsignedLongToBigInteger(long value) { + if (value >= 0) { + return BigInteger.valueOf(value); + } + return BigInteger.valueOf(value & Long.MAX_VALUE).setBit(Long.SIZE - 1); + } + + private static FunctionException primitiveNullValue(DuckDBColumnType type, long row) { + return new FunctionException("Primitive value for " + type + " at row " + row + " is NULL"); + } } diff --git a/src/main/java/org/duckdb/DuckDBReadableVectorImpl.java b/src/main/java/org/duckdb/DuckDBReadableVectorImpl.java deleted file mode 100644 index 58b0a3f03..000000000 --- a/src/main/java/org/duckdb/DuckDBReadableVectorImpl.java +++ /dev/null @@ -1,426 +0,0 @@ -package org.duckdb; - -import static java.nio.charset.StandardCharsets.UTF_8; -import static org.duckdb.DuckDBBindings.*; - -import java.math.BigDecimal; -import java.math.BigInteger; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.sql.Date; -import java.sql.Timestamp; -import java.time.Instant; -import java.time.LocalDate; -import java.time.LocalDateTime; -import java.time.OffsetDateTime; -import java.time.ZoneId; -import java.time.temporal.ChronoUnit; -import java.util.stream.LongStream; - -final class DuckDBReadableVectorImpl extends DuckDBReadableVector { - private static final BigDecimal ULONG_MULTIPLIER = new BigDecimal("18446744073709551616"); - private static final ByteOrder NATIVE_ORDER = ByteOrder.nativeOrder(); - - private final ByteBuffer vectorRef; - private final long rowCount; - private final DuckDBVectorTypeInfo typeInfo; - private final ByteBuffer data; - private final ByteBuffer validity; - - DuckDBReadableVectorImpl(ByteBuffer vectorRef, long rowCount) { - if (vectorRef == null) { - throw new DuckDBFunctionException("Invalid vector reference"); - } - this.vectorRef = vectorRef; - this.rowCount = rowCount; - try { - this.typeInfo = DuckDBVectorTypeInfo.fromVector(vectorRef); - } catch (java.sql.SQLException exception) { - throw new DuckDBFunctionException("Failed to resolve vector type info", exception); - } - this.data = - duckdb_vector_get_data(vectorRef, Math.multiplyExact(rowCount, typeInfo.widthBytes)).order(NATIVE_ORDER); - ByteBuffer validityBuffer = duckdb_vector_get_validity(vectorRef, rowCount); - this.validity = validityBuffer == null ? null : validityBuffer.order(NATIVE_ORDER); - } - - @Override - public DuckDBColumnType getType() { - return typeInfo.columnType; - } - - @Override - public long rowCount() { - return rowCount; - } - - @Override - public LongStream rowIndexStream() { - return LongStream.range(0, rowCount); - } - - @Override - public boolean isNull(long row) { - checkRowIndex(row); - if (validity == null) { - return false; - } - int entryPos = Math.toIntExact(Math.multiplyExact(row / Long.SIZE, (long) Long.BYTES)); - long mask = validity.getLong(entryPos); - return (mask & (1L << (row % Long.SIZE))) == 0; - } - - @Override - public boolean getBoolean(long row) { - requireType(DuckDBColumnType.BOOLEAN); - if (isNull(row)) { - throw primitiveNullValue(DuckDBColumnType.BOOLEAN, row); - } - return data.get(checkedRowIndex(row)) != 0; - } - - @Override - public boolean getBoolean(long row, boolean defaultVal) { - requireType(DuckDBColumnType.BOOLEAN); - return isNull(row) ? defaultVal : data.get(checkedRowIndex(row)) != 0; - } - - @Override - public byte getByte(long row) { - requireType(DuckDBColumnType.TINYINT); - if (isNull(row)) { - throw primitiveNullValue(DuckDBColumnType.TINYINT, row); - } - return data.get(checkedRowIndex(row)); - } - - @Override - public byte getByte(long row, byte defaultVal) { - requireType(DuckDBColumnType.TINYINT); - return isNull(row) ? defaultVal : data.get(checkedRowIndex(row)); - } - - @Override - public short getShort(long row) { - requireType(DuckDBColumnType.SMALLINT); - if (isNull(row)) { - throw primitiveNullValue(DuckDBColumnType.SMALLINT, row); - } - return data.getShort(checkedByteOffset(row, Short.BYTES)); - } - - @Override - public short getShort(long row, short defaultVal) { - requireType(DuckDBColumnType.SMALLINT); - return isNull(row) ? defaultVal : data.getShort(checkedByteOffset(row, Short.BYTES)); - } - - @Override - public short getUint8(long row) { - requireType(DuckDBColumnType.UTINYINT); - if (isNull(row)) { - throw primitiveNullValue(DuckDBColumnType.UTINYINT, row); - } - return (short) Byte.toUnsignedInt(data.get(checkedRowIndex(row))); - } - - @Override - public short getUint8(long row, short defaultVal) { - requireType(DuckDBColumnType.UTINYINT); - return isNull(row) ? defaultVal : (short) Byte.toUnsignedInt(data.get(checkedRowIndex(row))); - } - - @Override - public int getUint16(long row) { - requireType(DuckDBColumnType.USMALLINT); - if (isNull(row)) { - throw primitiveNullValue(DuckDBColumnType.USMALLINT, row); - } - return Short.toUnsignedInt(data.getShort(checkedByteOffset(row, Short.BYTES))); - } - - @Override - public int getUint16(long row, int defaultVal) { - requireType(DuckDBColumnType.USMALLINT); - return isNull(row) ? defaultVal : Short.toUnsignedInt(data.getShort(checkedByteOffset(row, Short.BYTES))); - } - - @Override - public int getInt(long row) { - requireType(DuckDBColumnType.INTEGER); - if (isNull(row)) { - throw primitiveNullValue(DuckDBColumnType.INTEGER, row); - } - return data.getInt(checkedByteOffset(row, Integer.BYTES)); - } - - @Override - public int getInt(long row, int defaultVal) { - requireType(DuckDBColumnType.INTEGER); - return isNull(row) ? defaultVal : data.getInt(checkedByteOffset(row, Integer.BYTES)); - } - - @Override - public long getUint32(long row) { - requireType(DuckDBColumnType.UINTEGER); - if (isNull(row)) { - throw primitiveNullValue(DuckDBColumnType.UINTEGER, row); - } - return Integer.toUnsignedLong(data.getInt(checkedByteOffset(row, Integer.BYTES))); - } - - @Override - public long getUint32(long row, long defaultVal) { - requireType(DuckDBColumnType.UINTEGER); - return isNull(row) ? defaultVal : Integer.toUnsignedLong(data.getInt(checkedByteOffset(row, Integer.BYTES))); - } - - @Override - public long getLong(long row) { - requireType(DuckDBColumnType.BIGINT); - if (isNull(row)) { - throw primitiveNullValue(DuckDBColumnType.BIGINT, row); - } - return data.getLong(checkedByteOffset(row, Long.BYTES)); - } - - @Override - public long getLong(long row, long defaultVal) { - requireType(DuckDBColumnType.BIGINT); - return isNull(row) ? defaultVal : data.getLong(checkedByteOffset(row, Long.BYTES)); - } - - @Override - public BigInteger getHugeInt(long row) { - requireType(DuckDBColumnType.HUGEINT); - if (isNull(row)) { - return null; - } - int offset = checkedByteOffset(row, typeInfo.widthBytes); - long lower = data.getLong(offset); - long upper = data.getLong(offset + Long.BYTES); - return DuckDBHugeInt.toBigInteger(lower, upper); - } - - @Override - public BigInteger getUHugeInt(long row) { - requireType(DuckDBColumnType.UHUGEINT); - if (isNull(row)) { - return null; - } - int offset = checkedByteOffset(row, typeInfo.widthBytes); - long lower = data.getLong(offset); - long upper = data.getLong(offset + Long.BYTES); - return DuckDBHugeInt.toUnsignedBigInteger(lower, upper); - } - - @Override - public BigInteger getUint64(long row) { - requireType(DuckDBColumnType.UBIGINT); - if (isNull(row)) { - return null; - } - long value = data.getLong(checkedByteOffset(row, Long.BYTES)); - return unsignedLongToBigInteger(value); - } - - @Override - public float getFloat(long row) { - requireType(DuckDBColumnType.FLOAT); - if (isNull(row)) { - throw primitiveNullValue(DuckDBColumnType.FLOAT, row); - } - return data.getFloat(checkedByteOffset(row, Float.BYTES)); - } - - @Override - public float getFloat(long row, float defaultVal) { - requireType(DuckDBColumnType.FLOAT); - return isNull(row) ? defaultVal : data.getFloat(checkedByteOffset(row, Float.BYTES)); - } - - @Override - public double getDouble(long row) { - requireType(DuckDBColumnType.DOUBLE); - if (isNull(row)) { - throw primitiveNullValue(DuckDBColumnType.DOUBLE, row); - } - return data.getDouble(checkedByteOffset(row, Double.BYTES)); - } - - @Override - public double getDouble(long row, double defaultVal) { - requireType(DuckDBColumnType.DOUBLE); - return isNull(row) ? defaultVal : data.getDouble(checkedByteOffset(row, Double.BYTES)); - } - - @Override - public LocalDate getLocalDate(long row) { - requireType(DuckDBColumnType.DATE); - if (isNull(row)) { - return null; - } - return LocalDate.ofEpochDay(data.getInt(checkedByteOffset(row, Integer.BYTES))); - } - - @Override - public Date getDate(long row) { - LocalDate value = getLocalDate(row); - return value == null ? null : Date.valueOf(value); - } - - @Override - public LocalDateTime getLocalDateTime(long row) { - requireTimestampType(); - if (isNull(row)) { - return null; - } - long epochValue = data.getLong(checkedByteOffset(row, Long.BYTES)); - try { - switch (typeInfo.capiType) { - case DUCKDB_TYPE_TIMESTAMP_S: - return DuckDBTimestamp.localDateTimeFromTimestamp(epochValue, ChronoUnit.SECONDS, null); - case DUCKDB_TYPE_TIMESTAMP_MS: - return DuckDBTimestamp.localDateTimeFromTimestamp(epochValue, ChronoUnit.MILLIS, null); - case DUCKDB_TYPE_TIMESTAMP: - return DuckDBTimestamp.localDateTimeFromTimestamp(epochValue, ChronoUnit.MICROS, null); - case DUCKDB_TYPE_TIMESTAMP_NS: - return DuckDBTimestamp.localDateTimeFromTimestamp(epochValue, ChronoUnit.NANOS, null); - case DUCKDB_TYPE_TIMESTAMP_TZ: - return DuckDBTimestamp.localDateTimeFromTimestampWithTimezone(epochValue, ChronoUnit.MICROS, null); - default: - throw new DuckDBFunctionException("Expected vector type TIMESTAMP*, found " + typeInfo.columnType); - } - } catch (java.sql.SQLException exception) { - throw new DuckDBFunctionException("Failed to decode timestamp at row " + row, exception); - } - } - - @Override - public Timestamp getTimestamp(long row) { - LocalDateTime value = getLocalDateTime(row); - return value == null ? null : Timestamp.valueOf(value); - } - - @Override - public OffsetDateTime getOffsetDateTime(long row) { - requireType(DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE); - if (isNull(row)) { - return null; - } - long micros = data.getLong(checkedByteOffset(row, Long.BYTES)); - Instant instant = instantFromEpoch(micros, ChronoUnit.MICROS); - return instant.atZone(ZoneId.systemDefault()).toOffsetDateTime(); - } - - @Override - public BigDecimal getBigDecimal(long row) { - requireType(DuckDBColumnType.DECIMAL); - if (isNull(row)) { - return null; - } - switch (typeInfo.storageType) { - case DUCKDB_TYPE_SMALLINT: - return BigDecimal.valueOf(data.getShort(checkedByteOffset(row, Short.BYTES)), typeInfo.decimalMeta.scale); - case DUCKDB_TYPE_INTEGER: - return BigDecimal.valueOf(data.getInt(checkedByteOffset(row, Integer.BYTES)), typeInfo.decimalMeta.scale); - case DUCKDB_TYPE_BIGINT: - return BigDecimal.valueOf(data.getLong(checkedByteOffset(row, Long.BYTES)), typeInfo.decimalMeta.scale); - case DUCKDB_TYPE_HUGEINT: { - int offset = checkedByteOffset(row, typeInfo.widthBytes); - long lower = data.getLong(offset); - long upper = data.getLong(offset + Long.BYTES); - return new BigDecimal(upper) - .multiply(ULONG_MULTIPLIER) - .add(new BigDecimal(Long.toUnsignedString(lower))) - .scaleByPowerOfTen(typeInfo.decimalMeta.scale * -1); - } - default: - throw new DuckDBFunctionException("Unsupported DECIMAL storage type: " + typeInfo.storageType); - } - } - - @Override - public String getString(long row) { - requireType(DuckDBColumnType.VARCHAR); - if (isNull(row)) { - return null; - } - byte[] bytes = duckdb_vector_get_string(data, row); - if (bytes == null) { - return null; - } - return new String(bytes, UTF_8); - } - - ByteBuffer vectorRef() { - return vectorRef; - } - - private void requireType(DuckDBColumnType expected) { - if (typeInfo.columnType != expected) { - throw new DuckDBFunctionException("Expected vector type " + expected + ", found " + typeInfo.columnType); - } - } - - private void requireTimestampType() { - switch (typeInfo.columnType) { - case TIMESTAMP: - case TIMESTAMP_S: - case TIMESTAMP_MS: - case TIMESTAMP_NS: - case TIMESTAMP_WITH_TIME_ZONE: - return; - default: - throw new DuckDBFunctionException("Expected vector type TIMESTAMP*, found " + typeInfo.columnType); - } - } - - private void checkRowIndex(long row) { - if (row < 0 || row >= rowCount) { - throw new IndexOutOfBoundsException("Row index out of bounds: " + row); - } - } - - private int checkedRowIndex(long row) { - checkRowIndex(row); - return Math.toIntExact(row); - } - - private int checkedByteOffset(long row, int elementWidth) { - checkRowIndex(row); - return Math.toIntExact(Math.multiplyExact(row, (long) elementWidth)); - } - - private static Instant instantFromEpoch(long value, ChronoUnit unit) { - switch (unit) { - case SECONDS: - return Instant.ofEpochSecond(value); - case MILLIS: - return Instant.ofEpochMilli(value); - case MICROS: { - long epochSecond = Math.floorDiv(value, 1_000_000L); - long nanoAdjustment = Math.floorMod(value, 1_000_000L) * 1000L; - return Instant.ofEpochSecond(epochSecond, nanoAdjustment); - } - case NANOS: { - long epochSecond = Math.floorDiv(value, 1_000_000_000L); - long nanoAdjustment = Math.floorMod(value, 1_000_000_000L); - return Instant.ofEpochSecond(epochSecond, nanoAdjustment); - } - default: - throw new DuckDBFunctionException("Unsupported unit type: " + unit); - } - } - - private static BigInteger unsignedLongToBigInteger(long value) { - if (value >= 0) { - return BigInteger.valueOf(value); - } - return BigInteger.valueOf(value & Long.MAX_VALUE).setBit(Long.SIZE - 1); - } - - private static DuckDBFunctionException primitiveNullValue(DuckDBColumnType type, long row) { - return new DuckDBFunctionException("Primitive value for " + type + " at row " + row + " is NULL"); - } -} diff --git a/src/main/java/org/duckdb/DuckDBRegisteredFunction.java b/src/main/java/org/duckdb/DuckDBRegisteredFunction.java deleted file mode 100644 index f11f6c968..000000000 --- a/src/main/java/org/duckdb/DuckDBRegisteredFunction.java +++ /dev/null @@ -1,98 +0,0 @@ -package org.duckdb; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; - -public final class DuckDBRegisteredFunction { - private final String name; - private final DuckDBFunctions.DuckDBFunctionKind functionKind; - private final List parameterTypes; - private final List parameterColumnTypes; - private final DuckDBLogicalType returnType; - private final DuckDBColumnType returnColumnType; - private final DuckDBScalarFunction function; - private final DuckDBLogicalType varArgType; - private final boolean volatileFlag; - private final boolean specialHandlingFlag; - private final boolean propagateNullsFlag; - - private DuckDBRegisteredFunction(String name, DuckDBFunctions.DuckDBFunctionKind functionKind, - List parameterTypes, - List parameterColumnTypes, DuckDBLogicalType returnType, - DuckDBColumnType returnColumnType, DuckDBScalarFunction function, - DuckDBLogicalType varArgType, boolean volatileFlag, boolean specialHandlingFlag, - boolean propagateNullsFlag) { - this.name = name; - this.functionKind = functionKind; - this.parameterTypes = parameterTypes; - this.parameterColumnTypes = parameterColumnTypes; - this.returnType = returnType; - this.returnColumnType = returnColumnType; - this.function = function; - this.varArgType = varArgType; - this.volatileFlag = volatileFlag; - this.specialHandlingFlag = specialHandlingFlag; - this.propagateNullsFlag = propagateNullsFlag; - } - - public String name() { - return name; - } - - public DuckDBFunctions.DuckDBFunctionKind functionKind() { - return functionKind; - } - - public List parameterTypes() { - return parameterTypes; - } - - public List parameterColumnTypes() { - return parameterColumnTypes; - } - - public DuckDBLogicalType returnType() { - return returnType; - } - - public DuckDBColumnType returnColumnType() { - return returnColumnType; - } - - public DuckDBScalarFunction function() { - return function; - } - - public DuckDBLogicalType varArgType() { - return varArgType; - } - - public boolean isVolatile() { - return volatileFlag; - } - - public boolean hasSpecialHandling() { - return specialHandlingFlag; - } - - public boolean propagateNulls() { - return propagateNullsFlag; - } - - public boolean isScalar() { - return functionKind == DuckDBFunctions.DuckDBFunctionKind.SCALAR; - } - - static DuckDBRegisteredFunction of(String name, List parameterTypes, - List parameterColumnTypes, DuckDBLogicalType returnType, - DuckDBColumnType returnColumnType, DuckDBScalarFunction function, - DuckDBLogicalType varArgType, boolean volatileFlag, boolean specialHandlingFlag, - boolean propagateNullsFlag) { - return new DuckDBRegisteredFunction(name, DuckDBFunctions.DuckDBFunctionKind.SCALAR, - Collections.unmodifiableList(new ArrayList<>(parameterTypes)), - Collections.unmodifiableList(new ArrayList<>(parameterColumnTypes)), - returnType, returnColumnType, function, varArgType, volatileFlag, - specialHandlingFlag, propagateNullsFlag); - } -} diff --git a/src/main/java/org/duckdb/DuckDBScalarContext.java b/src/main/java/org/duckdb/DuckDBScalarContext.java deleted file mode 100644 index ad0c927ae..000000000 --- a/src/main/java/org/duckdb/DuckDBScalarContext.java +++ /dev/null @@ -1,90 +0,0 @@ -package org.duckdb; - -import java.util.stream.LongStream; -import java.util.stream.Stream; - -/** - * Per-invocation scalar callback context. - * - *

Runtime type/value failures are surfaced as {@link DuckDBFunctionException}. Invalid row or - * column indexes throw {@link IndexOutOfBoundsException}. - */ -public final class DuckDBScalarContext { - private final DuckDBDataChunkReader input; - private final DuckDBWritableVector output; - private boolean propagateNulls; - - DuckDBScalarContext(DuckDBDataChunkReader input, DuckDBWritableVector output, boolean propagateNulls) { - if (input == null) { - throw new IllegalArgumentException("Input chunk cannot be null"); - } - if (output == null) { - throw new IllegalArgumentException("Output vector cannot be null"); - } - this.input = input; - this.output = output; - this.propagateNulls = propagateNulls; - } - - public long rowCount() { - return input.rowCount(); - } - - public long columnCount() { - return input.columnCount(); - } - - public DuckDBReadableVector input(long columnIndex) { - return input.vector(columnIndex); - } - - public DuckDBWritableVector output() { - return output; - } - - public boolean nullsPropagated() { - return propagateNulls; - } - - public DuckDBScalarContext propagateNulls(boolean propagateNulls) { - this.propagateNulls = propagateNulls; - return this; - } - - DuckDBDataChunkReader inputChunk() { - return input; - } - - public Stream stream() { - LongStream rows = LongStream.range(0, rowCount()); - if (propagateNulls) { - rows = rows.filter(this::rowHasNoNullInputs); - } - return rows.mapToObj(this::row); - } - - public DuckDBScalarRow row(long rowIndex) { - checkRowIndex(rowIndex); - return new DuckDBScalarRow(this, rowIndex); - } - - DuckDBReadableVector inputUnchecked(long columnIndex) { - return input(columnIndex); - } - - private void checkRowIndex(long rowIndex) { - if (rowIndex < 0 || rowIndex >= rowCount()) { - throw new IndexOutOfBoundsException("Row index out of bounds: " + rowIndex); - } - } - - private boolean rowHasNoNullInputs(long rowIndex) { - for (long columnIndex = 0; columnIndex < columnCount(); columnIndex++) { - if (inputUnchecked(columnIndex).isNull(rowIndex)) { - output.setNull(rowIndex); - return false; - } - } - return true; - } -} diff --git a/src/main/java/org/duckdb/DuckDBScalarFunction.java b/src/main/java/org/duckdb/DuckDBScalarFunction.java index beef14162..ddaea82be 100644 --- a/src/main/java/org/duckdb/DuckDBScalarFunction.java +++ b/src/main/java/org/duckdb/DuckDBScalarFunction.java @@ -8,8 +8,9 @@ public interface DuckDBScalarFunction { *

The context and all wrappers returned from it are valid only for the duration of the callback and must not * be retained. * - * @param ctx scalar function execution context for the current chunk + * @param input data chunk with function arguments + * @param output vector to write function results into * @throws Exception when function execution fails */ - void apply(DuckDBScalarContext ctx) throws Exception; + void apply(DuckDBDataChunkReader input, DuckDBWritableVector output) throws Exception; } diff --git a/src/main/java/org/duckdb/DuckDBScalarFunctionAdapter.java b/src/main/java/org/duckdb/DuckDBScalarFunctionAdapter.java index f67d20a72..ae40940da 100644 --- a/src/main/java/org/duckdb/DuckDBScalarFunctionAdapter.java +++ b/src/main/java/org/duckdb/DuckDBScalarFunctionAdapter.java @@ -21,6 +21,7 @@ import java.util.function.LongBinaryOperator; import java.util.function.LongUnaryOperator; import java.util.function.Supplier; +import org.duckdb.DuckDBFunctions.FunctionException; final class DuckDBScalarFunctionAdapter { private static final Map CODECS_BY_DUCKDB_TYPE = new LinkedHashMap<>(); @@ -118,23 +119,16 @@ static DuckDBScalarFunction unary(Function function, DuckDBColumnType para @SuppressWarnings("unchecked") Function typedFunction = (Function) function; TypeCodec inCodec = codecFor(parameterType, parameterJavaType); TypeCodec outCodec = codecFor(returnType, returnJavaType); - return ctx -> { - DuckDBReadableVector in = ctx.input(0); - DuckDBWritableVector out = ctx.output(); - long rowCount = ctx.rowCount(); - boolean propagateNulls = ctx.nullsPropagated(); + return (input, output) -> { + DuckDBReadableVector in = input.vector(0); + long rowCount = input.rowCount(); for (long row = 0; row < rowCount; row++) { try { - if (propagateNulls && in.isNull(row)) { - out.setNull(row); - continue; - } Object argument = in.isNull(row) ? null : inCodec.read(in, row); Object result = typedFunction.apply(argument); - outCodec.write(out, row, result); - } catch (DuckDBFunctionException exception) { - throw new DuckDBFunctionException("Failed to execute unary scalar function at row " + row, - exception); + outCodec.write(output, row, result); + } catch (FunctionException exception) { + throw new FunctionException("Failed to execute unary scalar function at row " + row, exception); } } }; @@ -148,162 +142,129 @@ static DuckDBScalarFunction binary(BiFunction function, DuckDBColumnTyp TypeCodec leftCodec = codecFor(leftType, leftJavaType); TypeCodec rightCodec = codecFor(rightType, rightJavaType); TypeCodec outCodec = codecFor(returnType, returnJavaType); - return ctx -> { - DuckDBReadableVector left = ctx.input(0); - DuckDBReadableVector right = ctx.input(1); - DuckDBWritableVector out = ctx.output(); - long rowCount = ctx.rowCount(); - boolean propagateNulls = ctx.nullsPropagated(); + return (input, output) -> { + DuckDBReadableVector left = input.vector(0); + DuckDBReadableVector right = input.vector(1); + long rowCount = input.rowCount(); for (long row = 0; row < rowCount; row++) { try { - boolean leftIsNull = left.isNull(row); - boolean rightIsNull = right.isNull(row); - if (propagateNulls && (leftIsNull || rightIsNull)) { - out.setNull(row); - continue; - } - Object leftValue = leftIsNull ? null : leftCodec.read(left, row); - Object rightValue = rightIsNull ? null : rightCodec.read(right, row); + Object leftValue = left.isNull(row) ? null : leftCodec.read(left, row); + Object rightValue = right.isNull(row) ? null : rightCodec.read(right, row); Object result = typedFunction.apply(leftValue, rightValue); - outCodec.write(out, row, result); - } catch (DuckDBFunctionException exception) { - throw new DuckDBFunctionException("Failed to execute binary scalar function at row " + row, - exception); + outCodec.write(output, row, result); + } catch (FunctionException exception) { + throw new FunctionException("Failed to execute binary scalar function at row " + row, exception); } } }; } static DuckDBScalarFunction intUnary(IntUnaryOperator function) { - return ctx -> { - if (!ctx.nullsPropagated()) { - throw new DuckDBFunctionException("withIntFunction requires propagateNulls(true)"); - } - DuckDBReadableVector in = ctx.input(0); - DuckDBWritableVector out = ctx.output(); - long rowCount = ctx.rowCount(); + return (input, output) -> { + DuckDBReadableVector in = input.vector(0); + long rowCount = input.rowCount(); for (long row = 0; row < rowCount; row++) { try { if (in.isNull(row)) { - out.setNull(row); + output.setNull(row); } else { - out.setInt(row, function.applyAsInt(in.getInt(row))); + output.setInt(row, function.applyAsInt(in.getInt(row))); } - } catch (DuckDBFunctionException exception) { - throw new DuckDBFunctionException("Failed to execute withIntFunction at row " + row, exception); + } catch (FunctionException exception) { + throw new FunctionException("Failed to execute withIntFunction at row " + row, exception); } } }; } static DuckDBScalarFunction intBinary(IntBinaryOperator function) { - return ctx -> { - if (!ctx.nullsPropagated()) { - throw new DuckDBFunctionException("withIntFunction requires propagateNulls(true)"); - } - DuckDBReadableVector left = ctx.input(0); - DuckDBReadableVector right = ctx.input(1); - DuckDBWritableVector out = ctx.output(); - long rowCount = ctx.rowCount(); + return (input, output) -> { + DuckDBReadableVector left = input.vector(0); + DuckDBReadableVector right = input.vector(1); + long rowCount = input.rowCount(); for (long row = 0; row < rowCount; row++) { try { if (left.isNull(row) || right.isNull(row)) { - out.setNull(row); + output.setNull(row); } else { - out.setInt(row, function.applyAsInt(left.getInt(row), right.getInt(row))); + output.setInt(row, function.applyAsInt(left.getInt(row), right.getInt(row))); } - } catch (DuckDBFunctionException exception) { - throw new DuckDBFunctionException("Failed to execute withIntFunction at row " + row, exception); + } catch (FunctionException exception) { + throw new FunctionException("Failed to execute withIntFunction at row " + row, exception); } } }; } static DuckDBScalarFunction doubleUnary(DoubleUnaryOperator function) { - return ctx -> { - if (!ctx.nullsPropagated()) { - throw new DuckDBFunctionException("withDoubleFunction requires propagateNulls(true)"); - } - DuckDBReadableVector in = ctx.input(0); - DuckDBWritableVector out = ctx.output(); - long rowCount = ctx.rowCount(); + return (input, output) -> { + DuckDBReadableVector in = input.vector(0); + long rowCount = input.rowCount(); for (long row = 0; row < rowCount; row++) { try { if (in.isNull(row)) { - out.setNull(row); + output.setNull(row); } else { - out.setDouble(row, function.applyAsDouble(in.getDouble(row))); + output.setDouble(row, function.applyAsDouble(in.getDouble(row))); } - } catch (DuckDBFunctionException exception) { - throw new DuckDBFunctionException("Failed to execute withDoubleFunction at row " + row, exception); + } catch (FunctionException exception) { + throw new FunctionException("Failed to execute withDoubleFunction at row " + row, exception); } } }; } static DuckDBScalarFunction doubleBinary(DoubleBinaryOperator function) { - return ctx -> { - if (!ctx.nullsPropagated()) { - throw new DuckDBFunctionException("withDoubleFunction requires propagateNulls(true)"); - } - DuckDBReadableVector left = ctx.input(0); - DuckDBReadableVector right = ctx.input(1); - DuckDBWritableVector out = ctx.output(); - long rowCount = ctx.rowCount(); + return (input, output) -> { + DuckDBReadableVector left = input.vector(0); + DuckDBReadableVector right = input.vector(1); + long rowCount = input.rowCount(); for (long row = 0; row < rowCount; row++) { try { if (left.isNull(row) || right.isNull(row)) { - out.setNull(row); + output.setNull(row); } else { - out.setDouble(row, function.applyAsDouble(left.getDouble(row), right.getDouble(row))); + output.setDouble(row, function.applyAsDouble(left.getDouble(row), right.getDouble(row))); } - } catch (DuckDBFunctionException exception) { - throw new DuckDBFunctionException("Failed to execute withDoubleFunction at row " + row, exception); + } catch (FunctionException exception) { + throw new FunctionException("Failed to execute withDoubleFunction at row " + row, exception); } } }; } static DuckDBScalarFunction longUnary(LongUnaryOperator function) { - return ctx -> { - if (!ctx.nullsPropagated()) { - throw new DuckDBFunctionException("withLongFunction requires propagateNulls(true)"); - } - DuckDBReadableVector in = ctx.input(0); - DuckDBWritableVector out = ctx.output(); - long rowCount = ctx.rowCount(); + return (input, output) -> { + DuckDBReadableVector in = input.vector(0); + long rowCount = input.rowCount(); for (long row = 0; row < rowCount; row++) { try { if (in.isNull(row)) { - out.setNull(row); + output.setNull(row); } else { - out.setLong(row, function.applyAsLong(in.getLong(row))); + output.setLong(row, function.applyAsLong(in.getLong(row))); } - } catch (DuckDBFunctionException exception) { - throw new DuckDBFunctionException("Failed to execute withLongFunction at row " + row, exception); + } catch (FunctionException exception) { + throw new FunctionException("Failed to execute withLongFunction at row " + row, exception); } } }; } static DuckDBScalarFunction longBinary(LongBinaryOperator function) { - return ctx -> { - if (!ctx.nullsPropagated()) { - throw new DuckDBFunctionException("withLongFunction requires propagateNulls(true)"); - } - DuckDBReadableVector left = ctx.input(0); - DuckDBReadableVector right = ctx.input(1); - DuckDBWritableVector out = ctx.output(); - long rowCount = ctx.rowCount(); + return (input, output) -> { + DuckDBReadableVector left = input.vector(0); + DuckDBReadableVector right = input.vector(1); + long rowCount = input.rowCount(); for (long row = 0; row < rowCount; row++) { try { if (left.isNull(row) || right.isNull(row)) { - out.setNull(row); + output.setNull(row); } else { - out.setLong(row, function.applyAsLong(left.getLong(row), right.getLong(row))); + output.setLong(row, function.applyAsLong(left.getLong(row), right.getLong(row))); } - } catch (DuckDBFunctionException exception) { - throw new DuckDBFunctionException("Failed to execute withLongFunction at row " + row, exception); + } catch (FunctionException exception) { + throw new FunctionException("Failed to execute withLongFunction at row " + row, exception); } } }; @@ -313,16 +274,14 @@ static DuckDBScalarFunction nullary(Supplier function, DuckDBColumnType retur throws SQLException { @SuppressWarnings("unchecked") Supplier typedFunction = (Supplier) function; TypeCodec outCodec = codecFor(returnType, returnJavaType); - return ctx -> { - DuckDBWritableVector out = ctx.output(); - long rowCount = ctx.rowCount(); + return (input, output) -> { + long rowCount = input.rowCount(); for (long row = 0; row < rowCount; row++) { try { Object result = typedFunction.get(); - outCodec.write(out, row, result); - } catch (DuckDBFunctionException exception) { - throw new DuckDBFunctionException("Failed to execute supplier scalar function at row " + row, - exception); + outCodec.write(output, row, result); + } catch (FunctionException exception) { + throw new FunctionException("Failed to execute supplier scalar function at row " + row, exception); } } }; @@ -338,42 +297,30 @@ static DuckDBScalarFunction variadic(Function function, DuckDBColum for (int i = 0; i < fixedTypes.length; i++) { fixedCodecs[i] = codecFor(fixedTypes[i], fixedJavaTypes[i]); } - return ctx -> { - DuckDBWritableVector out = ctx.output(); - long rowCount = ctx.rowCount(); - int vectorCount = Math.toIntExact(ctx.columnCount()); - boolean propagateNulls = ctx.nullsPropagated(); + return (input, output) -> { + long rowCount = input.rowCount(); + int vectorCount = Math.toIntExact(input.columnCount()); DuckDBReadableVector[] vectors = new DuckDBReadableVector[vectorCount]; TypeCodec[] codecs = new TypeCodec[vectorCount]; for (int column = 0; column < vectorCount; column++) { - vectors[column] = ctx.input(column); + vectors[column] = input.vector(column); codecs[column] = column < fixedCodecs.length ? fixedCodecs[column] : varArgCodec; } Object[] args = new Object[vectorCount]; for (long row = 0; row < rowCount; row++) { try { - boolean skipRow = false; for (int column = 0; column < vectorCount; column++) { DuckDBReadableVector vector = vectors[column]; if (vector.isNull(row)) { - if (propagateNulls) { - out.setNull(row); - skipRow = true; - break; - } args[column] = null; } else { args[column] = codecs[column].read(vector, row); } } - if (skipRow) { - continue; - } Object result = function.apply(args); - outCodec.write(out, row, result); - } catch (DuckDBFunctionException exception) { - throw new DuckDBFunctionException("Failed to execute variadic scalar function at row " + row, - exception); + outCodec.write(output, row, result); + } catch (FunctionException exception) { + throw new FunctionException("Failed to execute variadic scalar function at row " + row, exception); } } }; diff --git a/src/main/java/org/duckdb/DuckDBScalarFunctionBuilder.java b/src/main/java/org/duckdb/DuckDBScalarFunctionBuilder.java index 6298d0713..cf3d16c29 100644 --- a/src/main/java/org/duckdb/DuckDBScalarFunctionBuilder.java +++ b/src/main/java/org/duckdb/DuckDBScalarFunctionBuilder.java @@ -18,6 +18,7 @@ import java.util.function.LongBinaryOperator; import java.util.function.LongUnaryOperator; import java.util.function.Supplier; +import org.duckdb.DuckDBFunctions.RegisteredFunction; public final class DuckDBScalarFunctionBuilder implements AutoCloseable { private ByteBuffer scalarFunctionRef; @@ -31,7 +32,7 @@ public final class DuckDBScalarFunctionBuilder implements AutoCloseable { private final List parameterColumnTypes = new ArrayList<>(); private final List> parameterJavaTypes = new ArrayList<>(); private boolean volatileFlag; - private boolean specialHandlingFlag; + private boolean nullInNullOutFlag; private boolean propagateNullsFlag; private boolean finalized; @@ -76,6 +77,17 @@ public DuckDBScalarFunctionBuilder withParameter(DuckDBLogicalType parameterType return this; } + public DuckDBScalarFunctionBuilder withParameters(DuckDBLogicalType... parameterTypes) throws SQLException { + ensureNotFinalized(); + if (parameterTypes == null) { + throw new SQLException("Parameter types cannot be null"); + } + for (DuckDBLogicalType parameterType : parameterTypes) { + withParameter(parameterType); + } + return this; + } + public DuckDBScalarFunctionBuilder withReturnType(Class returnType) throws SQLException { ensureNotFinalized(); if (returnType == null) { @@ -121,6 +133,17 @@ public DuckDBScalarFunctionBuilder withParameter(DuckDBColumnType parameterType) return addMappedParameterType(parameterType, null); } + public DuckDBScalarFunctionBuilder withParameters(DuckDBColumnType... parameterTypes) throws SQLException { + ensureNotFinalized(); + if (parameterTypes == null) { + throw new SQLException("Parameter types cannot be null"); + } + for (DuckDBColumnType parameterType : parameterTypes) { + withParameter(parameterType); + } + return this; + } + public DuckDBScalarFunctionBuilder withVectorizedFunction(DuckDBScalarFunction function) throws SQLException { ensureNotFinalized(); if (function == null) { @@ -288,14 +311,13 @@ public DuckDBScalarFunctionBuilder withVolatile() throws SQLException { return this; } - public DuckDBScalarFunctionBuilder withSpecialHandling() throws SQLException { + public DuckDBScalarFunctionBuilder withNullInNullOut() throws SQLException { ensureNotFinalized(); - this.specialHandlingFlag = true; - duckdb_scalar_function_set_special_handling(scalarFunctionRef); + this.nullInNullOutFlag = true; return this; } - public DuckDBRegisteredFunction register(Connection connection) throws SQLException { + public RegisteredFunction register(Connection connection) throws SQLException { ensureNotFinalized(); if (connection == null) { throw new SQLException("Connection cannot be null"); @@ -309,6 +331,9 @@ public DuckDBRegisteredFunction register(Connection connection) throws SQLExcept if (callback == null) { throw new SQLException("Scalar function callback must be defined"); } + if (!nullInNullOutFlag) { + duckdb_scalar_function_set_special_handling(scalarFunctionRef); + } DuckDBConnection duckConnection = unwrapConnection(connection); Lock connectionLock = duckConnection.connRefLock; connectionLock.lock(); @@ -318,9 +343,9 @@ public DuckDBRegisteredFunction register(Connection connection) throws SQLExcept if (status != 0) { throw new SQLException("Failed to register scalar function '" + functionName + "'"); } - DuckDBRegisteredFunction registeredFunction = DuckDBRegisteredFunction.of( + RegisteredFunction registeredFunction = DuckDBFunctions.createRegisteredFunction( functionName, parameterTypes, parameterColumnTypes, returnType, returnColumnType, callback, varArgType, - volatileFlag, specialHandlingFlag, propagateNullsFlag); + volatileFlag, nullInNullOutFlag, propagateNullsFlag); DuckDBDriver.registerFunction(registeredFunction); return registeredFunction; } finally { @@ -409,8 +434,7 @@ private DuckDBScalarFunctionBuilder setCallback(DuckDBScalarFunction function, b throws SQLException { this.callback = function; this.propagateNullsFlag = requiresNullPropagation; - duckdb_scalar_function_set_function(scalarFunctionRef, - new DuckDBScalarFunctionWrapper(function, propagateNullsFlag)); + duckdb_scalar_function_set_function(scalarFunctionRef, new DuckDBScalarFunctionWrapper(function)); return this; } @@ -421,6 +445,7 @@ private void ensurePrimitiveCallbackCompatible(String callbackMethodName) throws } private void enablePrimitiveNullPropagation() { + nullInNullOutFlag = false; propagateNullsFlag = true; } diff --git a/src/main/java/org/duckdb/DuckDBScalarFunctionWrapper.java b/src/main/java/org/duckdb/DuckDBScalarFunctionWrapper.java index 1bfc7f2d1..9390bd135 100644 --- a/src/main/java/org/duckdb/DuckDBScalarFunctionWrapper.java +++ b/src/main/java/org/duckdb/DuckDBScalarFunctionWrapper.java @@ -6,19 +6,16 @@ final class DuckDBScalarFunctionWrapper { private final DuckDBScalarFunction function; - private final boolean propagateNulls; - DuckDBScalarFunctionWrapper(DuckDBScalarFunction function, boolean propagateNulls) { + DuckDBScalarFunctionWrapper(DuckDBScalarFunction function) { this.function = function; - this.propagateNulls = propagateNulls; } public void execute(ByteBuffer functionInfo, ByteBuffer inputChunk, ByteBuffer outputVector) { try { DuckDBDataChunkReader inputReader = new DuckDBDataChunkReader(inputChunk); - DuckDBWritableVector outputWriter = new DuckDBWritableVectorImpl(outputVector, inputReader.rowCount()); - DuckDBScalarContext context = new DuckDBScalarContext(inputReader, outputWriter, propagateNulls); - function.apply(context); + DuckDBWritableVector outputWriter = new DuckDBWritableVector(outputVector, inputReader.rowCount()); + function.apply(inputReader, outputWriter); } catch (Throwable throwable) { reportError(functionInfo, throwable); } diff --git a/src/main/java/org/duckdb/DuckDBScalarRow.java b/src/main/java/org/duckdb/DuckDBScalarRow.java deleted file mode 100644 index 7b51cf29e..000000000 --- a/src/main/java/org/duckdb/DuckDBScalarRow.java +++ /dev/null @@ -1,471 +0,0 @@ -package org.duckdb; - -import java.math.BigDecimal; -import java.math.BigInteger; -import java.sql.Timestamp; -import java.time.LocalDate; -import java.time.LocalDateTime; -import java.time.OffsetDateTime; - -public final class DuckDBScalarRow { - private final DuckDBScalarContext context; - private final long rowIndex; - - DuckDBScalarRow(DuckDBScalarContext context, long rowIndex) { - this.context = context; - this.rowIndex = rowIndex; - } - - public long index() { - return rowIndex; - } - - public boolean isNull(int columnIndex) { - return input(columnIndex).isNull(rowIndex); - } - - public boolean getBoolean(int columnIndex) { - try { - return input(columnIndex).getBoolean(rowIndex); - } catch (DuckDBFunctionException exception) { - throw readFailure("BOOLEAN", columnIndex, exception); - } - } - - public boolean getBoolean(int columnIndex, boolean defaultVal) { - try { - return input(columnIndex).getBoolean(rowIndex, defaultVal); - } catch (DuckDBFunctionException exception) { - throw readFailure("BOOLEAN", columnIndex, exception); - } - } - - public byte getByte(int columnIndex) { - try { - return input(columnIndex).getByte(rowIndex); - } catch (DuckDBFunctionException exception) { - throw readFailure("TINYINT", columnIndex, exception); - } - } - - public byte getByte(int columnIndex, byte defaultVal) { - try { - return input(columnIndex).getByte(rowIndex, defaultVal); - } catch (DuckDBFunctionException exception) { - throw readFailure("TINYINT", columnIndex, exception); - } - } - - public short getShort(int columnIndex) { - try { - return input(columnIndex).getShort(rowIndex); - } catch (DuckDBFunctionException exception) { - throw readFailure("SMALLINT", columnIndex, exception); - } - } - - public short getShort(int columnIndex, short defaultVal) { - try { - return input(columnIndex).getShort(rowIndex, defaultVal); - } catch (DuckDBFunctionException exception) { - throw readFailure("SMALLINT", columnIndex, exception); - } - } - - public short getUint8(int columnIndex) { - try { - return input(columnIndex).getUint8(rowIndex); - } catch (DuckDBFunctionException exception) { - throw readFailure("UTINYINT", columnIndex, exception); - } - } - - public short getUint8(int columnIndex, short defaultVal) { - try { - return input(columnIndex).getUint8(rowIndex, defaultVal); - } catch (DuckDBFunctionException exception) { - throw readFailure("UTINYINT", columnIndex, exception); - } - } - - public int getUint16(int columnIndex) { - try { - return input(columnIndex).getUint16(rowIndex); - } catch (DuckDBFunctionException exception) { - throw readFailure("USMALLINT", columnIndex, exception); - } - } - - public int getUint16(int columnIndex, int defaultVal) { - try { - return input(columnIndex).getUint16(rowIndex, defaultVal); - } catch (DuckDBFunctionException exception) { - throw readFailure("USMALLINT", columnIndex, exception); - } - } - - public int getInt(int columnIndex) { - try { - return input(columnIndex).getInt(rowIndex); - } catch (DuckDBFunctionException exception) { - throw readFailure("INTEGER", columnIndex, exception); - } - } - - public int getInt(int columnIndex, int defaultVal) { - try { - return input(columnIndex).getInt(rowIndex, defaultVal); - } catch (DuckDBFunctionException exception) { - throw readFailure("INTEGER", columnIndex, exception); - } - } - - public long getUint32(int columnIndex) { - try { - return input(columnIndex).getUint32(rowIndex); - } catch (DuckDBFunctionException exception) { - throw readFailure("UINTEGER", columnIndex, exception); - } - } - - public long getUint32(int columnIndex, long defaultVal) { - try { - return input(columnIndex).getUint32(rowIndex, defaultVal); - } catch (DuckDBFunctionException exception) { - throw readFailure("UINTEGER", columnIndex, exception); - } - } - - public long getLong(int columnIndex) { - try { - return input(columnIndex).getLong(rowIndex); - } catch (DuckDBFunctionException exception) { - throw readFailure("BIGINT", columnIndex, exception); - } - } - - public long getLong(int columnIndex, long defaultVal) { - try { - return input(columnIndex).getLong(rowIndex, defaultVal); - } catch (DuckDBFunctionException exception) { - throw readFailure("BIGINT", columnIndex, exception); - } - } - - public BigInteger getHugeInt(int columnIndex) { - try { - return input(columnIndex).getHugeInt(rowIndex); - } catch (DuckDBFunctionException exception) { - throw readFailure("HUGEINT", columnIndex, exception); - } - } - - public BigInteger getUHugeInt(int columnIndex) { - try { - return input(columnIndex).getUHugeInt(rowIndex); - } catch (DuckDBFunctionException exception) { - throw readFailure("UHUGEINT", columnIndex, exception); - } - } - - public BigInteger getUint64(int columnIndex) { - try { - return input(columnIndex).getUint64(rowIndex); - } catch (DuckDBFunctionException exception) { - throw readFailure("UBIGINT", columnIndex, exception); - } - } - - public float getFloat(int columnIndex) { - try { - return input(columnIndex).getFloat(rowIndex); - } catch (DuckDBFunctionException exception) { - throw readFailure("FLOAT", columnIndex, exception); - } - } - - public float getFloat(int columnIndex, float defaultVal) { - try { - return input(columnIndex).getFloat(rowIndex, defaultVal); - } catch (DuckDBFunctionException exception) { - throw readFailure("FLOAT", columnIndex, exception); - } - } - - public double getDouble(int columnIndex) { - try { - return input(columnIndex).getDouble(rowIndex); - } catch (DuckDBFunctionException exception) { - throw readFailure("DOUBLE", columnIndex, exception); - } - } - - public double getDouble(int columnIndex, double defaultVal) { - try { - return input(columnIndex).getDouble(rowIndex, defaultVal); - } catch (DuckDBFunctionException exception) { - throw readFailure("DOUBLE", columnIndex, exception); - } - } - - public LocalDate getLocalDate(int columnIndex) { - try { - return input(columnIndex).getLocalDate(rowIndex); - } catch (DuckDBFunctionException exception) { - throw readFailure("DATE", columnIndex, exception); - } - } - - public java.sql.Date getDate(int columnIndex) { - try { - return input(columnIndex).getDate(rowIndex); - } catch (DuckDBFunctionException exception) { - throw readFailure("DATE", columnIndex, exception); - } - } - - public LocalDateTime getLocalDateTime(int columnIndex) { - try { - return input(columnIndex).getLocalDateTime(rowIndex); - } catch (DuckDBFunctionException exception) { - throw readFailure("TIMESTAMP", columnIndex, exception); - } - } - - public Timestamp getTimestamp(int columnIndex) { - try { - return input(columnIndex).getTimestamp(rowIndex); - } catch (DuckDBFunctionException exception) { - throw readFailure("TIMESTAMP", columnIndex, exception); - } - } - - public OffsetDateTime getOffsetDateTime(int columnIndex) { - try { - return input(columnIndex).getOffsetDateTime(rowIndex); - } catch (DuckDBFunctionException exception) { - throw readFailure("TIMESTAMP WITH TIME ZONE", columnIndex, exception); - } - } - - public BigDecimal getBigDecimal(int columnIndex) { - try { - return input(columnIndex).getBigDecimal(rowIndex); - } catch (DuckDBFunctionException exception) { - throw readFailure("DECIMAL", columnIndex, exception); - } - } - - public String getString(int columnIndex) { - try { - return input(columnIndex).getString(rowIndex); - } catch (DuckDBFunctionException exception) { - throw readFailure("VARCHAR", columnIndex, exception); - } - } - - public void setNull() { - try { - context.output().setNull(rowIndex); - } catch (DuckDBFunctionException exception) { - throw writeFailure("NULL", exception); - } - } - - public void setBoolean(boolean value) { - try { - context.output().setBoolean(rowIndex, value); - } catch (DuckDBFunctionException exception) { - throw writeFailure("BOOLEAN", exception); - } - } - - public void setByte(byte value) { - try { - context.output().setByte(rowIndex, value); - } catch (DuckDBFunctionException exception) { - throw writeFailure("TINYINT", exception); - } - } - - public void setShort(short value) { - try { - context.output().setShort(rowIndex, value); - } catch (DuckDBFunctionException exception) { - throw writeFailure("SMALLINT", exception); - } - } - - public void setUint8(int value) { - try { - context.output().setUint8(rowIndex, value); - } catch (DuckDBFunctionException exception) { - throw writeFailure("UTINYINT", exception); - } - } - - public void setUint16(int value) { - try { - context.output().setUint16(rowIndex, value); - } catch (DuckDBFunctionException exception) { - throw writeFailure("USMALLINT", exception); - } - } - - public void setInt(int value) { - try { - context.output().setInt(rowIndex, value); - } catch (DuckDBFunctionException exception) { - throw writeFailure("INTEGER", exception); - } - } - - public void setUint32(long value) { - try { - context.output().setUint32(rowIndex, value); - } catch (DuckDBFunctionException exception) { - throw writeFailure("UINTEGER", exception); - } - } - - public void setLong(long value) { - try { - context.output().setLong(rowIndex, value); - } catch (DuckDBFunctionException exception) { - throw writeFailure("BIGINT", exception); - } - } - - public void setHugeInt(BigInteger value) { - try { - context.output().setHugeInt(rowIndex, value); - } catch (DuckDBFunctionException exception) { - throw writeFailure("HUGEINT", exception); - } - } - - public void setUHugeInt(BigInteger value) { - try { - context.output().setUHugeInt(rowIndex, value); - } catch (DuckDBFunctionException exception) { - throw writeFailure("UHUGEINT", exception); - } - } - - public void setUint64(BigInteger value) { - try { - context.output().setUint64(rowIndex, value); - } catch (DuckDBFunctionException exception) { - throw writeFailure("UBIGINT", exception); - } - } - - public void setFloat(float value) { - try { - context.output().setFloat(rowIndex, value); - } catch (DuckDBFunctionException exception) { - throw writeFailure("FLOAT", exception); - } - } - - public void setDouble(double value) { - try { - context.output().setDouble(rowIndex, value); - } catch (DuckDBFunctionException exception) { - throw writeFailure("DOUBLE", exception); - } - } - - public void setDate(LocalDate value) { - try { - context.output().setDate(rowIndex, value); - } catch (DuckDBFunctionException exception) { - throw writeFailure("DATE", exception); - } - } - - public void setDate(java.sql.Date value) { - try { - context.output().setDate(rowIndex, value); - } catch (DuckDBFunctionException exception) { - throw writeFailure("DATE", exception); - } - } - - public void setDate(java.util.Date value) { - try { - context.output().setDate(rowIndex, value); - } catch (DuckDBFunctionException exception) { - throw writeFailure("DATE", exception); - } - } - - public void setTimestamp(LocalDateTime value) { - try { - context.output().setTimestamp(rowIndex, value); - } catch (DuckDBFunctionException exception) { - throw writeFailure("TIMESTAMP", exception); - } - } - - public void setTimestamp(Timestamp value) { - try { - context.output().setTimestamp(rowIndex, value); - } catch (DuckDBFunctionException exception) { - throw writeFailure("TIMESTAMP", exception); - } - } - - public void setTimestamp(java.util.Date value) { - try { - context.output().setTimestamp(rowIndex, value); - } catch (DuckDBFunctionException exception) { - throw writeFailure("TIMESTAMP", exception); - } - } - - public void setTimestamp(LocalDate value) { - try { - context.output().setTimestamp(rowIndex, value); - } catch (DuckDBFunctionException exception) { - throw writeFailure("TIMESTAMP", exception); - } - } - - public void setOffsetDateTime(OffsetDateTime value) { - try { - context.output().setOffsetDateTime(rowIndex, value); - } catch (DuckDBFunctionException exception) { - throw writeFailure("TIMESTAMP WITH TIME ZONE", exception); - } - } - - public void setBigDecimal(BigDecimal value) { - try { - context.output().setBigDecimal(rowIndex, value); - } catch (DuckDBFunctionException exception) { - throw writeFailure("DECIMAL", exception); - } - } - - public void setString(String value) { - try { - context.output().setString(rowIndex, value); - } catch (DuckDBFunctionException exception) { - throw writeFailure("VARCHAR", exception); - } - } - - private DuckDBReadableVector input(int columnIndex) { - return context.inputUnchecked(columnIndex); - } - - private DuckDBFunctionException readFailure(String type, int columnIndex, DuckDBFunctionException exception) { - return new DuckDBFunctionException( - "Failed to read " + type + " from input column " + columnIndex + " at row " + rowIndex, exception); - } - - private DuckDBFunctionException writeFailure(String type, DuckDBFunctionException exception) { - return new DuckDBFunctionException("Failed to write " + type + " to output row " + rowIndex, exception); - } -} diff --git a/src/main/java/org/duckdb/DuckDBWritableVector.java b/src/main/java/org/duckdb/DuckDBWritableVector.java index 1179e9084..4b46d7450 100644 --- a/src/main/java/org/duckdb/DuckDBWritableVector.java +++ b/src/main/java/org/duckdb/DuckDBWritableVector.java @@ -1,116 +1,590 @@ package org.duckdb; +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.duckdb.DuckDBBindings.*; + import java.math.BigDecimal; import java.math.BigInteger; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.sql.Timestamp; +import java.time.Instant; import java.time.LocalDate; import java.time.LocalDateTime; import java.time.OffsetDateTime; - -/** - * Mutable scalar callback view over a DuckDB output vector. - * - *

Implementations throw {@link DuckDBFunctionException} for callback-time type/value write - * errors. Invalid row indexes throw {@link IndexOutOfBoundsException}. - */ -public abstract class DuckDBWritableVector { - public abstract DuckDBColumnType getType(); - - public abstract long rowCount(); - - public abstract void addNull(); - - public abstract void setNull(long row); - - public abstract void addBoolean(boolean value); - - public abstract void setBoolean(long row, boolean value); - - public abstract void addByte(byte value); - - public abstract void setByte(long row, byte value); - - public abstract void addShort(short value); - - public abstract void setShort(long row, short value); - - public abstract void addUint8(int value); - - public abstract void setUint8(long row, int value); - - public abstract void addUint16(int value); - - public abstract void setUint16(long row, int value); - - public abstract void addInt(int value); - - public abstract void setInt(long row, int value); - - public abstract void addUint32(long value); - - public abstract void setUint32(long row, long value); - - public abstract void addLong(long value); - - public abstract void setLong(long row, long value); - - public abstract void addHugeInt(BigInteger value); - - public abstract void setHugeInt(long row, BigInteger value); - - public abstract void addUHugeInt(BigInteger value); - - public abstract void setUHugeInt(long row, BigInteger value); - - public abstract void addUint64(BigInteger value); - - public abstract void setUint64(long row, BigInteger value); - - public abstract void addFloat(float value); - - public abstract void setFloat(long row, float value); - - public abstract void addDouble(double value); - - public abstract void setDouble(long row, double value); - - public abstract void addDate(LocalDate value); - - public abstract void setDate(long row, LocalDate value); - - public abstract void addDate(java.sql.Date value); - - public abstract void setDate(long row, java.sql.Date value); - - public abstract void addDate(java.util.Date value); - - public abstract void setDate(long row, java.util.Date value); - - public abstract void addTimestamp(LocalDateTime value); - - public abstract void setTimestamp(long row, LocalDateTime value); - - public abstract void addTimestamp(Timestamp value); - - public abstract void setTimestamp(long row, Timestamp value); - - public abstract void addTimestamp(java.util.Date value); - - public abstract void setTimestamp(long row, java.util.Date value); - - public abstract void addTimestamp(LocalDate value); - - public abstract void setTimestamp(long row, LocalDate value); - - public abstract void addOffsetDateTime(OffsetDateTime value); - - public abstract void setOffsetDateTime(long row, OffsetDateTime value); - - public abstract void addBigDecimal(BigDecimal value); - - public abstract void setBigDecimal(long row, BigDecimal value); - - public abstract void addString(String value); - - public abstract void setString(long row, String value); +import java.time.ZoneId; +import java.time.ZoneOffset; +import org.duckdb.DuckDBFunctions.FunctionException; + +public final class DuckDBWritableVector { + private static final BigInteger UINT64_MAX = new BigInteger("18446744073709551615"); + private static final ByteOrder NATIVE_ORDER = ByteOrder.nativeOrder(); + + private final ByteBuffer vectorRef; + private final long rowCount; + private final DuckDBVectorTypeInfo typeInfo; + private final ByteBuffer data; + private final ByteBuffer validity; + + DuckDBWritableVector(ByteBuffer vectorRef, long rowCount) { + if (vectorRef == null) { + throw new FunctionException("Invalid vector reference"); + } + this.vectorRef = vectorRef; + this.rowCount = rowCount; + try { + this.typeInfo = DuckDBVectorTypeInfo.fromVector(vectorRef); + } catch (java.sql.SQLException exception) { + throw new FunctionException("Failed to resolve vector type info", exception); + } + this.data = + duckdb_vector_get_data(vectorRef, Math.multiplyExact(rowCount, typeInfo.widthBytes)).order(NATIVE_ORDER); + duckdb_vector_ensure_validity_writable(vectorRef); + this.validity = duckdb_vector_get_validity(vectorRef, rowCount); + this.validity.order(NATIVE_ORDER); + } + + public DuckDBColumnType getType() { + return typeInfo.columnType; + } + + public long rowCount() { + return rowCount; + } + + public void setNull(long row) { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + setRowValidity(row, false); + } + + public void setBoolean(long row, boolean value) { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = typeMismatchMessage(DuckDBColumnType.BOOLEAN); + if (typeError != null) { + throw new FunctionException(typeError); + } + data.put(checkedRowIndex(row), value ? (byte) 1 : (byte) 0); + markValid(row); + } + + public void setByte(long row, byte value) { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = typeMismatchMessage(DuckDBColumnType.TINYINT); + if (typeError != null) { + throw new FunctionException(typeError); + } + data.put(checkedRowIndex(row), value); + markValid(row); + } + + public void setShort(long row, short value) { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = typeMismatchMessage(DuckDBColumnType.SMALLINT); + if (typeError != null) { + throw new FunctionException(typeError); + } + data.putShort(checkedByteOffset(row, Short.BYTES), value); + markValid(row); + } + + public void setUint8(long row, int value) { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = typeMismatchMessage(DuckDBColumnType.UTINYINT); + if (typeError != null) { + throw new FunctionException(typeError); + } + String rangeError = unsignedRangeErrorMessage("UTINYINT", value, 0xFFL); + if (rangeError != null) { + throw new FunctionException(rangeError); + } + data.put(checkedRowIndex(row), (byte) value); + markValid(row); + } + + public void setUint16(long row, int value) { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = typeMismatchMessage(DuckDBColumnType.USMALLINT); + if (typeError != null) { + throw new FunctionException(typeError); + } + String rangeError = unsignedRangeErrorMessage("USMALLINT", value, 0xFFFFL); + if (rangeError != null) { + throw new FunctionException(rangeError); + } + data.putShort(checkedByteOffset(row, Short.BYTES), (short) value); + markValid(row); + } + + public void setInt(long row, int value) { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = typeMismatchMessage(DuckDBColumnType.INTEGER); + if (typeError != null) { + throw new FunctionException(typeError); + } + data.putInt(checkedByteOffset(row, Integer.BYTES), value); + markValid(row); + } + + public void setUint32(long row, long value) { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = typeMismatchMessage(DuckDBColumnType.UINTEGER); + if (typeError != null) { + throw new FunctionException(typeError); + } + String rangeError = unsignedRangeErrorMessage("UINTEGER", value, 0xFFFFFFFFL); + if (rangeError != null) { + throw new FunctionException(rangeError); + } + data.putInt(checkedByteOffset(row, Integer.BYTES), (int) value); + markValid(row); + } + + public void setLong(long row, long value) { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = typeMismatchMessage(DuckDBColumnType.BIGINT); + if (typeError != null) { + throw new FunctionException(typeError); + } + data.putLong(checkedByteOffset(row, Long.BYTES), value); + markValid(row); + } + + public void setHugeInt(long row, BigInteger value) { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = typeMismatchMessage(DuckDBColumnType.HUGEINT); + if (typeError != null) { + throw new FunctionException(typeError); + } + if (value == null) { + setNull(row); + return; + } + DuckDBHugeInt hugeInt; + try { + hugeInt = new DuckDBHugeInt(value); + } catch (java.sql.SQLException exception) { + throw new FunctionException("Value out of range for HUGEINT: " + value, exception); + } + int offset = checkedByteOffset(row, typeInfo.widthBytes); + data.putLong(offset, hugeInt.lower()); + data.putLong(offset + Long.BYTES, hugeInt.upper()); + markValid(row); + } + + public void setUHugeInt(long row, BigInteger value) { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = typeMismatchMessage(DuckDBColumnType.UHUGEINT); + if (typeError != null) { + throw new FunctionException(typeError); + } + if (value == null) { + setNull(row); + return; + } + if (value.signum() < 0 || value.compareTo(DuckDBHugeInt.UHUGE_INT_MAX) > 0) { + throw new FunctionException("Value out of range for UHUGEINT: " + value); + } + int offset = checkedByteOffset(row, typeInfo.widthBytes); + data.putLong(offset, value.longValue()); + data.putLong(offset + Long.BYTES, value.shiftRight(Long.SIZE).longValue()); + markValid(row); + } + + public void setUint64(long row, BigInteger value) { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = typeMismatchMessage(DuckDBColumnType.UBIGINT); + if (typeError != null) { + throw new FunctionException(typeError); + } + if (value == null) { + setNull(row); + return; + } + if (value.signum() < 0 || value.compareTo(UINT64_MAX) > 0) { + throw new FunctionException("Value out of range for UBIGINT: " + value); + } + data.putLong(checkedByteOffset(row, Long.BYTES), value.longValue()); + markValid(row); + } + + public void setFloat(long row, float value) { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = typeMismatchMessage(DuckDBColumnType.FLOAT); + if (typeError != null) { + throw new FunctionException(typeError); + } + data.putFloat(checkedByteOffset(row, Float.BYTES), value); + markValid(row); + } + + public void setDouble(long row, double value) { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = typeMismatchMessage(DuckDBColumnType.DOUBLE); + if (typeError != null) { + throw new FunctionException(typeError); + } + data.putDouble(checkedByteOffset(row, Double.BYTES), value); + markValid(row); + } + + public void setDate(long row, LocalDate value) { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = typeMismatchMessage(DuckDBColumnType.DATE); + if (typeError != null) { + throw new FunctionException(typeError); + } + if (value == null) { + setNull(row); + return; + } + long days = value.toEpochDay(); + if (days < Integer.MIN_VALUE || days > Integer.MAX_VALUE) { + throw new FunctionException("Value out of range for DATE: " + value); + } + data.putInt(checkedByteOffset(row, Integer.BYTES), (int) days); + markValid(row); + } + + public void setDate(long row, java.sql.Date value) { + setDate(row, value == null ? null : value.toLocalDate()); + } + + public void setDate(long row, java.util.Date value) { + if (value == null) { + setNull(row); + return; + } + if (value instanceof java.sql.Date) { + setDate(row, (java.sql.Date) value); + return; + } + LocalDate localDate = Instant.ofEpochMilli(value.getTime()).atZone(ZoneOffset.UTC).toLocalDate(); + setDate(row, localDate); + } + + public void setTimestamp(long row, LocalDateTime value) { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = timestampTypeMismatchMessage(false); + if (typeError != null) { + throw new FunctionException(typeError); + } + if (value == null) { + setNull(row); + return; + } + data.putLong(checkedByteOffset(row, Long.BYTES), encodeLocalDateTime(value)); + markValid(row); + } + + public void setTimestamp(long row, Timestamp value) { + if (value == null) { + setNull(row); + return; + } + if (typeInfo.columnType == DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE) { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + data.putLong(checkedByteOffset(row, Long.BYTES), encodeInstant(value.toInstant())); + markValid(row); + return; + } + setTimestamp(row, value.toLocalDateTime()); + } + + public void setTimestamp(long row, java.util.Date value) { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = timestampTypeMismatchMessage(false); + if (typeError != null) { + throw new FunctionException(typeError); + } + if (value == null) { + setNull(row); + return; + } + if (value instanceof Timestamp) { + setTimestamp(row, (Timestamp) value); + return; + } + data.putLong(checkedByteOffset(row, Long.BYTES), encodeJavaUtilDate(value)); + markValid(row); + } + + public void setTimestamp(long row, LocalDate value) { + if (value == null) { + setNull(row); + return; + } + if (typeInfo.columnType == DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE) { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + Instant instant = value.atStartOfDay(ZoneId.systemDefault()).toInstant(); + data.putLong(checkedByteOffset(row, Long.BYTES), encodeInstant(instant)); + markValid(row); + return; + } + setTimestamp(row, value.atStartOfDay()); + } + + public void setOffsetDateTime(long row, OffsetDateTime value) { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = timestampTypeMismatchMessage(true); + if (typeError != null) { + throw new FunctionException(typeError); + } + if (value == null) { + setNull(row); + return; + } + data.putLong( + checkedByteOffset(row, Long.BYTES), + DuckDBTimestamp.localDateTime2Micros(value.withOffsetSameInstant(ZoneOffset.UTC).toLocalDateTime())); + markValid(row); + } + + public void setBigDecimal(long row, BigDecimal value) { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = typeMismatchMessage(DuckDBColumnType.DECIMAL); + if (typeError != null) { + throw new FunctionException(typeError); + } + if (value == null) { + setNull(row); + return; + } + BigDecimal scaled; + try { + scaled = value.setScale(typeInfo.decimalMeta.scale); + } catch (ArithmeticException e) { + throw decimalOutOfRange(value, e); + } + if (scaled.precision() > typeInfo.decimalMeta.width) { + throw decimalOutOfRange(value); + } + switch (typeInfo.storageType) { + case DUCKDB_TYPE_SMALLINT: + try { + data.putShort(checkedByteOffset(row, Short.BYTES), scaled.unscaledValue().shortValueExact()); + } catch (ArithmeticException e) { + throw decimalOutOfRange(value, e); + } + break; + case DUCKDB_TYPE_INTEGER: + try { + data.putInt(checkedByteOffset(row, Integer.BYTES), scaled.unscaledValue().intValueExact()); + } catch (ArithmeticException e) { + throw decimalOutOfRange(value, e); + } + break; + case DUCKDB_TYPE_BIGINT: + try { + data.putLong(checkedByteOffset(row, Long.BYTES), scaled.unscaledValue().longValueExact()); + } catch (ArithmeticException e) { + throw decimalOutOfRange(value, e); + } + break; + case DUCKDB_TYPE_HUGEINT: { + BigInteger unscaled = scaled.unscaledValue(); + int offset = checkedByteOffset(row, typeInfo.widthBytes); + data.putLong(offset, unscaled.longValue()); + data.putLong(offset + Long.BYTES, unscaled.shiftRight(Long.SIZE).longValue()); + break; + } + default: + throw new FunctionException("Unsupported DECIMAL storage type: " + typeInfo.storageType); + } + markValid(row); + } + + public void setString(long row, String value) { + String rowError = rowIndexErrorMessage(row); + if (rowError != null) { + throw new IndexOutOfBoundsException(rowError); + } + String typeError = typeMismatchMessage(DuckDBColumnType.VARCHAR); + if (typeError != null) { + throw new FunctionException(typeError); + } + if (value == null) { + setNull(row); + return; + } + duckdb_vector_assign_string_element_len(vectorRef, row, value.getBytes(UTF_8)); + markValid(row); + } + + ByteBuffer vectorRef() { + return vectorRef; + } + + private void markValid(long row) { + setRowValidity(row, true); + } + + private void setRowValidity(long row, boolean valid) { + int entryOffset = Math.toIntExact(Math.multiplyExact(row / Long.SIZE, (long) Long.BYTES)); + long bitIndex = row % Long.SIZE; + long mask = 1L << bitIndex; + long entry = validity.getLong(entryOffset); + if (valid) { + entry |= mask; + } else { + entry &= ~mask; + } + validity.putLong(entryOffset, entry); + } + + private String typeMismatchMessage(DuckDBColumnType expected) { + if (typeInfo.columnType != expected) { + return "Expected vector type " + expected + ", found " + typeInfo.columnType; + } + return null; + } + + private String rowIndexErrorMessage(long row) { + if (row < 0 || row >= rowCount) { + return "Row index out of bounds: " + row; + } + return null; + } + + private String timestampTypeMismatchMessage(boolean requireTimezone) { + if (requireTimezone) { + if (typeInfo.columnType != DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE) { + return "Expected vector type TIMESTAMP WITH TIME ZONE, found " + typeInfo.columnType; + } + return null; + } + switch (typeInfo.columnType) { + case TIMESTAMP: + case TIMESTAMP_S: + case TIMESTAMP_MS: + case TIMESTAMP_NS: + case TIMESTAMP_WITH_TIME_ZONE: + return null; + default: + return "Expected vector type TIMESTAMP*, found " + typeInfo.columnType; + } + } + + private long encodeLocalDateTime(LocalDateTime value) { + Instant instant; + if (typeInfo.columnType == DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE) { + instant = value.atZone(ZoneId.systemDefault()).toInstant(); + } else { + instant = value.toInstant(ZoneOffset.UTC); + } + return encodeInstant(instant); + } + + private long encodeJavaUtilDate(java.util.Date value) { + return encodeInstant(Instant.ofEpochMilli(value.getTime())); + } + + private long encodeInstant(Instant instant) { + long epochSeconds = instant.getEpochSecond(); + int nano = instant.getNano(); + switch (typeInfo.capiType) { + case DUCKDB_TYPE_TIMESTAMP_S: + return epochSeconds; + case DUCKDB_TYPE_TIMESTAMP_MS: + return Math.addExact(Math.multiplyExact(epochSeconds, 1_000L), nano / 1_000_000L); + case DUCKDB_TYPE_TIMESTAMP: + case DUCKDB_TYPE_TIMESTAMP_TZ: + return Math.addExact(Math.multiplyExact(epochSeconds, 1_000_000L), nano / 1_000L); + case DUCKDB_TYPE_TIMESTAMP_NS: + return Math.addExact(Math.multiplyExact(epochSeconds, 1_000_000_000L), nano); + default: + throw new FunctionException("Expected vector type TIMESTAMP*, found " + typeInfo.columnType); + } + } + + private static String unsignedRangeErrorMessage(String typeName, long value, long maxValue) { + if (value < 0 || value > maxValue) { + return "Value out of range for " + typeName + ": " + value; + } + return null; + } + + private FunctionException decimalOutOfRange(BigDecimal value) { + return new FunctionException("Value out of range for " + decimalTypeName() + ": " + value); + } + + private FunctionException decimalOutOfRange(BigDecimal value, ArithmeticException cause) { + FunctionException exception = decimalOutOfRange(value); + exception.initCause(cause); + return exception; + } + + private String decimalTypeName() { + return "DECIMAL(" + typeInfo.decimalMeta.width + "," + typeInfo.decimalMeta.scale + ")"; + } + + private int checkedRowIndex(long row) { + return Math.toIntExact(row); + } + + private int checkedByteOffset(long row, int elementWidth) { + return Math.toIntExact(Math.multiplyExact(row, (long) elementWidth)); + } } diff --git a/src/main/java/org/duckdb/DuckDBWritableVectorImpl.java b/src/main/java/org/duckdb/DuckDBWritableVectorImpl.java deleted file mode 100644 index 33887abb7..000000000 --- a/src/main/java/org/duckdb/DuckDBWritableVectorImpl.java +++ /dev/null @@ -1,765 +0,0 @@ -package org.duckdb; - -import static java.nio.charset.StandardCharsets.UTF_8; -import static org.duckdb.DuckDBBindings.*; - -import java.math.BigDecimal; -import java.math.BigInteger; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.sql.Timestamp; -import java.time.Instant; -import java.time.LocalDate; -import java.time.LocalDateTime; -import java.time.OffsetDateTime; -import java.time.ZoneId; -import java.time.ZoneOffset; - -final class DuckDBWritableVectorImpl extends DuckDBWritableVector { - private static final BigInteger UINT64_MAX = new BigInteger("18446744073709551615"); - private static final ByteOrder NATIVE_ORDER = ByteOrder.nativeOrder(); - - private final ByteBuffer vectorRef; - private final long rowCount; - private final DuckDBVectorTypeInfo typeInfo; - private final ByteBuffer data; - private ByteBuffer validity; - private long appendIndex; - - DuckDBWritableVectorImpl(ByteBuffer vectorRef, long rowCount) { - if (vectorRef == null) { - throw new DuckDBFunctionException("Invalid vector reference"); - } - this.vectorRef = vectorRef; - this.rowCount = rowCount; - try { - this.typeInfo = DuckDBVectorTypeInfo.fromVector(vectorRef); - } catch (java.sql.SQLException exception) { - throw new DuckDBFunctionException("Failed to resolve vector type info", exception); - } - this.data = - duckdb_vector_get_data(vectorRef, Math.multiplyExact(rowCount, typeInfo.widthBytes)).order(NATIVE_ORDER); - ByteBuffer validityBuffer = duckdb_vector_get_validity(vectorRef, rowCount); - this.validity = validityBuffer == null ? null : validityBuffer.order(NATIVE_ORDER); - } - - @Override - public DuckDBColumnType getType() { - return typeInfo.columnType; - } - - @Override - public long rowCount() { - return rowCount; - } - - @Override - public void addNull() { - setNull(nextAppendRow()); - } - - @Override - public void setNull(long row) { - String rowError = rowIndexErrorMessage(row); - if (rowError != null) { - throw new IndexOutOfBoundsException(rowError); - } - ensureValidity(); - setRowValidity(row, false); - advanceAppendIndex(row); - } - - @Override - public void addBoolean(boolean value) { - setBoolean(nextAppendRow(), value); - } - - @Override - public void setBoolean(long row, boolean value) { - String rowError = rowIndexErrorMessage(row); - if (rowError != null) { - throw new IndexOutOfBoundsException(rowError); - } - String typeError = typeMismatchMessage(DuckDBColumnType.BOOLEAN); - if (typeError != null) { - throw new DuckDBFunctionException(typeError); - } - data.put(checkedRowIndex(row), value ? (byte) 1 : (byte) 0); - markValid(row); - } - - @Override - public void addByte(byte value) { - setByte(nextAppendRow(), value); - } - - @Override - public void setByte(long row, byte value) { - String rowError = rowIndexErrorMessage(row); - if (rowError != null) { - throw new IndexOutOfBoundsException(rowError); - } - String typeError = typeMismatchMessage(DuckDBColumnType.TINYINT); - if (typeError != null) { - throw new DuckDBFunctionException(typeError); - } - data.put(checkedRowIndex(row), value); - markValid(row); - } - - @Override - public void addShort(short value) { - setShort(nextAppendRow(), value); - } - - @Override - public void setShort(long row, short value) { - String rowError = rowIndexErrorMessage(row); - if (rowError != null) { - throw new IndexOutOfBoundsException(rowError); - } - String typeError = typeMismatchMessage(DuckDBColumnType.SMALLINT); - if (typeError != null) { - throw new DuckDBFunctionException(typeError); - } - data.putShort(checkedByteOffset(row, Short.BYTES), value); - markValid(row); - } - - @Override - public void addUint8(int value) { - setUint8(nextAppendRow(), value); - } - - @Override - public void setUint8(long row, int value) { - String rowError = rowIndexErrorMessage(row); - if (rowError != null) { - throw new IndexOutOfBoundsException(rowError); - } - String typeError = typeMismatchMessage(DuckDBColumnType.UTINYINT); - if (typeError != null) { - throw new DuckDBFunctionException(typeError); - } - String rangeError = unsignedRangeErrorMessage("UTINYINT", value, 0xFFL); - if (rangeError != null) { - throw new DuckDBFunctionException(rangeError); - } - data.put(checkedRowIndex(row), (byte) value); - markValid(row); - } - - @Override - public void addUint16(int value) { - setUint16(nextAppendRow(), value); - } - - @Override - public void setUint16(long row, int value) { - String rowError = rowIndexErrorMessage(row); - if (rowError != null) { - throw new IndexOutOfBoundsException(rowError); - } - String typeError = typeMismatchMessage(DuckDBColumnType.USMALLINT); - if (typeError != null) { - throw new DuckDBFunctionException(typeError); - } - String rangeError = unsignedRangeErrorMessage("USMALLINT", value, 0xFFFFL); - if (rangeError != null) { - throw new DuckDBFunctionException(rangeError); - } - data.putShort(checkedByteOffset(row, Short.BYTES), (short) value); - markValid(row); - } - - @Override - public void addInt(int value) { - setInt(nextAppendRow(), value); - } - - @Override - public void setInt(long row, int value) { - String rowError = rowIndexErrorMessage(row); - if (rowError != null) { - throw new IndexOutOfBoundsException(rowError); - } - String typeError = typeMismatchMessage(DuckDBColumnType.INTEGER); - if (typeError != null) { - throw new DuckDBFunctionException(typeError); - } - data.putInt(checkedByteOffset(row, Integer.BYTES), value); - markValid(row); - } - - @Override - public void addUint32(long value) { - setUint32(nextAppendRow(), value); - } - - @Override - public void setUint32(long row, long value) { - String rowError = rowIndexErrorMessage(row); - if (rowError != null) { - throw new IndexOutOfBoundsException(rowError); - } - String typeError = typeMismatchMessage(DuckDBColumnType.UINTEGER); - if (typeError != null) { - throw new DuckDBFunctionException(typeError); - } - String rangeError = unsignedRangeErrorMessage("UINTEGER", value, 0xFFFFFFFFL); - if (rangeError != null) { - throw new DuckDBFunctionException(rangeError); - } - data.putInt(checkedByteOffset(row, Integer.BYTES), (int) value); - markValid(row); - } - - @Override - public void addLong(long value) { - setLong(nextAppendRow(), value); - } - - @Override - public void setLong(long row, long value) { - String rowError = rowIndexErrorMessage(row); - if (rowError != null) { - throw new IndexOutOfBoundsException(rowError); - } - String typeError = typeMismatchMessage(DuckDBColumnType.BIGINT); - if (typeError != null) { - throw new DuckDBFunctionException(typeError); - } - data.putLong(checkedByteOffset(row, Long.BYTES), value); - markValid(row); - } - - @Override - public void addHugeInt(BigInteger value) { - setHugeInt(nextAppendRow(), value); - } - - @Override - public void setHugeInt(long row, BigInteger value) { - String rowError = rowIndexErrorMessage(row); - if (rowError != null) { - throw new IndexOutOfBoundsException(rowError); - } - String typeError = typeMismatchMessage(DuckDBColumnType.HUGEINT); - if (typeError != null) { - throw new DuckDBFunctionException(typeError); - } - if (value == null) { - setNull(row); - return; - } - DuckDBHugeInt hugeInt; - try { - hugeInt = new DuckDBHugeInt(value); - } catch (java.sql.SQLException exception) { - throw new DuckDBFunctionException("Value out of range for HUGEINT: " + value, exception); - } - int offset = checkedByteOffset(row, typeInfo.widthBytes); - data.putLong(offset, hugeInt.lower()); - data.putLong(offset + Long.BYTES, hugeInt.upper()); - markValid(row); - } - - @Override - public void addUHugeInt(BigInteger value) { - setUHugeInt(nextAppendRow(), value); - } - - @Override - public void setUHugeInt(long row, BigInteger value) { - String rowError = rowIndexErrorMessage(row); - if (rowError != null) { - throw new IndexOutOfBoundsException(rowError); - } - String typeError = typeMismatchMessage(DuckDBColumnType.UHUGEINT); - if (typeError != null) { - throw new DuckDBFunctionException(typeError); - } - if (value == null) { - setNull(row); - return; - } - if (value.signum() < 0 || value.compareTo(DuckDBHugeInt.UHUGE_INT_MAX) > 0) { - throw new DuckDBFunctionException("Value out of range for UHUGEINT: " + value); - } - int offset = checkedByteOffset(row, typeInfo.widthBytes); - data.putLong(offset, value.longValue()); - data.putLong(offset + Long.BYTES, value.shiftRight(Long.SIZE).longValue()); - markValid(row); - } - - @Override - public void addUint64(BigInteger value) { - setUint64(nextAppendRow(), value); - } - - @Override - public void setUint64(long row, BigInteger value) { - String rowError = rowIndexErrorMessage(row); - if (rowError != null) { - throw new IndexOutOfBoundsException(rowError); - } - String typeError = typeMismatchMessage(DuckDBColumnType.UBIGINT); - if (typeError != null) { - throw new DuckDBFunctionException(typeError); - } - if (value == null) { - setNull(row); - return; - } - if (value.signum() < 0 || value.compareTo(UINT64_MAX) > 0) { - throw new DuckDBFunctionException("Value out of range for UBIGINT: " + value); - } - data.putLong(checkedByteOffset(row, Long.BYTES), value.longValue()); - markValid(row); - } - - @Override - public void addFloat(float value) { - setFloat(nextAppendRow(), value); - } - - @Override - public void setFloat(long row, float value) { - String rowError = rowIndexErrorMessage(row); - if (rowError != null) { - throw new IndexOutOfBoundsException(rowError); - } - String typeError = typeMismatchMessage(DuckDBColumnType.FLOAT); - if (typeError != null) { - throw new DuckDBFunctionException(typeError); - } - data.putFloat(checkedByteOffset(row, Float.BYTES), value); - markValid(row); - } - - @Override - public void addDouble(double value) { - setDouble(nextAppendRow(), value); - } - - @Override - public void setDouble(long row, double value) { - String rowError = rowIndexErrorMessage(row); - if (rowError != null) { - throw new IndexOutOfBoundsException(rowError); - } - String typeError = typeMismatchMessage(DuckDBColumnType.DOUBLE); - if (typeError != null) { - throw new DuckDBFunctionException(typeError); - } - data.putDouble(checkedByteOffset(row, Double.BYTES), value); - markValid(row); - } - - @Override - public void addDate(LocalDate value) { - setDate(nextAppendRow(), value); - } - - @Override - public void setDate(long row, LocalDate value) { - String rowError = rowIndexErrorMessage(row); - if (rowError != null) { - throw new IndexOutOfBoundsException(rowError); - } - String typeError = typeMismatchMessage(DuckDBColumnType.DATE); - if (typeError != null) { - throw new DuckDBFunctionException(typeError); - } - if (value == null) { - setNull(row); - return; - } - long days = value.toEpochDay(); - if (days < Integer.MIN_VALUE || days > Integer.MAX_VALUE) { - throw new DuckDBFunctionException("Value out of range for DATE: " + value); - } - data.putInt(checkedByteOffset(row, Integer.BYTES), (int) days); - markValid(row); - } - - @Override - public void addDate(java.sql.Date value) { - setDate(nextAppendRow(), value); - } - - @Override - public void setDate(long row, java.sql.Date value) { - setDate(row, value == null ? null : value.toLocalDate()); - } - - @Override - public void addDate(java.util.Date value) { - setDate(nextAppendRow(), value); - } - - @Override - public void setDate(long row, java.util.Date value) { - if (value == null) { - setNull(row); - return; - } - if (value instanceof java.sql.Date) { - setDate(row, (java.sql.Date) value); - return; - } - LocalDate localDate = Instant.ofEpochMilli(value.getTime()).atZone(ZoneOffset.UTC).toLocalDate(); - setDate(row, localDate); - } - - @Override - public void addTimestamp(LocalDateTime value) { - setTimestamp(nextAppendRow(), value); - } - - @Override - public void setTimestamp(long row, LocalDateTime value) { - String rowError = rowIndexErrorMessage(row); - if (rowError != null) { - throw new IndexOutOfBoundsException(rowError); - } - String typeError = timestampTypeMismatchMessage(false); - if (typeError != null) { - throw new DuckDBFunctionException(typeError); - } - if (value == null) { - setNull(row); - return; - } - data.putLong(checkedByteOffset(row, Long.BYTES), encodeLocalDateTime(value)); - markValid(row); - } - - @Override - public void addTimestamp(Timestamp value) { - setTimestamp(nextAppendRow(), value); - } - - @Override - public void setTimestamp(long row, Timestamp value) { - if (value == null) { - setNull(row); - return; - } - if (typeInfo.columnType == DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE) { - String rowError = rowIndexErrorMessage(row); - if (rowError != null) { - throw new IndexOutOfBoundsException(rowError); - } - data.putLong(checkedByteOffset(row, Long.BYTES), encodeInstant(value.toInstant())); - markValid(row); - return; - } - setTimestamp(row, value.toLocalDateTime()); - } - - @Override - public void addTimestamp(java.util.Date value) { - setTimestamp(nextAppendRow(), value); - } - - @Override - public void setTimestamp(long row, java.util.Date value) { - String rowError = rowIndexErrorMessage(row); - if (rowError != null) { - throw new IndexOutOfBoundsException(rowError); - } - String typeError = timestampTypeMismatchMessage(false); - if (typeError != null) { - throw new DuckDBFunctionException(typeError); - } - if (value == null) { - setNull(row); - return; - } - if (value instanceof Timestamp) { - setTimestamp(row, (Timestamp) value); - return; - } - data.putLong(checkedByteOffset(row, Long.BYTES), encodeJavaUtilDate(value)); - markValid(row); - } - - @Override - public void addTimestamp(LocalDate value) { - setTimestamp(nextAppendRow(), value); - } - - @Override - public void setTimestamp(long row, LocalDate value) { - if (value == null) { - setNull(row); - return; - } - if (typeInfo.columnType == DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE) { - String rowError = rowIndexErrorMessage(row); - if (rowError != null) { - throw new IndexOutOfBoundsException(rowError); - } - Instant instant = value.atStartOfDay(ZoneId.systemDefault()).toInstant(); - data.putLong(checkedByteOffset(row, Long.BYTES), encodeInstant(instant)); - markValid(row); - return; - } - setTimestamp(row, value.atStartOfDay()); - } - - @Override - public void addOffsetDateTime(OffsetDateTime value) { - setOffsetDateTime(nextAppendRow(), value); - } - - @Override - public void setOffsetDateTime(long row, OffsetDateTime value) { - String rowError = rowIndexErrorMessage(row); - if (rowError != null) { - throw new IndexOutOfBoundsException(rowError); - } - String typeError = timestampTypeMismatchMessage(true); - if (typeError != null) { - throw new DuckDBFunctionException(typeError); - } - if (value == null) { - setNull(row); - return; - } - data.putLong( - checkedByteOffset(row, Long.BYTES), - DuckDBTimestamp.localDateTime2Micros(value.withOffsetSameInstant(ZoneOffset.UTC).toLocalDateTime())); - markValid(row); - } - - @Override - public void addBigDecimal(BigDecimal value) { - setBigDecimal(nextAppendRow(), value); - } - - @Override - public void setBigDecimal(long row, BigDecimal value) { - String rowError = rowIndexErrorMessage(row); - if (rowError != null) { - throw new IndexOutOfBoundsException(rowError); - } - String typeError = typeMismatchMessage(DuckDBColumnType.DECIMAL); - if (typeError != null) { - throw new DuckDBFunctionException(typeError); - } - if (value == null) { - setNull(row); - return; - } - BigDecimal scaled; - try { - scaled = value.setScale(typeInfo.decimalMeta.scale); - } catch (ArithmeticException e) { - throw decimalOutOfRange(value, e); - } - if (scaled.precision() > typeInfo.decimalMeta.width) { - throw decimalOutOfRange(value); - } - switch (typeInfo.storageType) { - case DUCKDB_TYPE_SMALLINT: - try { - data.putShort(checkedByteOffset(row, Short.BYTES), scaled.unscaledValue().shortValueExact()); - } catch (ArithmeticException e) { - throw decimalOutOfRange(value, e); - } - break; - case DUCKDB_TYPE_INTEGER: - try { - data.putInt(checkedByteOffset(row, Integer.BYTES), scaled.unscaledValue().intValueExact()); - } catch (ArithmeticException e) { - throw decimalOutOfRange(value, e); - } - break; - case DUCKDB_TYPE_BIGINT: - try { - data.putLong(checkedByteOffset(row, Long.BYTES), scaled.unscaledValue().longValueExact()); - } catch (ArithmeticException e) { - throw decimalOutOfRange(value, e); - } - break; - case DUCKDB_TYPE_HUGEINT: { - BigInteger unscaled = scaled.unscaledValue(); - int offset = checkedByteOffset(row, typeInfo.widthBytes); - data.putLong(offset, unscaled.longValue()); - data.putLong(offset + Long.BYTES, unscaled.shiftRight(Long.SIZE).longValue()); - break; - } - default: - throw new DuckDBFunctionException("Unsupported DECIMAL storage type: " + typeInfo.storageType); - } - markValid(row); - } - - @Override - public void addString(String value) { - setString(nextAppendRow(), value); - } - - @Override - public void setString(long row, String value) { - String rowError = rowIndexErrorMessage(row); - if (rowError != null) { - throw new IndexOutOfBoundsException(rowError); - } - String typeError = typeMismatchMessage(DuckDBColumnType.VARCHAR); - if (typeError != null) { - throw new DuckDBFunctionException(typeError); - } - if (value == null) { - setNull(row); - return; - } - duckdb_vector_assign_string_element_len(vectorRef, row, value.getBytes(UTF_8)); - markValid(row); - } - - ByteBuffer vectorRef() { - return vectorRef; - } - - private void ensureValidity() { - if (validity != null) { - return; - } - duckdb_vector_ensure_validity_writable(vectorRef); - validity = duckdb_vector_get_validity(vectorRef, rowCount); - if (validity == null) { - throw new DuckDBFunctionException("Cannot initialize vector validity"); - } - validity = validity.order(NATIVE_ORDER); - } - - private void markValid(long row) { - if (validity == null) { - advanceAppendIndex(row); - return; - } - setRowValidity(row, true); - advanceAppendIndex(row); - } - - private void setRowValidity(long row, boolean valid) { - int entryOffset = Math.toIntExact(Math.multiplyExact(row / Long.SIZE, (long) Long.BYTES)); - long bitIndex = row % Long.SIZE; - long mask = 1L << bitIndex; - long entry = validity.getLong(entryOffset); - if (valid) { - entry |= mask; - } else { - entry &= ~mask; - } - validity.putLong(entryOffset, entry); - } - - private String typeMismatchMessage(DuckDBColumnType expected) { - if (typeInfo.columnType != expected) { - return "Expected vector type " + expected + ", found " + typeInfo.columnType; - } - return null; - } - - private String rowIndexErrorMessage(long row) { - if (row < 0 || row >= rowCount) { - return "Row index out of bounds: " + row; - } - return null; - } - - private String timestampTypeMismatchMessage(boolean requireTimezone) { - if (requireTimezone) { - if (typeInfo.columnType != DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE) { - return "Expected vector type TIMESTAMP WITH TIME ZONE, found " + typeInfo.columnType; - } - return null; - } - switch (typeInfo.columnType) { - case TIMESTAMP: - case TIMESTAMP_S: - case TIMESTAMP_MS: - case TIMESTAMP_NS: - case TIMESTAMP_WITH_TIME_ZONE: - return null; - default: - return "Expected vector type TIMESTAMP*, found " + typeInfo.columnType; - } - } - - private long encodeLocalDateTime(LocalDateTime value) { - Instant instant; - if (typeInfo.columnType == DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE) { - instant = value.atZone(ZoneId.systemDefault()).toInstant(); - } else { - instant = value.toInstant(ZoneOffset.UTC); - } - return encodeInstant(instant); - } - - private long encodeJavaUtilDate(java.util.Date value) { - return encodeInstant(Instant.ofEpochMilli(value.getTime())); - } - - private long encodeInstant(Instant instant) { - long epochSeconds = instant.getEpochSecond(); - int nano = instant.getNano(); - switch (typeInfo.capiType) { - case DUCKDB_TYPE_TIMESTAMP_S: - return epochSeconds; - case DUCKDB_TYPE_TIMESTAMP_MS: - return Math.addExact(Math.multiplyExact(epochSeconds, 1_000L), nano / 1_000_000L); - case DUCKDB_TYPE_TIMESTAMP: - case DUCKDB_TYPE_TIMESTAMP_TZ: - return Math.addExact(Math.multiplyExact(epochSeconds, 1_000_000L), nano / 1_000L); - case DUCKDB_TYPE_TIMESTAMP_NS: - return Math.addExact(Math.multiplyExact(epochSeconds, 1_000_000_000L), nano); - default: - throw new DuckDBFunctionException("Expected vector type TIMESTAMP*, found " + typeInfo.columnType); - } - } - - private static String unsignedRangeErrorMessage(String typeName, long value, long maxValue) { - if (value < 0 || value > maxValue) { - return "Value out of range for " + typeName + ": " + value; - } - return null; - } - - private DuckDBFunctionException decimalOutOfRange(BigDecimal value) { - return new DuckDBFunctionException("Value out of range for " + decimalTypeName() + ": " + value); - } - - private DuckDBFunctionException decimalOutOfRange(BigDecimal value, ArithmeticException cause) { - DuckDBFunctionException exception = decimalOutOfRange(value); - exception.initCause(cause); - return exception; - } - - private String decimalTypeName() { - return "DECIMAL(" + typeInfo.decimalMeta.width + "," + typeInfo.decimalMeta.scale + ")"; - } - - private int checkedRowIndex(long row) { - return Math.toIntExact(row); - } - - private int checkedByteOffset(long row, int elementWidth) { - return Math.toIntExact(Math.multiplyExact(row, (long) elementWidth)); - } - - private void advanceAppendIndex(long row) { - appendIndex = Math.max(appendIndex, Math.addExact(row, 1)); - } - - private long nextAppendRow() { - if (appendIndex >= rowCount) { - throw new IndexOutOfBoundsException("Append index out of bounds: " + appendIndex); - } - return appendIndex; - } -} diff --git a/src/test/java/org/duckdb/TestBindings.java b/src/test/java/org/duckdb/TestBindings.java index 8bdf0ad3f..701d82880 100644 --- a/src/test/java/org/duckdb/TestBindings.java +++ b/src/test/java/org/duckdb/TestBindings.java @@ -9,9 +9,7 @@ import java.math.BigInteger; import java.nio.ByteBuffer; import java.nio.ByteOrder; -import java.nio.charset.StandardCharsets; import java.sql.*; -import java.util.Arrays; public class TestBindings { @@ -27,16 +25,16 @@ public static void test_bindings_vector_row_index_stream() throws Exception { ByteBuffer inputVec = duckdb_create_vector(lt); ByteBuffer outputVec = duckdb_create_vector(lt); - DuckDBWritableVector input = new DuckDBWritableVectorImpl(inputVec, 3); + DuckDBWritableVector input = new DuckDBWritableVector(inputVec, 3); input.setInt(0, 1); input.setInt(1, 41); input.setInt(2, -5); - DuckDBReadableVector readable = new DuckDBReadableVectorImpl(inputVec, 3); - DuckDBWritableVector output = new DuckDBWritableVectorImpl(outputVec, 3); - readable.rowIndexStream().forEachOrdered(row -> { output.setInt(row, readable.getInt(row) + 1); }); + DuckDBReadableVector readable = new DuckDBReadableVector(inputVec, 3); + DuckDBWritableVector output = new DuckDBWritableVector(outputVec, 3); + readable.stream().forEachOrdered(row -> { output.setInt(row, readable.getInt(row) + 1); }); - DuckDBReadableVector result = new DuckDBReadableVectorImpl(outputVec, 3); + DuckDBReadableVector result = new DuckDBReadableVector(outputVec, 3); assertEquals(result.getInt(0), 2); assertEquals(result.getInt(1), 42); assertEquals(result.getInt(2), -4); @@ -46,41 +44,6 @@ public static void test_bindings_vector_row_index_stream() throws Exception { duckdb_destroy_logical_type(lt); } - public static void test_bindings_writable_vector_append_after_indexed_write() throws Exception { - ByteBuffer lt = duckdb_create_logical_type(DUCKDB_TYPE_INTEGER.typeId); - ByteBuffer vec = duckdb_create_vector(lt); - - DuckDBWritableVector writable = new DuckDBWritableVectorImpl(vec, 3); - writable.setNull(0); - writable.setInt(1, 41); - writable.addInt(42); - - DuckDBReadableVector readable = new DuckDBReadableVectorImpl(vec, 3); - assertTrue(readable.isNull(0)); - assertEquals(readable.getInt(1), 41); - assertEquals(readable.getInt(2), 42); - - duckdb_destroy_vector(vec); - duckdb_destroy_logical_type(lt); - } - - public static void test_bindings_writable_vector_failed_append_does_not_advance() throws Exception { - ByteBuffer lt = duckdb_create_logical_type(DUCKDB_TYPE_INTEGER.typeId); - ByteBuffer vec = duckdb_create_vector(lt); - - DuckDBWritableVector writable = new DuckDBWritableVectorImpl(vec, 2); - assertThrows(() -> { writable.addString("boom"); }, DuckDBFunctionException.class); - writable.addInt(7); - writable.addInt(8); - - DuckDBReadableVector readable = new DuckDBReadableVectorImpl(vec, 2); - assertEquals(readable.getInt(0), 7); - assertEquals(readable.getInt(1), 8); - - duckdb_destroy_vector(vec); - duckdb_destroy_logical_type(lt); - } - public static void test_bindings_logical_type() throws Exception { ByteBuffer lt = duckdb_create_logical_type(DUCKDB_TYPE_INTEGER.typeId); assertNotNull(lt); @@ -194,11 +157,11 @@ public static void test_bindings_vector_get_string() throws Exception { ByteBuffer vec = duckdb_create_vector(lt); long rowCount = duckdb_vector_size(); - DuckDBWritableVector writable = new DuckDBWritableVectorImpl(vec, rowCount); + DuckDBWritableVector writable = new DuckDBWritableVector(vec, rowCount); writable.setNull(0); writable.setString(1, "duckdb"); - DuckDBReadableVector readable = new DuckDBReadableVectorImpl(vec, rowCount); + DuckDBReadableVector readable = new DuckDBReadableVector(vec, rowCount); assertNull(readable.getString(0)); assertEquals(readable.getString(1), "duckdb"); assertThrows(() -> { readable.getString(rowCount); }, IndexOutOfBoundsException.class); @@ -213,13 +176,13 @@ public static void test_bindings_vector_native_endian_roundtrip() throws Excepti int rowCount = (int) duckdb_vector_size(); int expected = 0x01020304; - DuckDBWritableVector writable = new DuckDBWritableVectorImpl(vec, rowCount); + DuckDBWritableVector writable = new DuckDBWritableVector(vec, rowCount); writable.setInt(0, expected); ByteBuffer rawData = duckdb_vector_get_data(vec, (long) rowCount * Integer.BYTES); assertEquals(rawData.order(ByteOrder.nativeOrder()).getInt(0), expected); - DuckDBReadableVector readable = new DuckDBReadableVectorImpl(vec, rowCount); + DuckDBReadableVector readable = new DuckDBReadableVector(vec, rowCount); assertEquals(readable.getInt(0), expected); duckdb_destroy_vector(vec); @@ -231,12 +194,12 @@ public static void test_bindings_writable_vector_stack_trace_origin() throws Exc ByteBuffer vec = duckdb_create_vector(lt); int rowCount = (int) duckdb_vector_size(); - DuckDBWritableVector writable = new DuckDBWritableVectorImpl(vec, rowCount); + DuckDBWritableVector writable = new DuckDBWritableVector(vec, rowCount); try { writable.setInt(0, 42); fail("Expected setInt to reject VARCHAR vector"); - } catch (DuckDBFunctionException exception) { + } catch (DuckDBFunctions.FunctionException exception) { assertTrue(exception.getMessage().contains("Expected vector type INTEGER, found VARCHAR")); assertEquals(exception.getStackTrace()[0].getMethodName(), "setInt"); } @@ -263,7 +226,7 @@ public static void test_bindings_vector_ubigint_native_endian_roundtrip() throws new BigInteger[] {BigInteger.ZERO, new BigInteger("42"), new BigInteger("9223372036854775808"), new BigInteger("18446744073709551615")}; - DuckDBWritableVector writable = new DuckDBWritableVectorImpl(vec, rowCount); + DuckDBWritableVector writable = new DuckDBWritableVector(vec, rowCount); for (int i = 0; i < values.length; i++) { writable.setUint64(i, values[i]); } @@ -274,7 +237,7 @@ public static void test_bindings_vector_ubigint_native_endian_roundtrip() throws assertEquals(nativeData.getLong(i * Long.BYTES), values[i].longValue()); } - DuckDBReadableVector readable = new DuckDBReadableVectorImpl(vec, rowCount); + DuckDBReadableVector readable = new DuckDBReadableVector(vec, rowCount); for (int i = 0; i < values.length; i++) { assertEquals(readable.getUint64(i), values[i]); } @@ -293,7 +256,7 @@ public static void test_bindings_vector_uhugeint_native_endian_roundtrip() throw new BigInteger[] {BigInteger.ZERO, new BigInteger("42"), new BigInteger("9223372036854775808"), new BigInteger("340282366920938463463374607431768211455")}; - DuckDBWritableVector writable = new DuckDBWritableVectorImpl(vec, rowCount); + DuckDBWritableVector writable = new DuckDBWritableVector(vec, rowCount); for (int i = 0; i < values.length; i++) { writable.setUHugeInt(i, values[i]); } @@ -306,7 +269,7 @@ public static void test_bindings_vector_uhugeint_native_endian_roundtrip() throw assertEquals(nativeData.getLong(offset + Long.BYTES), values[i].shiftRight(Long.SIZE).longValue()); } - DuckDBReadableVector readable = new DuckDBReadableVectorImpl(vec, rowCount); + DuckDBReadableVector readable = new DuckDBReadableVector(vec, rowCount); for (int i = 0; i < values.length; i++) { assertEquals(readable.getUHugeInt(i), values[i]); } @@ -423,7 +386,7 @@ public static void test_bindings_writable_vector_validity_word_boundaries() thro ByteBuffer vec = duckdb_create_vector(lt); long rowCount = 70; - DuckDBWritableVector writable = new DuckDBWritableVectorImpl(vec, rowCount); + DuckDBWritableVector writable = new DuckDBWritableVector(vec, rowCount); long[] boundaryRows = new long[] {63, 64, 65, rowCount - 1}; long[] sentinelRows = new long[] {62, 66, 67, 68}; @@ -438,7 +401,7 @@ public static void test_bindings_writable_vector_validity_word_boundaries() thro writable.setNull(row); } - DuckDBReadableVector readable = new DuckDBReadableVectorImpl(vec, rowCount); + DuckDBReadableVector readable = new DuckDBReadableVector(vec, rowCount); for (long row : boundaryRows) { assertTrue(readable.isNull(row)); } @@ -451,7 +414,7 @@ public static void test_bindings_writable_vector_validity_word_boundaries() thro writable.setInt(row, (int) (row + 1000)); } - DuckDBReadableVector revalidated = new DuckDBReadableVectorImpl(vec, rowCount); + DuckDBReadableVector revalidated = new DuckDBReadableVector(vec, rowCount); for (long row : boundaryRows) { assertFalse(revalidated.isNull(row)); assertEquals(revalidated.getInt(row), (int) (row + 1000)); diff --git a/src/test/java/org/duckdb/TestScalarFunctions.java b/src/test/java/org/duckdb/TestScalarFunctions.java index 8c4fc856e..f63da607a 100644 --- a/src/test/java/org/duckdb/TestScalarFunctions.java +++ b/src/test/java/org/duckdb/TestScalarFunctions.java @@ -1,6 +1,7 @@ package org.duckdb; import static java.nio.charset.StandardCharsets.UTF_8; +import static java.time.ZoneOffset.UTC; import static org.duckdb.DuckDBBindings.*; import static org.duckdb.DuckDBBindings.CAPIType.DUCKDB_TYPE_INTEGER; import static org.duckdb.TestDuckDBJDBC.JDBC_URL; @@ -23,27 +24,142 @@ import java.time.LocalDate; import java.time.LocalDateTime; import java.time.OffsetDateTime; -import java.time.ZoneOffset; import java.util.List; import java.util.function.BiFunction; import java.util.function.Function; import java.util.function.Supplier; +import org.duckdb.DuckDBFunctions.RegisteredFunction; public class TestScalarFunctions { private interface ResultSetVerifier { void verify(ResultSet rs) throws Exception; } - private static int sumNonNullIntColumns(DuckDBScalarContext ctx, DuckDBScalarRow row) { + private static int sumNonNullIntColumns(DuckDBDataChunkReader input, long rowIndex) { int sum = 0; - for (int columnIndex = 0; columnIndex < ctx.columnCount(); columnIndex++) { - if (!row.isNull(columnIndex)) { - sum += row.getInt(columnIndex); + for (int columnIndex = 0; columnIndex < input.columnCount(); columnIndex++) { + if (!input.vector(columnIndex).isNull(rowIndex)) { + sum += input.vector(columnIndex).getInt(rowIndex); } } return sum; } + private static void assertUnaryScalarFunction(String functionName, DuckDBColumnType parameterType, + DuckDBColumnType returnType, DuckDBScalarFunction function, + String query, ResultSetVerifier verifier) throws Exception { + try (DuckDBLogicalType parameterLogicalType = DuckDBLogicalType.of(parameterType); + DuckDBLogicalType returnLogicalType = DuckDBLogicalType.of(returnType)) { + assertUnaryScalarFunction(functionName, parameterLogicalType, returnLogicalType, function, query, verifier); + } + } + + private static void assertUnaryScalarFunction(String functionName, DuckDBLogicalType parameterType, + DuckDBLogicalType returnType, DuckDBScalarFunction function, + String query, ResultSetVerifier verifier) throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBFunctions.scalarFunction() + .withName(functionName) + .withParameter(parameterType) + .withReturnType(returnType) + .withVectorizedFunction(function) + .register(conn); + try (ResultSet rs = stmt.executeQuery(query)) { + verifier.verify(rs); + } + } + } + + private static void assertUnaryJavaFunction(String functionName, Class parameterType, Class returnType, + Function function, String query, ResultSetVerifier verifier) + throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBFunctions.scalarFunction() + .withName(functionName) + .withParameter(parameterType) + .withReturnType(returnType) + .withFunction(function) + .register(conn); + try (ResultSet rs = stmt.executeQuery(query)) { + verifier.verify(rs); + } + } + } + + private static void assertUnaryJavaFunction(String functionName, DuckDBColumnType parameterType, + DuckDBColumnType returnType, Function function, String query, + ResultSetVerifier verifier) throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBFunctions.scalarFunction() + .withName(functionName) + .withParameter(parameterType) + .withReturnType(returnType) + .withFunction(function) + .register(conn); + try (ResultSet rs = stmt.executeQuery(query)) { + verifier.verify(rs); + } + } + } + + private static void assertUnaryJavaFunction(String functionName, DuckDBLogicalType parameterType, + DuckDBLogicalType returnType, Function function, String query, + ResultSetVerifier verifier) throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBFunctions.scalarFunction() + .withName(functionName) + .withParameter(parameterType) + .withReturnType(returnType) + .withFunction(function) + .register(conn); + try (ResultSet rs = stmt.executeQuery(query)) { + verifier.verify(rs); + } + } + } + + private static void assertBinaryJavaFunction(String functionName, Class leftType, Class rightType, + Class returnType, BiFunction function, String query, + ResultSetVerifier verifier) throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBFunctions.scalarFunction() + .withName(functionName) + .withParameter(leftType) + .withParameter(rightType) + .withReturnType(returnType) + .withFunction(function) + .register(conn); + try (ResultSet rs = stmt.executeQuery(query)) { + verifier.verify(rs); + } + } + } + + private static void assertNullaryJavaFunction(String functionName, Class returnType, Supplier function, + String query, ResultSetVerifier verifier) throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + DuckDBFunctions.scalarFunction() + .withName(functionName) + .withReturnType(returnType) + .withFunction(function) + .register(conn); + try (ResultSet rs = stmt.executeQuery(query)) { + verifier.verify(rs); + } + } + } + + private static void assertNullRow(ResultSet rs) throws Exception { + assertEquals(rs.getObject(1), null); + assertTrue(rs.wasNull()); + } + public static void test_bindings_scalar_function() throws Exception { ByteBuffer intType = duckdb_create_logical_type(DUCKDB_TYPE_INTEGER.typeId); ByteBuffer scalarFunction = duckdb_create_scalar_function(); @@ -75,22 +191,28 @@ public static void test_register_scalar_function_builder() throws Exception { try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); Statement stmt = conn.createStatement(); DuckDBLogicalType intType = DuckDBLogicalType.of(DuckDBColumnType.INTEGER)) { - DuckDBRegisteredFunction function = - DuckDBFunctions.scalarFunction() - .withName("java_add_int_builder") - .withParameter(intType) - .withReturnType(intType) - .withVectorizedFunction(ctx -> { - ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setInt(row.getInt(0) + 1)); - }) - .register(conn); + RegisteredFunction function = DuckDBFunctions.scalarFunction() + .withName("java_add_int_builder") + .withParameter(intType) + .withReturnType(intType) + .withVectorizedFunction((input, output) -> { + DuckDBReadableVector in = input.vector(0); + input.stream().forEach(row -> { + if (in.isNull(row)) { + output.setNull(row); + } else { + output.setInt(row, in.getInt(row) + 1); + } + }); + }) + .register(conn); assertEquals(function.name(), "java_add_int_builder"); assertEquals(function.parameterTypes().size(), 1); assertEquals(function.parameterTypes().get(0), intType); assertEquals(function.returnType(), intType); assertEquals(function.varArgType(), null); assertEquals(function.isVolatile(), false); - assertEquals(function.hasSpecialHandling(), false); + assertEquals(function.isNullInNullOut(), false); assertEquals(function.propagateNulls(), false); try (ResultSet rs = @@ -129,7 +251,7 @@ public static void test_register_scalar_function_builder_connection_without_unwr } public static void test_register_scalar_function_builder_returns_detached_metadata() throws Exception { - DuckDBRegisteredFunction function; + RegisteredFunction function; try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement(); DuckDBScalarFunctionBuilder builder = DuckDBFunctions.scalarFunction()) { function = builder.withName("java_add_int_detached") @@ -147,7 +269,7 @@ public static void test_register_scalar_function_builder_returns_detached_metada assertEquals(function.parameterColumnTypes().get(0), DuckDBColumnType.INTEGER); assertEquals(function.returnColumnType(), DuckDBColumnType.INTEGER); assertNotNull(function.function()); - assertEquals(function.functionKind(), DuckDBFunctions.DuckDBFunctionKind.SCALAR); + assertEquals(function.functionKind(), DuckDBFunctions.Kind.SCALAR); assertTrue(function.isScalar()); assertEquals(function.propagateNulls(), false); @@ -167,17 +289,17 @@ public static void test_register_scalar_function_builder_returns_detached_metada public static void test_register_scalar_function_registry_records_registered_functions() throws Exception { DuckDBDriver.clearFunctionsRegistry(); try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement()) { - DuckDBRegisteredFunction function = DuckDBFunctions.scalarFunction() - .withName("java_registry_recorded") - .withParameter(Integer.class) - .withReturnType(Integer.class) - .withFunction((Integer x) -> x + 1) - .register(conn); - - List registeredFunctions = DuckDBDriver.registeredFunctions(); + RegisteredFunction function = DuckDBFunctions.scalarFunction() + .withName("java_registry_recorded") + .withParameter(Integer.class) + .withReturnType(Integer.class) + .withFunction((Integer x) -> x + 1) + .register(conn); + + List registeredFunctions = DuckDBDriver.registeredFunctions(); assertEquals(registeredFunctions.size(), 1); assertEquals(registeredFunctions.get(0), function); - assertEquals(registeredFunctions.get(0).functionKind(), DuckDBFunctions.DuckDBFunctionKind.SCALAR); + assertEquals(registeredFunctions.get(0).functionKind(), DuckDBFunctions.Kind.SCALAR); assertTrue(registeredFunctions.get(0).isScalar()); try (ResultSet rs = stmt.executeQuery("SELECT java_registry_recorded(41)")) { @@ -200,7 +322,7 @@ public static void test_register_scalar_function_registry_is_read_only() throws .withFunction((Integer x) -> x + 1) .register(conn); - List registeredFunctions = DuckDBDriver.registeredFunctions(); + List registeredFunctions = DuckDBDriver.registeredFunctions(); assertThrows(() -> { registeredFunctions.add(null); }, UnsupportedOperationException.class); } finally { DuckDBDriver.clearFunctionsRegistry(); @@ -254,7 +376,7 @@ public static void test_register_scalar_function_registry_allows_duplicate_names .withFunction((Integer x) -> x + 2) .register(connB); - List registeredFunctions = DuckDBDriver.registeredFunctions(); + List registeredFunctions = DuckDBDriver.registeredFunctions(); assertEquals(registeredFunctions.size(), 2); assertEquals(registeredFunctions.get(0).name(), "java_registry_duplicate_name"); assertEquals(registeredFunctions.get(1).name(), "java_registry_duplicate_name"); @@ -300,20 +422,20 @@ public static void test_register_scalar_function_builder_varargs_and_flags() thr try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); Statement stmt = conn.createStatement(); DuckDBLogicalType intType = DuckDBLogicalType.of(DuckDBColumnType.INTEGER)) { - DuckDBRegisteredFunction function = + RegisteredFunction function = DuckDBFunctions.scalarFunction() .withName("java_sum_varargs_builder") .withParameter(intType) .withVarArgs(intType) .withReturnType(intType) .withVolatile() - .withSpecialHandling() - .withVectorizedFunction( - ctx -> { ctx.stream().forEachOrdered(row -> { row.setInt(sumNonNullIntColumns(ctx, row)); }); }) + .withVectorizedFunction((input, output) -> { + input.stream().forEach(row -> output.setInt(row, sumNonNullIntColumns(input, row))); + }) .register(conn); assertEquals(function.varArgType(), intType); assertEquals(function.isVolatile(), true); - assertEquals(function.hasSpecialHandling(), true); + assertEquals(function.isNullInNullOut(), false); assertEquals(function.propagateNulls(), false); try (ResultSet rs = @@ -329,15 +451,21 @@ public static void test_register_scalar_function_builder_varargs_and_flags() thr public static void test_register_scalar_function_builder_column_type_overloads() throws Exception { try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); Statement stmt = conn.createStatement()) { - DuckDBRegisteredFunction function = - DuckDBFunctions.scalarFunction() - .withName("java_add_int_builder_col_type") - .withParameter(DuckDBColumnType.INTEGER) - .withReturnType(DuckDBColumnType.INTEGER) - .withVectorizedFunction(ctx -> { - ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setInt(row.getInt(0) + 1)); - }) - .register(conn); + RegisteredFunction function = DuckDBFunctions.scalarFunction() + .withName("java_add_int_builder_col_type") + .withParameter(DuckDBColumnType.INTEGER) + .withReturnType(DuckDBColumnType.INTEGER) + .withVectorizedFunction((input, output) -> { + DuckDBReadableVector in = input.vector(0); + input.stream().forEach(row -> { + if (in.isNull(row)) { + output.setNull(row); + } else { + output.setInt(row, in.getInt(row) + 1); + } + }); + }) + .register(conn); assertEquals(function.parameterColumnTypes().size(), 1); assertEquals(function.parameterColumnTypes().get(0), DuckDBColumnType.INTEGER); assertEquals(function.parameterTypes().get(0), null); @@ -406,12 +534,12 @@ public static void test_register_scalar_function_builder_java_function() throws public static void test_register_scalar_function_builder_java_function_propagate_nulls_false() throws Exception { try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); Statement stmt = conn.createStatement()) { - DuckDBRegisteredFunction function = DuckDBFunctions.scalarFunction() - .withName("java_add_int_function_nullable") - .withParameter(Integer.class) - .withReturnType(Integer.class) - .withFunction((Integer x) -> x == null ? 99 : x + 1) - .register(conn); + RegisteredFunction function = DuckDBFunctions.scalarFunction() + .withName("java_add_int_function_nullable") + .withParameter(Integer.class) + .withReturnType(Integer.class) + .withFunction((Integer x) -> x == null ? 99 : x + 1) + .register(conn); assertEquals(function.propagateNulls(), false); try (ResultSet rs = stmt.executeQuery( @@ -433,8 +561,7 @@ public static void test_register_scalar_function_builder_java_bifunction() throw Statement stmt = conn.createStatement()) { DuckDBFunctions.scalarFunction() .withName("java_add_int_bifunction") - .withParameter(Integer.class) - .withParameter(Integer.class) + .withParameters(Integer.class, Integer.class) .withReturnType(Integer.class) .withFunction((Integer x, Integer y) -> null != x && null != y ? x + y : null) .register(conn); @@ -458,11 +585,10 @@ public static void test_register_scalar_function_builder_java_bifunction() throw public static void test_register_scalar_function_builder_java_bifunction_propagate_nulls_false() throws Exception { try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); Statement stmt = conn.createStatement()) { - DuckDBRegisteredFunction function = + RegisteredFunction function = DuckDBFunctions.scalarFunction() .withName("java_add_int_bifunction_nullable") - .withParameter(Integer.class) - .withParameter(Integer.class) + .withParameters(Integer.class, Integer.class) .withReturnType(Integer.class) .withFunction( (Integer left, Integer right) -> (left == null ? 0 : left) + (right == null ? 0 : right)) @@ -488,12 +614,12 @@ public static void test_register_scalar_function_builder_java_bifunction_propaga public static void test_register_scalar_function_builder_with_int_function() throws Exception { try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); Statement stmt = conn.createStatement()) { - DuckDBRegisteredFunction function = DuckDBFunctions.scalarFunction() - .withName("java_add_int_with_int_function") - .withParameter(Integer.class) - .withReturnType(Integer.class) - .withIntFunction(x -> x + 1) - .register(conn); + RegisteredFunction function = DuckDBFunctions.scalarFunction() + .withName("java_add_int_with_int_function") + .withParameter(Integer.class) + .withReturnType(Integer.class) + .withIntFunction(x -> x + 1) + .register(conn); assertEquals(function.propagateNulls(), true); try (ResultSet rs = stmt.executeQuery( @@ -512,13 +638,13 @@ public static void test_register_scalar_function_builder_with_int_function() thr public static void test_register_scalar_function_builder_with_int_binary_function() throws Exception { try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); Statement stmt = conn.createStatement()) { - DuckDBRegisteredFunction function = DuckDBFunctions.scalarFunction() - .withName("java_add_int_with_int_binary_function") - .withParameter(Integer.class) - .withParameter(Integer.class) - .withReturnType(Integer.class) - .withIntFunction((left, right) -> left + right) - .register(conn); + RegisteredFunction function = DuckDBFunctions.scalarFunction() + .withName("java_add_int_with_int_binary_function") + .withParameter(Integer.class) + .withParameter(Integer.class) + .withReturnType(Integer.class) + .withIntFunction((left, right) -> left + right) + .register(conn); assertEquals(function.propagateNulls(), true); try (ResultSet rs = stmt.executeQuery("SELECT java_add_int_with_int_binary_function(a, b) " @@ -539,12 +665,12 @@ public static void test_register_scalar_function_builder_with_int_binary_functio public static void test_register_scalar_function_builder_with_double_function() throws Exception { try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); Statement stmt = conn.createStatement()) { - DuckDBRegisteredFunction function = DuckDBFunctions.scalarFunction() - .withName("java_add_double_with_double_function") - .withParameter(Double.class) - .withReturnType(Double.class) - .withDoubleFunction(x -> x + 0.5d) - .register(conn); + RegisteredFunction function = DuckDBFunctions.scalarFunction() + .withName("java_add_double_with_double_function") + .withParameter(Double.class) + .withReturnType(Double.class) + .withDoubleFunction(x -> x + 0.5d) + .register(conn); assertEquals(function.propagateNulls(), true); try (ResultSet rs = stmt.executeQuery( @@ -563,13 +689,13 @@ public static void test_register_scalar_function_builder_with_double_function() public static void test_register_scalar_function_builder_with_double_binary_function() throws Exception { try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); Statement stmt = conn.createStatement()) { - DuckDBRegisteredFunction function = DuckDBFunctions.scalarFunction() - .withName("java_add_double_with_double_binary_function") - .withParameter(Double.class) - .withParameter(Double.class) - .withReturnType(Double.class) - .withDoubleFunction((left, right) -> left + right) - .register(conn); + RegisteredFunction function = DuckDBFunctions.scalarFunction() + .withName("java_add_double_with_double_binary_function") + .withParameter(Double.class) + .withParameter(Double.class) + .withReturnType(Double.class) + .withDoubleFunction((left, right) -> left + right) + .register(conn); assertEquals(function.propagateNulls(), true); try (ResultSet rs = @@ -591,12 +717,12 @@ public static void test_register_scalar_function_builder_with_double_binary_func public static void test_register_scalar_function_builder_with_long_function() throws Exception { try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); Statement stmt = conn.createStatement()) { - DuckDBRegisteredFunction function = DuckDBFunctions.scalarFunction() - .withName("java_add_long_with_long_function") - .withParameter(Long.class) - .withReturnType(Long.class) - .withLongFunction(x -> x + 3) - .register(conn); + RegisteredFunction function = DuckDBFunctions.scalarFunction() + .withName("java_add_long_with_long_function") + .withParameter(Long.class) + .withReturnType(Long.class) + .withLongFunction(x -> x + 3) + .register(conn); assertEquals(function.propagateNulls(), true); try ( @@ -616,13 +742,13 @@ public static void test_register_scalar_function_builder_with_long_function() th public static void test_register_scalar_function_builder_with_long_binary_function() throws Exception { try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); Statement stmt = conn.createStatement()) { - DuckDBRegisteredFunction function = DuckDBFunctions.scalarFunction() - .withName("java_add_long_with_long_binary_function") - .withParameter(Long.class) - .withParameter(Long.class) - .withReturnType(Long.class) - .withLongFunction((left, right) -> left + right) - .register(conn); + RegisteredFunction function = DuckDBFunctions.scalarFunction() + .withName("java_add_long_with_long_binary_function") + .withParameter(Long.class) + .withParameter(Long.class) + .withReturnType(Long.class) + .withLongFunction((left, right) -> left + right) + .register(conn); assertEquals(function.propagateNulls(), true); try (ResultSet rs = stmt.executeQuery("SELECT java_add_long_with_long_binary_function(a, b) " @@ -805,6 +931,7 @@ public static void test_register_scalar_function_builder_java_varargs_function() .withParameter(Integer.class) .withVarArgs(intType) .withReturnType(Integer.class) + .withNullInNullOut() .withVarArgsFunction(args -> { int sum = 0; for (Object arg : args) { @@ -941,7 +1068,7 @@ public static void test_register_scalar_function_builder_java_function_supported rs -> { assertTrue(rs.next()); assertTrue(rs.getObject(1, OffsetDateTime.class) - .isEqual(OffsetDateTime.of(2024, 7, 21, 10, 39, 56, 123456000, ZoneOffset.UTC))); + .isEqual(OffsetDateTime.of(2024, 7, 21, 10, 39, 56, 123456000, UTC))); assertTrue(rs.next()); assertNullRow(rs); assertFalse(rs.next()); @@ -1135,8 +1262,16 @@ public static void test_register_scalar_function_typed_logical_type() throws Exc .withName("java_add_int_typed") .withParameter(intType) .withReturnType(intType) - .withVectorizedFunction( - ctx -> { ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setInt(row.getInt(0) + 1)); }) + .withVectorizedFunction((input, output) -> { + DuckDBReadableVector in = input.vector(0); + input.stream().forEach(row -> { + if (in.isNull(row)) { + output.setNull(row); + } else { + output.setInt(row, in.getInt(row) + 1); + } + }); + }) .register(conn); try (ResultSet rs = stmt.executeQuery("SELECT java_add_int_typed(v) FROM (VALUES (1), (NULL), (41)) t(v)")) { @@ -1160,13 +1295,9 @@ public static void test_register_scalar_function_parallel() throws Exception { .withName("java_add_one_bigint") .withParameter(bigintType) .withReturnType(bigintType) - .withVectorizedFunction(ctx -> { - DuckDBWritableVector out = ctx.output(); - DuckDBReadableVector in = ctx.input(0); - long rowCount = ctx.rowCount(); - for (long row = 0; row < rowCount; row++) { - out.setLong(row, in.getLong(row) + 1); - } + .withVectorizedFunction((input, output) -> { + DuckDBReadableVector in = input.vector(0); + input.stream().forEach(row -> { output.setLong(row, in.getLong(row) + 1); }); }) .register(conn); @@ -1187,8 +1318,16 @@ public static void test_register_scalar_function_context_row_stream_int() throws .withName("java_add_int_row_stream") .withParameter(intType) .withReturnType(intType) - .withVectorizedFunction( - ctx -> { ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setInt(row.getInt(0) + 1)); }) + .withVectorizedFunction((input, output) -> { + DuckDBReadableVector in = input.vector(0); + input.stream().forEach(row -> { + if (in.isNull(row)) { + output.setNull(row); + } else { + output.setInt(row, in.getInt(row) + 1); + } + }); + }) .register(conn); try (ResultSet rs = @@ -1212,8 +1351,15 @@ public static void test_register_scalar_function_context_row_stream_double() thr .withName("java_add_double_row_stream") .withParameter(doubleType) .withReturnType(doubleType) - .withVectorizedFunction(ctx -> { - ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setDouble(row.getDouble(0) + 1.5d)); + .withVectorizedFunction((input, output) -> { + DuckDBReadableVector in = input.vector(0); + input.stream().forEach(row -> { + if (in.isNull(row)) { + output.setNull(row); + } else { + output.setDouble(row, in.getDouble(row) + 1.5d); + } + }); }) .register(conn); @@ -1246,52 +1392,62 @@ public static void test_register_scalar_function_primitive_nulls_handling() thro .withParameter(DuckDBColumnType.FLOAT) .withParameter(DuckDBColumnType.DOUBLE) .withReturnType(DuckDBColumnType.VARCHAR) - .withSpecialHandling() - .withVectorizedFunction(ctx -> { - assertFalse(ctx.nullsPropagated()); - ctx.stream().forEachOrdered(row -> { + .withVectorizedFunction((input, output) -> { + input.stream().forEach(rowIndex -> { try { - DuckDBReadableVector booleanVector = ctx.input(0); - DuckDBReadableVector intVector = ctx.input(5); + DuckDBReadableVector booleanVector = input.vector(0); + DuckDBReadableVector intVector = input.vector(5); assertThrows( - () -> { booleanVector.getBoolean(row.index()); }, DuckDBFunctionException.class); - assertTrue(booleanVector.getBoolean(row.index(), true)); - assertThrows(() -> { intVector.getInt(row.index()); }, DuckDBFunctionException.class); - assertEquals(intVector.getInt(row.index(), 42), 42); + () -> { booleanVector.getBoolean(rowIndex); }, DuckDBFunctions.FunctionException.class); + assertTrue(booleanVector.getBoolean(rowIndex, true)); + assertThrows( + () -> { intVector.getInt(rowIndex); }, DuckDBFunctions.FunctionException.class); + assertEquals(intVector.getInt(rowIndex, 42), 42); - assertThrows(() -> { row.getBoolean(0); }, DuckDBFunctionException.class); + assertThrows(() -> { + input.vector(0).getBoolean(rowIndex); + }, DuckDBFunctions.FunctionException.class); try { - row.getBoolean(0); - fail("Expected row.getBoolean(0) to fail on NULL"); - } catch (DuckDBFunctionException exception) { - assertTrue(exception.getMessage().contains("Failed to read BOOLEAN")); - assertNotNull(exception.getCause()); - assertTrue(exception.getCause() instanceof DuckDBFunctionException); - assertTrue(exception.getCause().getMessage().contains("Primitive value for BOOLEAN")); + input.vector(0).getBoolean(rowIndex); + fail("Expected input.vector(0).getBoolean(rowIndex) to fail on NULL"); + } catch (DuckDBFunctions.FunctionException exception) { + assertTrue(exception.getMessage().contains("BOOLEAN")); } - assertTrue(row.getBoolean(0, true)); - assertThrows(() -> { row.getByte(1); }, DuckDBFunctionException.class); - assertEquals(row.getByte(1, (byte) 42), (byte) 42); - assertThrows(() -> { row.getUint8(2); }, DuckDBFunctionException.class); - assertEquals(row.getUint8(2, (short) 42), (short) 42); - assertThrows(() -> { row.getShort(3); }, DuckDBFunctionException.class); - assertEquals(row.getShort(3, (short) 42), (short) 42); - assertThrows(() -> { row.getUint16(4); }, DuckDBFunctionException.class); - assertEquals(row.getUint16(4, 42), 42); - assertThrows(() -> { row.getInt(5); }, DuckDBFunctionException.class); - assertEquals(row.getInt(5, 42), 42); - assertThrows(() -> { row.getUint32(6); }, DuckDBFunctionException.class); - assertEquals(row.getUint32(6, (long) 42), (long) 42); - assertThrows(() -> { row.getLong(7); }, DuckDBFunctionException.class); - assertEquals(row.getLong(7, (long) 42), (long) 42); - assertThrows(() -> { row.getFloat(8); }, DuckDBFunctionException.class); - assertEquals(row.getFloat(8, (float) 42.1), (float) 42.1); - assertThrows(() -> { row.getDouble(9); }, DuckDBFunctionException.class); - assertEquals(row.getDouble(9, 42.1), 42.1); + assertTrue(input.vector(0).getBoolean(rowIndex, true)); + assertThrows( + () -> { input.vector(1).getByte(rowIndex); }, DuckDBFunctions.FunctionException.class); + assertEquals(input.vector(1).getByte(rowIndex, (byte) 42), (byte) 42); + assertThrows( + () -> { input.vector(2).getUint8(rowIndex); }, DuckDBFunctions.FunctionException.class); + assertEquals(input.vector(2).getUint8(rowIndex, (short) 42), (short) 42); + assertThrows( + () -> { input.vector(3).getShort(rowIndex); }, DuckDBFunctions.FunctionException.class); + assertEquals(input.vector(3).getShort(rowIndex, (short) 42), (short) 42); + assertThrows(() -> { + input.vector(4).getUint16(rowIndex); + }, DuckDBFunctions.FunctionException.class); + assertEquals(input.vector(4).getUint16(rowIndex, 42), 42); + assertThrows( + () -> { input.vector(5).getInt(rowIndex); }, DuckDBFunctions.FunctionException.class); + assertEquals(input.vector(5).getInt(rowIndex, 42), 42); + assertThrows(() -> { + input.vector(6).getUint32(rowIndex); + }, DuckDBFunctions.FunctionException.class); + assertEquals(input.vector(6).getUint32(rowIndex, (long) 42), (long) 42); + assertThrows( + () -> { input.vector(7).getLong(rowIndex); }, DuckDBFunctions.FunctionException.class); + assertEquals(input.vector(7).getLong(rowIndex, (long) 42), (long) 42); + assertThrows( + () -> { input.vector(8).getFloat(rowIndex); }, DuckDBFunctions.FunctionException.class); + assertEquals(input.vector(8).getFloat(rowIndex, (float) 42.1), (float) 42.1); + assertThrows(() -> { + input.vector(9).getDouble(rowIndex); + }, DuckDBFunctions.FunctionException.class); + assertEquals(input.vector(9).getDouble(rowIndex, 42.1), 42.1); } catch (Exception e) { throw new RuntimeException(e); } - row.setString("ok"); + output.setString(rowIndex, "ok"); }); }) .register(conn); @@ -1310,11 +1466,11 @@ public static void test_register_scalar_function_primitive_nulls_handling() thro public static void test_register_scalar_function_context_row_stream_propagate_nulls_false() throws Exception { assertUnaryScalarFunction( "java_suffix_varchar_row_stream_nullable", DuckDBColumnType.VARCHAR, DuckDBColumnType.VARCHAR, - ctx - -> { - ctx.stream().forEachOrdered(row -> { - String value = row.getString(0); - row.setString(value == null ? "NULL_SEEN" : value + "_ok"); + (input, output) + -> { + input.stream().forEach(row -> { + String value = input.vector(0).getString(row); + output.setString(row, value == null ? "NULL_SEEN" : value + "_ok"); }); }, "SELECT java_suffix_varchar_row_stream_nullable(v) FROM (VALUES ('duck'), (NULL), ('db')) t(v)", @@ -1384,7 +1540,7 @@ public static void test_register_scalar_function_exception_propagation() throws .withName("java_throws_exception") .withParameter(intType) .withReturnType(intType) - .withVectorizedFunction(ctx -> { throw new IllegalStateException("boom"); }) + .withVectorizedFunction((input, output) -> { throw new IllegalStateException("boom"); }) .register(conn); String message = assertThrows(() -> { stmt.executeQuery("SELECT java_throws_exception(1)"); }, SQLException.class); @@ -1395,45 +1551,68 @@ public static void test_register_scalar_function_exception_propagation() throws } public static void test_register_scalar_function_boolean() throws Exception { - assertUnaryScalarFunction( - "java_not_bool", DuckDBColumnType.BOOLEAN, DuckDBColumnType.BOOLEAN, - ctx - -> { ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setBoolean(!row.getBoolean(0))); }, - "SELECT java_not_bool(v) FROM (VALUES (TRUE), (NULL), (FALSE)) t(v)", - rs -> { - assertTrue(rs.next()); - assertEquals(rs.getObject(1, Boolean.class), false); - assertTrue(rs.next()); - assertNullRow(rs); - assertTrue(rs.next()); - assertEquals(rs.getObject(1, Boolean.class), true); - assertFalse(rs.next()); - }); + assertUnaryScalarFunction("java_not_bool", DuckDBColumnType.BOOLEAN, DuckDBColumnType.BOOLEAN, + (input, output) + -> { + DuckDBReadableVector in = input.vector(0); + input.stream().forEach(row -> { + if (in.isNull(row)) { + output.setNull(row); + } else { + output.setBoolean(row, !in.getBoolean(row)); + } + }); + }, + "SELECT java_not_bool(v) FROM (VALUES (TRUE), (NULL), (FALSE)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Boolean.class), false); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Boolean.class), true); + assertFalse(rs.next()); + }); } public static void test_register_scalar_function_tinyint() throws Exception { - assertUnaryScalarFunction( - "java_add_tinyint", DuckDBColumnType.TINYINT, DuckDBColumnType.TINYINT, - ctx - -> { ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setByte((byte) (row.getByte(0) + 1))); }, - "SELECT java_add_tinyint(v) FROM (VALUES (41::TINYINT), (NULL), (-2::TINYINT)) t(v)", - rs -> { - assertTrue(rs.next()); - assertEquals(rs.getObject(1, Byte.class), (byte) 42); - assertTrue(rs.next()); - assertNullRow(rs); - assertTrue(rs.next()); - assertEquals(rs.getObject(1, Byte.class), (byte) -1); - assertFalse(rs.next()); - }); + assertUnaryScalarFunction("java_add_tinyint", DuckDBColumnType.TINYINT, DuckDBColumnType.TINYINT, + (input, output) + -> { + DuckDBReadableVector in = input.vector(0); + input.stream().forEach(row -> { + if (in.isNull(row)) { + output.setNull(row); + } else { + output.setByte(row, (byte) (in.getByte(row) + 1)); + } + }); + }, + "SELECT java_add_tinyint(v) FROM (VALUES (41::TINYINT), (NULL), (-2::TINYINT)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Byte.class), (byte) 42); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Byte.class), (byte) -1); + assertFalse(rs.next()); + }); } public static void test_register_scalar_function_smallint() throws Exception { assertUnaryScalarFunction( "java_add_smallint", DuckDBColumnType.SMALLINT, DuckDBColumnType.SMALLINT, - ctx - -> { - ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setShort((short) (row.getShort(0) + 2))); + (input, output) + -> { + DuckDBReadableVector in = input.vector(0); + input.stream().forEach(row -> { + if (in.isNull(row)) { + output.setNull(row); + } else { + output.setShort(row, (short) (in.getShort(row) + 2)); + } + }); }, "SELECT java_add_smallint(v) FROM (VALUES (40::SMALLINT), (NULL), (-4::SMALLINT)) t(v)", rs -> { @@ -1448,29 +1627,42 @@ public static void test_register_scalar_function_smallint() throws Exception { } public static void test_register_scalar_function_integer() throws Exception { - assertUnaryScalarFunction( - "java_add_int", DuckDBColumnType.INTEGER, DuckDBColumnType.INTEGER, - ctx - -> { ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setInt(row.getInt(0) + 1)); }, - "SELECT java_add_int(v) FROM (VALUES (1), (NULL), (41)) t(v)", - rs -> { - assertTrue(rs.next()); - assertEquals(rs.getObject(1, Integer.class), 2); - assertTrue(rs.next()); - assertNullRow(rs); - assertTrue(rs.next()); - assertEquals(rs.getObject(1, Integer.class), 42); - assertFalse(rs.next()); - }); + assertUnaryScalarFunction("java_add_int", DuckDBColumnType.INTEGER, DuckDBColumnType.INTEGER, + (input, output) + -> { + DuckDBReadableVector in = input.vector(0); + input.stream().forEach(row -> { + if (in.isNull(row)) { + output.setNull(row); + } else { + output.setInt(row, in.getInt(row) + 1); + } + }); + }, + "SELECT java_add_int(v) FROM (VALUES (1), (NULL), (41)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 2); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Integer.class), 42); + assertFalse(rs.next()); + }); } public static void test_register_scalar_function_integer_revalidates_after_null() throws Exception { assertUnaryScalarFunction("java_revalidate_int", DuckDBColumnType.INTEGER, DuckDBColumnType.INTEGER, - ctx - -> { - ctx.propagateNulls(true).stream().forEachOrdered(row -> { - row.setNull(); - row.setInt(row.getInt(0) + 1); + (input, output) + -> { + DuckDBReadableVector in = input.vector(0); + input.stream().forEach(row -> { + if (in.isNull(row)) { + output.setNull(row); + } else { + output.setNull(row); + output.setInt(row, in.getInt(row) + 1); + } }); }, "SELECT java_revalidate_int(v) FROM (VALUES (41), (NULL)) t(v)", @@ -1485,27 +1677,44 @@ public static void test_register_scalar_function_integer_revalidates_after_null( } public static void test_register_scalar_function_bigint() throws Exception { - assertUnaryScalarFunction( - "java_add_bigint", DuckDBColumnType.BIGINT, DuckDBColumnType.BIGINT, - ctx - -> { ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setLong(row.getLong(0) + 3)); }, - "SELECT java_add_bigint(v) FROM (VALUES (39::BIGINT), (NULL), (-5::BIGINT)) t(v)", - rs -> { - assertTrue(rs.next()); - assertEquals(rs.getObject(1, Long.class), 42L); - assertTrue(rs.next()); - assertNullRow(rs); - assertTrue(rs.next()); - assertEquals(rs.getObject(1, Long.class), -2L); - assertFalse(rs.next()); - }); + assertUnaryScalarFunction("java_add_bigint", DuckDBColumnType.BIGINT, DuckDBColumnType.BIGINT, + (input, output) + -> { + DuckDBReadableVector in = input.vector(0); + input.stream().forEach(row -> { + if (in.isNull(row)) { + output.setNull(row); + } else { + output.setLong(row, in.getLong(row) + 3); + } + }); + }, + "SELECT java_add_bigint(v) FROM (VALUES (39::BIGINT), (NULL), (-5::BIGINT)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Long.class), 42L); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Long.class), -2L); + assertFalse(rs.next()); + }); } public static void test_register_scalar_function_utinyint() throws Exception { assertUnaryScalarFunction( "java_add_utinyint", DuckDBColumnType.UTINYINT, DuckDBColumnType.UTINYINT, - ctx - -> { ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setUint8(row.getUint8(0) + 1)); }, + (input, output) + -> { + DuckDBReadableVector in = input.vector(0); + input.stream().forEach(row -> { + if (in.isNull(row)) { + output.setNull(row); + } else { + output.setUint8(row, (short) (in.getUint8(row) + 1)); + } + }); + }, "SELECT java_add_utinyint(v) FROM (VALUES (41::UTINYINT), (NULL), (254::UTINYINT)) t(v)", rs -> { assertTrue(rs.next()); @@ -1521,8 +1730,17 @@ public static void test_register_scalar_function_utinyint() throws Exception { public static void test_register_scalar_function_usmallint() throws Exception { assertUnaryScalarFunction( "java_add_usmallint", DuckDBColumnType.USMALLINT, DuckDBColumnType.USMALLINT, - ctx - -> { ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setUint16(row.getUint16(0) + 2)); }, + (input, output) + -> { + DuckDBReadableVector in = input.vector(0); + input.stream().forEach(row -> { + if (in.isNull(row)) { + output.setNull(row); + } else { + output.setUint16(row, in.getUint16(row) + 2); + } + }); + }, "SELECT java_add_usmallint(v) FROM (VALUES (40::USMALLINT), (NULL), (65533::USMALLINT)) t(v)", rs -> { assertTrue(rs.next()); @@ -1538,8 +1756,17 @@ public static void test_register_scalar_function_usmallint() throws Exception { public static void test_register_scalar_function_uinteger() throws Exception { assertUnaryScalarFunction( "java_add_uinteger", DuckDBColumnType.UINTEGER, DuckDBColumnType.UINTEGER, - ctx - -> { ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setUint32(row.getUint32(0) + 3)); }, + (input, output) + -> { + DuckDBReadableVector in = input.vector(0); + input.stream().forEach(row -> { + if (in.isNull(row)) { + output.setNull(row); + } else { + output.setUint32(row, in.getUint32(row) + 3); + } + }); + }, "SELECT java_add_uinteger(v) FROM (VALUES (39::UINTEGER), (NULL), (4294967292::UINTEGER)) t(v)", rs -> { assertTrue(rs.next()); @@ -1553,33 +1780,46 @@ public static void test_register_scalar_function_uinteger() throws Exception { } public static void test_register_scalar_function_ubigint() throws Exception { - assertUnaryScalarFunction( - "java_add_ubigint", DuckDBColumnType.UBIGINT, DuckDBColumnType.UBIGINT, - ctx - -> { - BigInteger increment = BigInteger.ONE; - ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setUint64(row.getUint64(0).add(increment))); - }, - "SELECT java_add_ubigint(v) FROM (VALUES (41::UBIGINT), (NULL), " - + "(18446744073709551614::UBIGINT)) t(v)", - rs -> { - assertTrue(rs.next()); - assertEquals(rs.getObject(1, BigInteger.class), new BigInteger("42")); - assertTrue(rs.next()); - assertNullRow(rs); - assertTrue(rs.next()); - assertEquals(rs.getObject(1, BigInteger.class), new BigInteger("18446744073709551615")); - assertFalse(rs.next()); - }); + assertUnaryScalarFunction("java_add_ubigint", DuckDBColumnType.UBIGINT, DuckDBColumnType.UBIGINT, + (input, output) + -> { + BigInteger increment = BigInteger.ONE; + DuckDBReadableVector in = input.vector(0); + input.stream().forEach(row -> { + if (in.isNull(row)) { + output.setNull(row); + } else { + output.setUint64(row, in.getUint64(row).add(increment)); + } + }); + }, + "SELECT java_add_ubigint(v) FROM (VALUES (41::UBIGINT), (NULL), " + + "(18446744073709551614::UBIGINT)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, BigInteger.class), new BigInteger("42")); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, BigInteger.class), + new BigInteger("18446744073709551615")); + assertFalse(rs.next()); + }); } public static void test_register_scalar_function_uhugeint() throws Exception { assertUnaryScalarFunction("java_add_uhugeint", DuckDBColumnType.UHUGEINT, DuckDBColumnType.UHUGEINT, - ctx - -> { + (input, output) + -> { BigInteger increment = BigInteger.ONE; - ctx.propagateNulls(true).stream().forEachOrdered( - row -> row.setUHugeInt(row.getUHugeInt(0).add(increment))); + DuckDBReadableVector in = input.vector(0); + input.stream().forEach(row -> { + if (in.isNull(row)) { + output.setNull(row); + } else { + output.setUHugeInt(row, in.getUHugeInt(row).add(increment)); + } + }); }, "SELECT java_add_uhugeint(v) FROM (VALUES (CAST('41' AS UHUGEINT)), (NULL), " + "(CAST('340282366920938463463374607431768211454' AS UHUGEINT))) t(v)", @@ -1621,47 +1861,69 @@ public static void test_register_scalar_function_builder_java_function_uhugeint( } public static void test_register_scalar_function_float() throws Exception { - assertUnaryScalarFunction( - "java_add_float", DuckDBColumnType.FLOAT, DuckDBColumnType.FLOAT, - ctx - -> { ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setFloat(row.getFloat(0) + 1.25f)); }, - "SELECT java_add_float(v) FROM (VALUES (40.75::FLOAT), (NULL), (-2.5::FLOAT)) t(v)", - rs -> { - assertTrue(rs.next()); - assertEquals(rs.getObject(1, Float.class), 42.0f); - assertTrue(rs.next()); - assertNullRow(rs); - assertTrue(rs.next()); - assertEquals(rs.getObject(1, Float.class), -1.25f); - assertFalse(rs.next()); - }); + assertUnaryScalarFunction("java_add_float", DuckDBColumnType.FLOAT, DuckDBColumnType.FLOAT, + (input, output) + -> { + DuckDBReadableVector in = input.vector(0); + input.stream().forEach(row -> { + if (in.isNull(row)) { + output.setNull(row); + } else { + output.setFloat(row, in.getFloat(row) + 1.25f); + } + }); + }, + "SELECT java_add_float(v) FROM (VALUES (40.75::FLOAT), (NULL), (-2.5::FLOAT)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Float.class), 42.0f); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Float.class), -1.25f); + assertFalse(rs.next()); + }); } public static void test_register_scalar_function_double() throws Exception { - assertUnaryScalarFunction( - "java_add_double", DuckDBColumnType.DOUBLE, DuckDBColumnType.DOUBLE, - ctx - -> { ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setDouble(row.getDouble(0) + 1.5d)); }, - "SELECT java_add_double(v) FROM (VALUES (40.5::DOUBLE), (NULL), (-3.0::DOUBLE)) t(v)", - rs -> { - assertTrue(rs.next()); - assertEquals(rs.getObject(1, Double.class), 42.0d); - assertTrue(rs.next()); - assertNullRow(rs); - assertTrue(rs.next()); - assertEquals(rs.getObject(1, Double.class), -1.5d); - assertFalse(rs.next()); - }); + assertUnaryScalarFunction("java_add_double", DuckDBColumnType.DOUBLE, DuckDBColumnType.DOUBLE, + (input, output) + -> { + DuckDBReadableVector in = input.vector(0); + input.stream().forEach(row -> { + if (in.isNull(row)) { + output.setNull(row); + } else { + output.setDouble(row, in.getDouble(row) + 1.5d); + } + }); + }, + "SELECT java_add_double(v) FROM (VALUES (40.5::DOUBLE), (NULL), (-3.0::DOUBLE)) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Double.class), 42.0d); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, Double.class), -1.5d); + assertFalse(rs.next()); + }); } public static void test_register_scalar_function_decimal() throws Exception { try (DuckDBLogicalType decimalType = DuckDBLogicalType.decimal(38, 10)) { assertUnaryScalarFunction("java_add_decimal", decimalType, decimalType, - ctx - -> { + (input, output) + -> { BigDecimal increment = new BigDecimal("0.0000000001"); - ctx.propagateNulls(true).stream().forEachOrdered( - row -> row.setBigDecimal(row.getBigDecimal(0).add(increment))); + DuckDBReadableVector in = input.vector(0); + input.stream().forEach(row -> { + if (in.isNull(row)) { + output.setNull(row); + } else { + output.setBigDecimal(row, in.getBigDecimal(row).add(increment)); + } + }); }, "SELECT java_add_decimal(v) FROM (VALUES " + "(CAST('12345678901234567890.1234567890' AS DECIMAL(38,10))), " @@ -1688,12 +1950,8 @@ public static void test_register_scalar_function_decimal_precision_overflow() th .withName("java_decimal_precision_overflow") .withParameter(decimalType) .withReturnType(decimalType) - .withVectorizedFunction(ctx -> { - DuckDBWritableVector out = ctx.output(); - long rowCount = ctx.rowCount(); - for (int i = 0; i < rowCount; i++) { - out.setBigDecimal(i, new BigDecimal("12345678901.23")); - } + .withVectorizedFunction((input, output) -> { + input.stream().forEach(i -> { output.setBigDecimal(i, new BigDecimal("12345678901.23")); }); }) .register(conn); @@ -1712,12 +1970,8 @@ public static void test_register_scalar_function_decimal_scale_overflow() throws .withName("java_decimal_scale_overflow") .withParameter(decimalType) .withReturnType(decimalType) - .withVectorizedFunction(ctx -> { - DuckDBWritableVector out = ctx.output(); - long rowCount = ctx.rowCount(); - for (int i = 0; i < rowCount; i++) { - out.setBigDecimal(i, new BigDecimal("1.234")); - } + .withVectorizedFunction((input, output) -> { + input.stream().forEach(i -> { output.setBigDecimal(i, new BigDecimal("1.234")); }); }) .register(conn); @@ -1731,9 +1985,16 @@ public static void test_register_scalar_function_decimal_scale_overflow() throws public static void test_register_scalar_function_date() throws Exception { assertUnaryScalarFunction( "java_add_date", DuckDBColumnType.DATE, DuckDBColumnType.DATE, - ctx - -> { - ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setDate(row.getLocalDate(0).plusDays(2))); + (input, output) + -> { + DuckDBReadableVector in = input.vector(0); + input.stream().forEach(row -> { + if (in.isNull(row)) { + output.setNull(row); + } else { + output.setDate(row, in.getLocalDate(row).plusDays(2)); + } + }); }, "SELECT java_add_date(v) FROM (VALUES (DATE '2024-07-20'), (NULL), (DATE '1969-12-31')) t(v)", rs -> { @@ -1750,11 +2011,17 @@ public static void test_register_scalar_function_date() throws Exception { public static void test_register_scalar_function_date_from_java_util_date() throws Exception { assertUnaryScalarFunction("java_date_from_util_date", DuckDBColumnType.DATE, DuckDBColumnType.DATE, - ctx - -> { - ctx.propagateNulls(true).stream().forEachOrdered(row -> { - LocalDate value = row.getLocalDate(0).plusDays(1); - row.setDate(java.util.Date.from(value.atStartOfDay(UTC).toInstant())); + (input, output) + -> { + DuckDBReadableVector in = input.vector(0); + input.stream().forEach(row -> { + if (in.isNull(row)) { + output.setNull(row); + } else { + LocalDate value = in.getLocalDate(row).plusDays(1); + output.setDate(row, + java.util.Date.from(value.atStartOfDay(UTC).toInstant())); + } }); }, "SELECT java_date_from_util_date(v) FROM (VALUES (DATE '2024-07-21'), (NULL)) t(v)", @@ -1769,10 +2036,16 @@ public static void test_register_scalar_function_date_from_java_util_date() thro public static void test_register_scalar_function_timestamp() throws Exception { assertUnaryScalarFunction("java_add_timestamp", DuckDBColumnType.TIMESTAMP, DuckDBColumnType.TIMESTAMP, - ctx - -> { - ctx.propagateNulls(true).stream().forEachOrdered( - row -> row.setTimestamp(row.getLocalDateTime(0).plusMinutes(30))); + (input, output) + -> { + DuckDBReadableVector in = input.vector(0); + input.stream().forEach(row -> { + if (in.isNull(row)) { + output.setNull(row); + } else { + output.setTimestamp(row, in.getLocalDateTime(row).plusMinutes(30)); + } + }); }, "SELECT java_add_timestamp(v) FROM (VALUES " + "(TIMESTAMP '2024-07-21 12:34:56.123456'), " @@ -1795,10 +2068,16 @@ public static void test_register_scalar_function_timestamp() throws Exception { public static void test_register_scalar_function_timestamp_s() throws Exception { assertUnaryScalarFunction( "java_add_timestamp_s", DuckDBColumnType.TIMESTAMP_S, DuckDBColumnType.TIMESTAMP_S, - ctx - -> { - ctx.propagateNulls(true).stream().forEachOrdered( - row -> row.setTimestamp(row.getLocalDateTime(0).plusSeconds(2))); + (input, output) + -> { + DuckDBReadableVector in = input.vector(0); + input.stream().forEach(row -> { + if (in.isNull(row)) { + output.setNull(row); + } else { + output.setTimestamp(row, in.getLocalDateTime(row).plusSeconds(2)); + } + }); }, "SELECT java_add_timestamp_s(v) FROM (VALUES (TIMESTAMP_S '2024-07-21 12:34:56'), (NULL)) t(v)", rs -> { @@ -1811,25 +2090,34 @@ public static void test_register_scalar_function_timestamp_s() throws Exception } public static void test_register_scalar_function_timestamp_s_pre_epoch() throws Exception { - assertUnaryScalarFunction("java_copy_timestamp_s_pre_epoch", DuckDBColumnType.TIMESTAMP, - DuckDBColumnType.TIMESTAMP_S, - ctx - -> { ctx.stream().forEachOrdered(row -> row.setTimestamp(row.getLocalDateTime(0))); }, - "SELECT java_copy_timestamp_s_pre_epoch(v) FROM (VALUES " - + "(TIMESTAMP '1969-12-31 23:59:59.999')) t(v)", - rs -> { - assertTrue(rs.next()); - assertEquals(rs.getTimestamp(1), Timestamp.valueOf("1969-12-31 23:59:59")); - assertFalse(rs.next()); - }); + assertUnaryScalarFunction( + "java_copy_timestamp_s_pre_epoch", DuckDBColumnType.TIMESTAMP, DuckDBColumnType.TIMESTAMP_S, + (input, output) + -> { + input.stream().forEach( + rowIndex -> output.setTimestamp(rowIndex, input.vector(0).getLocalDateTime(rowIndex))); + }, + "SELECT java_copy_timestamp_s_pre_epoch(v) FROM (VALUES " + + "(TIMESTAMP '1969-12-31 23:59:59.999')) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getTimestamp(1), Timestamp.valueOf("1969-12-31 23:59:59")); + assertFalse(rs.next()); + }); } public static void test_register_scalar_function_timestamp_ms() throws Exception { assertUnaryScalarFunction("java_add_timestamp_ms", DuckDBColumnType.TIMESTAMP_MS, DuckDBColumnType.TIMESTAMP_MS, - ctx - -> { - ctx.propagateNulls(true).stream().forEachOrdered( - row -> row.setTimestamp(row.getLocalDateTime(0).plusNanos(7_000_000))); + (input, output) + -> { + DuckDBReadableVector in = input.vector(0); + input.stream().forEach(row -> { + if (in.isNull(row)) { + output.setNull(row); + } else { + output.setTimestamp(row, in.getLocalDateTime(row).plusNanos(7_000_000)); + } + }); }, "SELECT java_add_timestamp_ms(v) FROM (VALUES " + "(TIMESTAMP_MS '2024-07-21 12:34:56.123'), (NULL)) t(v)", @@ -1844,26 +2132,35 @@ public static void test_register_scalar_function_timestamp_ms() throws Exception } public static void test_register_scalar_function_timestamp_ms_pre_epoch() throws Exception { - assertUnaryScalarFunction("java_copy_timestamp_ms_pre_epoch", DuckDBColumnType.TIMESTAMP, - DuckDBColumnType.TIMESTAMP_MS, - ctx - -> { ctx.stream().forEachOrdered(row -> row.setTimestamp(row.getLocalDateTime(0))); }, - "SELECT java_copy_timestamp_ms_pre_epoch(v) FROM (VALUES " - + "(TIMESTAMP '1969-12-31 23:59:59.9995')) t(v)", - rs -> { - assertTrue(rs.next()); - assertEquals(rs.getObject(1, LocalDateTime.class), - LocalDateTime.of(1969, 12, 31, 23, 59, 59, 999_000_000)); - assertFalse(rs.next()); - }); + assertUnaryScalarFunction( + "java_copy_timestamp_ms_pre_epoch", DuckDBColumnType.TIMESTAMP, DuckDBColumnType.TIMESTAMP_MS, + (input, output) + -> { + input.stream().forEach( + rowIndex -> output.setTimestamp(rowIndex, input.vector(0).getLocalDateTime(rowIndex))); + }, + "SELECT java_copy_timestamp_ms_pre_epoch(v) FROM (VALUES " + + "(TIMESTAMP '1969-12-31 23:59:59.9995')) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, LocalDateTime.class), + LocalDateTime.of(1969, 12, 31, 23, 59, 59, 999_000_000)); + assertFalse(rs.next()); + }); } public static void test_register_scalar_function_timestamp_ns() throws Exception { assertUnaryScalarFunction("java_add_timestamp_ns", DuckDBColumnType.TIMESTAMP_NS, DuckDBColumnType.TIMESTAMP_NS, - ctx - -> { - ctx.propagateNulls(true).stream().forEachOrdered( - row -> row.setTimestamp(row.getLocalDateTime(0).plusNanos(789))); + (input, output) + -> { + DuckDBReadableVector in = input.vector(0); + input.stream().forEach(row -> { + if (in.isNull(row)) { + output.setNull(row); + } else { + output.setTimestamp(row, in.getLocalDateTime(row).plusNanos(789)); + } + }); }, "SELECT java_add_timestamp_ns(v) FROM (VALUES " + "(TIMESTAMP_NS '2024-07-21 12:34:56.123456789'), (NULL)) t(v)", @@ -1878,38 +2175,47 @@ public static void test_register_scalar_function_timestamp_ns() throws Exception } public static void test_register_scalar_function_timestamptz() throws Exception { - assertUnaryScalarFunction( - "java_add_timestamptz", DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE, - DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE, - ctx - -> { - ctx.propagateNulls(true).stream().forEachOrdered( - row -> row.setOffsetDateTime(row.getOffsetDateTime(0).plusMinutes(5))); - }, - "SELECT java_add_timestamptz(v) FROM (VALUES " - + "(TIMESTAMPTZ '2024-07-21 12:34:56.123456+02:00'), (NULL)) t(v)", - rs -> { - assertTrue(rs.next()); - assertTrue(rs.getObject(1, OffsetDateTime.class) - .isEqual(OffsetDateTime.of(2024, 7, 21, 10, 39, 56, 123456000, ZoneOffset.UTC))); - assertTrue(rs.next()); - assertNullRow(rs); - assertFalse(rs.next()); - }); + assertUnaryScalarFunction("java_add_timestamptz", DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE, + DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE, + (input, output) + -> { + DuckDBReadableVector in = input.vector(0); + input.stream().forEach(row -> { + if (in.isNull(row)) { + output.setNull(row); + } else { + output.setOffsetDateTime(row, in.getOffsetDateTime(row).plusMinutes(5)); + } + }); + }, + "SELECT java_add_timestamptz(v) FROM (VALUES " + + "(TIMESTAMPTZ '2024-07-21 12:34:56.123456+02:00'), (NULL)) t(v)", + rs -> { + assertTrue(rs.next()); + assertTrue( + rs.getObject(1, OffsetDateTime.class) + .isEqual(OffsetDateTime.of(2024, 7, 21, 10, 39, 56, 123456000, UTC))); + assertTrue(rs.next()); + assertNullRow(rs); + assertFalse(rs.next()); + }); } public static void test_register_scalar_function_timestamptz_set_timestamp() throws Exception { assertUnaryScalarFunction( "java_copy_timestamptz_with_timestamp", DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE, DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE, - ctx - -> { ctx.stream().forEachOrdered(row -> row.setTimestamp(row.getTimestamp(0))); }, + (input, output) + -> { + input.stream().forEach( + rowIndex -> output.setTimestamp(rowIndex, input.vector(0).getTimestamp(rowIndex))); + }, "SELECT java_copy_timestamptz_with_timestamp(v) FROM (VALUES " + "(TIMESTAMPTZ '2024-07-21 12:34:56.123456+02:00')) t(v)", rs -> { assertTrue(rs.next()); assertTrue(rs.getObject(1, OffsetDateTime.class) - .isEqual(OffsetDateTime.of(2024, 7, 21, 10, 34, 56, 123456000, ZoneOffset.UTC))); + .isEqual(OffsetDateTime.of(2024, 7, 21, 10, 34, 56, 123456000, UTC))); assertFalse(rs.next()); }); } @@ -1917,11 +2223,17 @@ public static void test_register_scalar_function_timestamptz_set_timestamp() thr public static void test_register_scalar_function_timestamp_from_java_util_date() throws Exception { assertUnaryScalarFunction( "java_timestamp_from_util_date", DuckDBColumnType.TIMESTAMP, DuckDBColumnType.TIMESTAMP, - ctx - -> { + (input, output) + -> { long oneSecondMillis = 1000L; - ctx.propagateNulls(true).stream().forEachOrdered( - row -> { row.setTimestamp(new java.util.Date(row.getTimestamp(0).getTime() + oneSecondMillis)); }); + DuckDBReadableVector in = input.vector(0); + input.stream().forEach(row -> { + if (in.isNull(row)) { + output.setNull(row); + } else { + output.setTimestamp(row, new java.util.Date(in.getTimestamp(row).getTime() + oneSecondMillis)); + } + }); }, "SELECT epoch_ms(java_timestamp_from_util_date(v)) FROM (VALUES " + "(TIMESTAMP '2024-07-21 12:34:56.123456'), " @@ -1940,11 +2252,16 @@ public static void test_register_scalar_function_timestamp_from_java_util_date() public static void test_register_scalar_function_timestamp_from_java_util_date_typed_timestamp() throws Exception { assertUnaryScalarFunction( "java_timestamp_from_util_ts", DuckDBColumnType.TIMESTAMP, DuckDBColumnType.TIMESTAMP, - ctx - -> { - ctx.propagateNulls(true).stream().forEachOrdered(row -> { - java.util.Date value = Timestamp.valueOf(row.getLocalDateTime(0).plusNanos(789000)); - row.setTimestamp(value); + (input, output) + -> { + DuckDBReadableVector in = input.vector(0); + input.stream().forEach(row -> { + if (in.isNull(row)) { + output.setNull(row); + } else { + java.util.Date value = Timestamp.valueOf(in.getLocalDateTime(row).plusNanos(789000)); + output.setTimestamp(row, value); + } }); }, "SELECT java_timestamp_from_util_ts(v) FROM (VALUES (TIMESTAMP '2024-07-21 12:34:56.123456')) t(v)", @@ -1959,11 +2276,11 @@ public static void test_register_scalar_function_timestamp_from_java_util_date_t public static void test_register_scalar_function_timestamp_from_java_util_date_typed_sql_date() throws Exception { assertUnaryScalarFunction( "java_timestamp_from_util_sql_date", DuckDBColumnType.DATE, DuckDBColumnType.TIMESTAMP, - ctx - -> { - ctx.stream().forEachOrdered(row -> { - java.util.Date value = Date.valueOf(row.getLocalDate(0)); - row.setTimestamp(value); + (input, output) + -> { + input.stream().forEach(rowIndex -> { + java.util.Date value = Date.valueOf(input.vector(0).getLocalDate(rowIndex)); + output.setTimestamp(rowIndex, value); }); }, "SELECT epoch_ms(java_timestamp_from_util_sql_date(v)) FROM (VALUES (DATE '2024-07-21')) t(v)", @@ -1977,11 +2294,11 @@ public static void test_register_scalar_function_timestamp_from_java_util_date_t public static void test_register_scalar_function_timestamp_from_java_util_date_typed_sql_time() throws Exception { assertUnaryScalarFunction( "java_timestamp_from_util_sql_time", DuckDBColumnType.TIMESTAMP, DuckDBColumnType.TIMESTAMP, - ctx - -> { - ctx.stream().forEachOrdered(row -> { + (input, output) + -> { + input.stream().forEach(rowIndex -> { java.util.Date value = Time.valueOf("12:34:56"); - row.setTimestamp(value); + output.setTimestamp(rowIndex, value); }); }, "SELECT epoch_ms(java_timestamp_from_util_sql_time(v)) FROM (VALUES (TIMESTAMP '2024-07-21 00:00:00')) t(v)", @@ -1995,10 +2312,16 @@ public static void test_register_scalar_function_timestamp_from_java_util_date_t public static void test_register_scalar_function_timestamp_from_local_date() throws Exception { assertUnaryScalarFunction( "java_timestamp_from_local_date", DuckDBColumnType.DATE, DuckDBColumnType.TIMESTAMP, - ctx - -> { - ctx.propagateNulls(true).stream().forEachOrdered( - row -> row.setTimestamp(row.getLocalDate(0).plusDays(1))); + (input, output) + -> { + DuckDBReadableVector in = input.vector(0); + input.stream().forEach(row -> { + if (in.isNull(row)) { + output.setNull(row); + } else { + output.setTimestamp(row, in.getLocalDate(row).plusDays(1)); + } + }); }, "SELECT java_timestamp_from_local_date(v) FROM (VALUES (DATE '2024-07-21'), (NULL)) t(v)", rs -> { @@ -2012,48 +2335,69 @@ public static void test_register_scalar_function_timestamp_from_local_date() thr } public static void test_register_scalar_function_varchar() throws Exception { - assertUnaryScalarFunction( - "java_suffix_varchar", DuckDBColumnType.VARCHAR, DuckDBColumnType.VARCHAR, - ctx - -> { ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setString(row.getString(0) + "_java")); }, - "SELECT java_suffix_varchar(v) FROM (VALUES ('duck'), (NULL), " - + "('abcdefghijklmnop')) t(v)", - rs -> { - assertTrue(rs.next()); - assertEquals(rs.getObject(1, String.class), "duck_java"); - assertTrue(rs.next()); - assertNullRow(rs); - assertTrue(rs.next()); - assertEquals(rs.getObject(1, String.class), "abcdefghijklmnop_java"); - assertFalse(rs.next()); - }); + assertUnaryScalarFunction("java_suffix_varchar", DuckDBColumnType.VARCHAR, DuckDBColumnType.VARCHAR, + (input, output) + -> { + DuckDBReadableVector in = input.vector(0); + input.stream().forEach(row -> { + if (in.isNull(row)) { + output.setNull(row); + } else { + output.setString(row, in.getString(row) + "_java"); + } + }); + }, + "SELECT java_suffix_varchar(v) FROM (VALUES ('duck'), (NULL), " + + "('abcdefghijklmnop')) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, String.class), "duck_java"); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, String.class), "abcdefghijklmnop_java"); + assertFalse(rs.next()); + }); } public static void test_register_scalar_function_varchar_get_string_handles_null() throws Exception { - assertUnaryScalarFunction( - "java_echo_varchar_nullable", DuckDBColumnType.VARCHAR, DuckDBColumnType.VARCHAR, - ctx - -> { ctx.propagateNulls(true).stream().forEachOrdered(row -> row.setString(row.getString(0))); }, - "SELECT java_echo_varchar_nullable(v) FROM (VALUES ('duck'), (NULL), " - + "('abcdefghijklmnop')) t(v)", - rs -> { - assertTrue(rs.next()); - assertEquals(rs.getObject(1, String.class), "duck"); - assertTrue(rs.next()); - assertNullRow(rs); - assertTrue(rs.next()); - assertEquals(rs.getObject(1, String.class), "abcdefghijklmnop"); - assertFalse(rs.next()); - }); + assertUnaryScalarFunction("java_echo_varchar_nullable", DuckDBColumnType.VARCHAR, DuckDBColumnType.VARCHAR, + (input, output) + -> { + DuckDBReadableVector in = input.vector(0); + input.stream().forEach(row -> { + if (in.isNull(row)) { + output.setNull(row); + } else { + output.setString(row, in.getString(row)); + } + }); + }, + "SELECT java_echo_varchar_nullable(v) FROM (VALUES ('duck'), (NULL), " + + "('abcdefghijklmnop')) t(v)", + rs -> { + assertTrue(rs.next()); + assertEquals(rs.getObject(1, String.class), "duck"); + assertTrue(rs.next()); + assertNullRow(rs); + assertTrue(rs.next()); + assertEquals(rs.getObject(1, String.class), "abcdefghijklmnop"); + assertFalse(rs.next()); + }); } public static void test_register_scalar_function_varchar_revalidates_after_null() throws Exception { assertUnaryScalarFunction("java_revalidate_varchar", DuckDBColumnType.VARCHAR, DuckDBColumnType.VARCHAR, - ctx - -> { - ctx.propagateNulls(true).stream().forEachOrdered(row -> { - row.setNull(); - row.setString(row.getString(0) + "_ok"); + (input, output) + -> { + DuckDBReadableVector in = input.vector(0); + input.stream().forEach(row -> { + if (in.isNull(row)) { + output.setNull(row); + } else { + output.setNull(row); + output.setString(row, in.getString(row) + "_ok"); + } }); }, "SELECT java_revalidate_varchar(v) FROM (VALUES ('duck'), (NULL)) t(v)", @@ -2067,120 +2411,117 @@ public static void test_register_scalar_function_varchar_revalidates_after_null( }); } - private static void assertUnaryScalarFunction(String functionName, DuckDBColumnType parameterType, - DuckDBColumnType returnType, DuckDBScalarFunction function, - String query, ResultSetVerifier verifier) throws Exception { - try (DuckDBLogicalType parameterLogicalType = DuckDBLogicalType.of(parameterType); - DuckDBLogicalType returnLogicalType = DuckDBLogicalType.of(returnType)) { - assertUnaryScalarFunction(functionName, parameterLogicalType, returnLogicalType, function, query, verifier); - } - } - - private static void assertUnaryScalarFunction(String functionName, DuckDBLogicalType parameterType, - DuckDBLogicalType returnType, DuckDBScalarFunction function, - String query, ResultSetVerifier verifier) throws Exception { - try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); - Statement stmt = conn.createStatement()) { + public static void test_scalar_function_primitive_types() throws Exception { + try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement()) { DuckDBFunctions.scalarFunction() - .withName(functionName) - .withParameter(parameterType) - .withReturnType(returnType) - .withVectorizedFunction(function) + .withName("java_int_add") + .withParameter(Integer.TYPE) + .withReturnType(int.class) + .withIntFunction(x -> x + 1) .register(conn); - try (ResultSet rs = stmt.executeQuery(query)) { - verifier.verify(rs); + try (ResultSet rs = stmt.executeQuery("SELECT java_int_add(41::INTEGER)")) { + assertTrue(rs.next()); + assertEquals(rs.getInt(1), 42); + assertFalse(rs.next()); } } } - private static void assertUnaryJavaFunction(String functionName, Class parameterType, Class returnType, - Function function, String query, ResultSetVerifier verifier) - throws Exception { - try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); - Statement stmt = conn.createStatement()) { + public static void test_scalar_functions_null_in_null_out() throws Exception { + try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement()) { DuckDBFunctions.scalarFunction() - .withName(functionName) - .withParameter(parameterType) - .withReturnType(returnType) - .withFunction(function) + .withName("java_concat") + .withParameters(String.class, String.class) + .withReturnType(String.class) + .withNullInNullOut() + .withFunction((String left, String right) -> { + if (left == null && right == null) { + return "NULL was passed to me"; + } + return left + right; + }) .register(conn); - try (ResultSet rs = stmt.executeQuery(query)) { - verifier.verify(rs); + try (ResultSet rs = stmt.executeQuery("SELECT java_concat('foo', 'bar')")) { + assertTrue(rs.next()); + assertEquals(rs.getString(1), "foobar"); + assertFalse(rs.next()); } - } - } - - private static void assertUnaryJavaFunction(String functionName, DuckDBColumnType parameterType, - DuckDBColumnType returnType, Function function, String query, - ResultSetVerifier verifier) throws Exception { - try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); - Statement stmt = conn.createStatement()) { - DuckDBFunctions.scalarFunction() - .withName(functionName) - .withParameter(parameterType) - .withReturnType(returnType) - .withFunction(function) - .register(conn); - try (ResultSet rs = stmt.executeQuery(query)) { - verifier.verify(rs); + try (ResultSet rs = stmt.executeQuery("SELECT java_concat(NULL, NULL)")) { + assertTrue(rs.next()); + assertEquals(rs.getString(1), null); + assertTrue(rs.wasNull()); + assertFalse(rs.next()); } - } - } - - private static void assertUnaryJavaFunction(String functionName, DuckDBLogicalType parameterType, - DuckDBLogicalType returnType, Function function, String query, - ResultSetVerifier verifier) throws Exception { - try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); - Statement stmt = conn.createStatement()) { - DuckDBFunctions.scalarFunction() - .withName(functionName) - .withParameter(parameterType) - .withReturnType(returnType) - .withFunction(function) - .register(conn); - try (ResultSet rs = stmt.executeQuery(query)) { - verifier.verify(rs); + try (ResultSet rs = stmt.executeQuery("SELECT java_concat(a.col1, a.col2) FROM (" + + " SELECT 'foo' col1, 'bar' col2" + + " UNION ALL" + + " SELECT NULL col1, NULL col2" + + " UNION ALL" + + " SELECT 'boo' col1, 'baz' col2" + + ") a")) { + assertTrue(rs.next()); + assertEquals(rs.getString(1), "foobar"); + assertTrue(rs.next()); + assertEquals(rs.getString(1), "NULL was passed to me"); + assertTrue(rs.next()); + assertEquals(rs.getString(1), "boobaz"); + assertFalse(rs.next()); } } } - private static void assertBinaryJavaFunction(String functionName, Class leftType, Class rightType, - Class returnType, BiFunction function, String query, - ResultSetVerifier verifier) throws Exception { - try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); - Statement stmt = conn.createStatement()) { + public static void test_scalar_function_vectorized_event_label_example() throws Exception { + try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement(); + DuckDBLogicalType tsType = DuckDBLogicalType.of(DuckDBColumnType.TIMESTAMP); + DuckDBLogicalType strType = DuckDBLogicalType.of(DuckDBColumnType.VARCHAR); + DuckDBLogicalType dblType = DuckDBLogicalType.of(DuckDBColumnType.DOUBLE)) { DuckDBFunctions.scalarFunction() - .withName(functionName) - .withParameter(leftType) - .withParameter(rightType) - .withReturnType(returnType) - .withFunction(function) + .withName("java_event_label") + .withParameters(tsType, strType, dblType) + .withReturnType(strType) + .withVectorizedFunction((input, output) -> { + input.stream().forEach(row -> { + String value = input.vector(0).getLocalDateTime(row) + " | " + + String.valueOf(input.vector(1).getString(row)).trim().toUpperCase() + " | " + + input.vector(2).getDouble(row, 0.0d); + output.setString(row, value); + }); + }) .register(conn); - try (ResultSet rs = stmt.executeQuery(query)) { - verifier.verify(rs); + try (ResultSet rs = stmt.executeQuery("SELECT java_event_label('2020-12-31 23:58:59', 'foo', 42.2)")) { + assertTrue(rs.next()); + assertEquals(rs.getString(1), "2020-12-31T23:58:59 | FOO | 42.2"); + assertFalse(rs.next()); + } + try (ResultSet rs = stmt.executeQuery("SELECT java_event_label(NULL, NULL, NULL)")) { + assertTrue(rs.next()); + assertEquals(rs.getString(1), "null | NULL | 0.0"); + assertFalse(rs.next()); } } } - private static void assertNullaryJavaFunction(String functionName, Class returnType, Supplier function, - String query, ResultSetVerifier verifier) throws Exception { - try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); - Statement stmt = conn.createStatement()) { + public static void test_scalar_function_vectorized_parallel_streams() throws Exception { + try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement();) { DuckDBFunctions.scalarFunction() - .withName(functionName) - .withReturnType(returnType) - .withFunction(function) + .withName("java_add") + .withParameters(long.class, long.class) + .withReturnType(long.class) + .withVectorizedFunction((input, output) -> input.stream().parallel().forEach(row -> { + long left = input.vector(0).getLong(row, 0); + long right = input.vector(1).getLong(row, 0); + output.setLong(row, left + right); + })) .register(conn); - try (ResultSet rs = stmt.executeQuery(query)) { - verifier.verify(rs); + try (ResultSet rs = stmt.executeQuery("SELECT java_add(x, x + 1) FROM range(0, (1<<16) + 3) r(x)")) { + for (long i = 0; i < (1 << 16) + 3; i++) { + assertTrue(rs.next()); + long expected = i + i + 1; + long actual = rs.getLong(1); + assertEquals(actual, expected); + } + assertFalse(rs.next()); } } } - - private static void assertNullRow(ResultSet rs) throws Exception { - assertEquals(rs.getObject(1), null); - assertTrue(rs.wasNull()); - } - - private static final java.time.ZoneOffset UTC = java.time.ZoneOffset.UTC; }