Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import org.apache.paimon.types.VarBinaryType;
import org.apache.paimon.types.VarCharType;
import org.apache.paimon.types.VariantType;
import org.apache.paimon.types.VectorType;

import org.apache.hadoop.hive.common.type.HiveChar;
import org.apache.hadoop.hive.common.type.HiveVarchar;
Expand Down Expand Up @@ -235,6 +236,11 @@ public TypeInfo visit(BlobType blobType) {
return TypeInfoFactory.binaryTypeInfo;
}

@Override
public TypeInfo visit(VectorType vectorType) {
return TypeInfoFactory.getListTypeInfo(vectorType.getElementType().accept(this));
}

@Override
protected TypeInfo defaultMethod(org.apache.paimon.types.DataType dataType) {
throw new UnsupportedOperationException("Unsupported type: " + dataType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.apache.paimon.types.RowType;
import org.apache.paimon.types.TimeType;
import org.apache.paimon.types.VarCharType;
import org.apache.paimon.types.VectorType;

import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
Expand Down Expand Up @@ -81,6 +82,9 @@ public static ObjectInspector create(DataType logicalType) {
case ARRAY:
ArrayType arrayType = (ArrayType) logicalType;
return new PaimonListObjectInspector(arrayType.getElementType());
case VECTOR:
VectorType vectorType = (VectorType) logicalType;
return new PaimonListObjectInspector(vectorType.getElementType());
case MAP:
MapType mapType = (MapType) logicalType;
return new PaimonMapObjectInspector(mapType.getKeyType(), mapType.getValueType());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.paimon.types.DataType;
import org.apache.paimon.types.DataTypeChecks;
import org.apache.paimon.types.RowType;
import org.apache.paimon.types.VectorType;
import org.apache.paimon.utils.InternalRowUtils;

import org.apache.spark.sql.catalyst.util.ArrayData;
Expand Down Expand Up @@ -172,7 +173,13 @@ public org.apache.spark.sql.catalyst.InternalRow getStruct(int ordinal, int numF

@Override
public ArrayData getArray(int ordinal) {
return fromPaimon(row.getArray(ordinal), (ArrayType) rowType.getTypeAt(ordinal));
DataType type = rowType.getTypeAt(ordinal);
if (type instanceof ArrayType) {
return fromPaimon(row.getArray(ordinal), (ArrayType) type);
} else if (type instanceof VectorType) {
return DataConverter.fromPaimon(row.getVector(ordinal), (VectorType) type);
}
throw new UnsupportedOperationException("Not an array type: " + type);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.paimon.data.InternalArray;
import org.apache.paimon.data.InternalMap;
import org.apache.paimon.data.InternalRow;
import org.apache.paimon.data.InternalVector;
import org.apache.paimon.data.Timestamp;
import org.apache.paimon.spark.data.SparkArrayData;
import org.apache.paimon.spark.data.SparkInternalRow;
Expand All @@ -32,6 +33,7 @@
import org.apache.paimon.types.MapType;
import org.apache.paimon.types.MultisetType;
import org.apache.paimon.types.RowType;
import org.apache.paimon.types.VectorType;

import org.apache.spark.sql.catalyst.util.ArrayBasedMapData;
import org.apache.spark.sql.catalyst.util.ArrayData;
Expand All @@ -58,6 +60,8 @@ public static Object fromPaimon(Object o, DataType type) {
return fromPaimon((org.apache.paimon.data.Decimal) o);
case ARRAY:
return fromPaimon((InternalArray) o, (ArrayType) type);
case VECTOR:
return fromPaimon((InternalVector) o, (VectorType) type);
case MAP:
case MULTISET:
return fromPaimon((InternalMap) o, type);
Expand Down Expand Up @@ -93,6 +97,16 @@ public static ArrayData fromPaimon(InternalArray array, ArrayType arrayType) {
return fromPaimonArrayElementType(array, arrayType.getElementType());
}

public static ArrayData fromPaimon(InternalVector vector, VectorType vectorType) {
if (vector.size() != vectorType.getLength()) {
throw new IllegalArgumentException(
String.format(
"Vector length mismatch. Expected %d but was %d.",
vectorType.getLength(), vector.size()));
}
return fromPaimonArrayElementType(vector, vectorType.getElementType());
}

private static ArrayData fromPaimonArrayElementType(InternalArray array, DataType elementType) {
return SparkArrayData.create(elementType).replace(array);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@
import org.apache.paimon.types.BlobType;
import org.apache.paimon.types.DataField;
import org.apache.paimon.types.DataType;
import org.apache.paimon.types.VectorType;
import org.apache.paimon.utils.ExceptionUtils;
import org.apache.paimon.utils.Preconditions;

import org.apache.spark.sql.PaimonSparkSession$;
import org.apache.spark.sql.SparkSession;
Expand Down Expand Up @@ -73,6 +75,7 @@
import org.apache.spark.sql.execution.datasources.DataSource;
import org.apache.spark.sql.execution.datasources.FileFormat;
import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2;
import org.apache.spark.sql.types.ArrayType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.util.CaseInsensitiveStringMap;
Expand Down Expand Up @@ -563,6 +566,7 @@ private Schema toInitialSchema(
List<String> blobFields = CoreOptions.blobField(properties);
Set<String> blobDescriptorFields = new CoreOptions(properties).blobDescriptorField();
List<String> blobViewFields = CoreOptions.blobViewField(properties);
Set<String> vectorFields = CoreOptions.fromMap(properties).vectorField();
String provider = properties.get(TableCatalog.PROP_PROVIDER);
if (!usePaimon(provider)) {
if (isFormatTable(provider)) {
Expand Down Expand Up @@ -603,6 +607,17 @@ private Schema toInitialSchema(
field.dataType() instanceof org.apache.spark.sql.types.BinaryType,
"The type of blob field must be binary");
type = new BlobType();
} else if (vectorFields.contains(field.name())) {
Preconditions.checkArgument(
field.dataType() instanceof ArrayType,
"The type of blob field must be array");
ArrayType arrayType = (ArrayType) field.dataType();
String dimKey = String.format("field.%s.vector-dim", field.name());
type =
new VectorType(
arrayType.containsNull(),
Integer.parseInt(properties.get(dimKey)),
toPaimonType(arrayType.elementType()));
} else {
type = toPaimonType(field.dataType()).copy(field.nullable());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,20 @@ public InternalArray getArray(int pos) {

@Override
public InternalVector getVector(int pos) {
throw new UnsupportedOperationException("Not support VectorType yet.");
int actualPos = getActualFieldPosition(pos);
if (actualPos == -1 || internalRow.isNullAt(actualPos)) {
return null;
}
DataType dataType = tableSchema.fields()[pos].dataType();
return toSparkInternalVector(dataType, internalRow.getArray(actualPos));
}

private static InternalVector toSparkInternalVector(DataType dataType, ArrayData arrayData) {
if (!(dataType instanceof ArrayType)) {
throw new UnsupportedOperationException("Not a vector type: " + dataType);
}
ArrayType arrayType = (ArrayType) dataType;
return new SparkInternalVector(arrayData, arrayType.elementType());
}

@Override
Expand Down Expand Up @@ -435,7 +448,7 @@ public InternalArray getArray(int pos) {

@Override
public InternalVector getVector(int pos) {
throw new UnsupportedOperationException("Not support VectorType yet.");
return toSparkInternalVector(elementType, arrayData.getArray(pos));
}

@Override
Expand All @@ -452,6 +465,13 @@ public InternalRow getRow(int pos, int numFields) {
}
}

/** adapt to spark internal vector. */
public static class SparkInternalVector extends SparkInternalArray implements InternalVector {
public SparkInternalVector(ArrayData arrayData, DataType elementType) {
super(arrayData, elementType);
}
}

/** adapt to spark internal map. */
public static class SparkInternalMap implements InternalMap {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import org.apache.paimon.catalog.CatalogContext;
import org.apache.paimon.data.BinaryString;
import org.apache.paimon.data.BinaryVector;
import org.apache.paimon.data.Blob;
import org.apache.paimon.data.Decimal;
import org.apache.paimon.data.InternalArray;
Expand All @@ -35,6 +36,7 @@
import org.apache.paimon.types.MapType;
import org.apache.paimon.types.RowKind;
import org.apache.paimon.types.RowType;
import org.apache.paimon.types.VectorType;
import org.apache.paimon.utils.DateTimeUtils;
import org.apache.paimon.utils.UriReaderFactory;

Expand All @@ -48,6 +50,7 @@
import java.time.LocalDateTime;
import java.time.ZoneId;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

Expand Down Expand Up @@ -168,7 +171,10 @@ public InternalArray getArray(int i) {

@Override
public InternalVector getVector(int pos) {
throw new UnsupportedOperationException("Not support VectorType yet.");
if (row.isNullAt(pos)) {
return null;
}
return toPaimonVector((VectorType) type.getTypeAt(pos), row.get(pos));
}

@Override
Expand Down Expand Up @@ -426,4 +432,92 @@ public double[] toDoubleArray() {
return res;
}
}

private static InternalVector toPaimonVector(VectorType vectorType, Object vector) {
if (vector == null) {
return null;
}
if (vector instanceof boolean[]) {
return BinaryVector.fromPrimitiveArray((boolean[]) vector);
} else if (vector instanceof byte[]) {
return BinaryVector.fromPrimitiveArray((byte[]) vector);
} else if (vector instanceof short[]) {
return BinaryVector.fromPrimitiveArray((short[]) vector);
} else if (vector instanceof int[]) {
return BinaryVector.fromPrimitiveArray((int[]) vector);
} else if (vector instanceof long[]) {
return BinaryVector.fromPrimitiveArray((long[]) vector);
} else if (vector instanceof float[]) {
return BinaryVector.fromPrimitiveArray((float[]) vector);
} else if (vector instanceof double[]) {
return BinaryVector.fromPrimitiveArray((double[]) vector);
}
if (vector instanceof scala.collection.Seq) {
vector = JavaConverters.seqAsJavaList((scala.collection.Seq<Object>) vector);
} else if (vector.getClass().isArray()) {
vector = Arrays.asList((Object[]) vector);
}
if (!(vector instanceof List)) {
throw new UnsupportedOperationException(
"Unsupported vector object: " + vector.getClass().getName());
}
return toPaimonVector(vectorType, (List<?>) vector);
}

private static InternalVector toPaimonVector(VectorType vectorType, List<?> list) {
int expectedLength = vectorType.getLength();
if (list.size() != expectedLength) {
throw new IllegalArgumentException(
String.format(
"Vector length mismatch. Expected %d but was %d.",
expectedLength, list.size()));
}
switch (vectorType.getElementType().getTypeRoot()) {
case BOOLEAN:
boolean[] booleanValues = new boolean[expectedLength];
for (int i = 0; i < expectedLength; i++) {
booleanValues[i] = (Boolean) list.get(i);
}
return BinaryVector.fromPrimitiveArray(booleanValues);
case TINYINT:
byte[] byteValues = new byte[expectedLength];
for (int i = 0; i < expectedLength; i++) {
byteValues[i] = ((Number) list.get(i)).byteValue();
}
return BinaryVector.fromPrimitiveArray(byteValues);
case SMALLINT:
short[] shortValues = new short[expectedLength];
for (int i = 0; i < expectedLength; i++) {
shortValues[i] = ((Number) list.get(i)).shortValue();
}
return BinaryVector.fromPrimitiveArray(shortValues);
case INTEGER:
int[] intValues = new int[expectedLength];
for (int i = 0; i < expectedLength; i++) {
intValues[i] = ((Number) list.get(i)).intValue();
}
return BinaryVector.fromPrimitiveArray(intValues);
case BIGINT:
long[] longValues = new long[expectedLength];
for (int i = 0; i < expectedLength; i++) {
longValues[i] = ((Number) list.get(i)).longValue();
}
return BinaryVector.fromPrimitiveArray(longValues);
case FLOAT:
float[] floatValues = new float[expectedLength];
for (int i = 0; i < expectedLength; i++) {
floatValues[i] = ((Number) list.get(i)).floatValue();
}
return BinaryVector.fromPrimitiveArray(floatValues);
case DOUBLE:
double[] doubleValues = new double[expectedLength];
for (int i = 0; i < expectedLength; i++) {
doubleValues[i] = ((Number) list.get(i)).doubleValue();
}
return BinaryVector.fromPrimitiveArray(doubleValues);
default:
throw new UnsupportedOperationException(
"Unsupported element type for vector " + vectorType.getElementType());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import org.apache.paimon.types.VarBinaryType;
import org.apache.paimon.types.VarCharType;
import org.apache.paimon.types.VariantType;
import org.apache.paimon.types.VectorType;

import org.apache.spark.sql.paimon.shims.SparkShimLoader;
import org.apache.spark.sql.types.DataType;
Expand Down Expand Up @@ -127,8 +128,16 @@ private static org.apache.paimon.types.DataType prunePaimonType(
} else if (sparkDataType instanceof org.apache.spark.sql.types.ArrayType) {
org.apache.spark.sql.types.ArrayType s =
(org.apache.spark.sql.types.ArrayType) sparkDataType;
ArrayType r = (ArrayType) paimonDataType;
return r.newElementType(prunePaimonType(s.elementType(), r.getElementType()));
if (paimonDataType instanceof VectorType) {
VectorType v = (VectorType) paimonDataType;
return new VectorType(
v.isNullable(),
v.getLength(),
prunePaimonType(s.elementType(), v.getElementType()));
} else {
ArrayType r = (ArrayType) paimonDataType;
return r.newElementType(prunePaimonType(s.elementType(), r.getElementType()));
}
} else {
return paimonDataType;
}
Expand Down Expand Up @@ -242,6 +251,11 @@ public DataType visit(ArrayType arrayType) {
return DataTypes.createArrayType(elementType.accept(this), elementType.isNullable());
}

@Override
public DataType visit(VectorType vectorType) {
return DataTypes.createArrayType(vectorType.getElementType().accept(this), vectorType.isNullable());
}

@Override
public DataType visit(MultisetType multisetType) {
return DataTypes.createMapType(
Expand Down
Loading
Loading