From b4fbb07ebcbc77b5b93a0ef6ba5450b151736327 Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Wed, 12 Nov 2025 17:03:44 -0500 Subject: [PATCH 01/11] feat: add structured UDT literal support with dual encoding Support both opaque (google.protobuf.Any) and structured (Literal.Struct) encodings for user-defined type literals per Substrait spec. - Split UserDefinedLiteral into UserDefinedAny and UserDefinedStruct - Move type parameters to interface level for parameterized types - Add ExtensionCollector.getExtensionCollection() method - Full Calcite integration with REINTERPRET pattern for Any-based UDTs - Add Scala visitor methods and comprehensive documentation - Comprehensive test coverage including roundtrip tests --- .../expression/AbstractExpressionVisitor.java | 10 + .../io/substrait/expression/Expression.java | 87 ++++++- .../expression/ExpressionCreator.java | 80 +++++- .../expression/ExpressionVisitor.java | 4 +- .../proto/ExpressionProtoConverter.java | 41 ++- .../proto/ProtoExpressionConverter.java | 30 ++- .../extension/ExtensionCollector.java | 71 ++++++ .../substrait/extension/SimpleExtension.java | 28 +++ .../ExpressionCopyOnWriteVisitor.java | 8 +- .../src/main/java/io/substrait/type/Type.java | 17 ++ .../type/proto/BaseProtoConverter.java | 2 +- .../substrait/type/proto/BaseProtoTypes.java | 3 + .../proto/ParameterizedProtoConverter.java | 7 + .../type/proto/ProtoTypeConverter.java | 8 +- .../proto/TypeExpressionProtoVisitor.java | 7 + .../type/proto/TypeProtoConverter.java | 11 + .../ExtensionCollectorGetCollectionTest.java | 173 +++++++++++++ .../type/proto/LiteralRoundtripTest.java | 117 +++++++++ .../examples/util/ExpressionStringify.java | 12 +- .../examples/util/SubstraitStringify.java | 2 +- .../io/substrait/isthmus/TypeConverter.java | 4 +- .../isthmus/expression/CallConverters.java | 58 +++-- .../expression/ExpressionRexConverter.java | 43 +++- .../isthmus/expression/LiteralConverter.java | 21 +- .../type/SubstraitUserDefinedType.java | 233 ++++++++++++++++++ .../substrait/isthmus/CustomFunctionTest.java | 229 ++++++++++++++++- .../substrait/debug/ExpressionToString.scala | 6 +- .../spark/DefaultExpressionVisitor.scala | 10 +- 28 files changed, 1245 insertions(+), 77 deletions(-) create mode 100644 core/src/test/java/io/substrait/extension/ExtensionCollectorGetCollectionTest.java create mode 100644 isthmus/src/main/java/io/substrait/isthmus/type/SubstraitUserDefinedType.java diff --git a/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java b/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java index 072507295..d190542f8 100644 --- a/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java +++ b/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java @@ -151,6 +151,16 @@ public O visit(Expression.StructLiteral expr, C context) throws E { return visitFallback(expr, context); } + @Override + public O visit(Expression.UserDefinedAny expr, C context) throws E { + return visitFallback(expr, context); + } + + @Override + public O visit(Expression.UserDefinedStruct expr, C context) throws E { + return visitFallback(expr, context); + } + @Override public O visit(Expression.Switch expr, C context) throws E { return visitFallback(expr, context); diff --git a/core/src/main/java/io/substrait/expression/Expression.java b/core/src/main/java/io/substrait/expression/Expression.java index 42c3c5118..1b0a8362d 100644 --- a/core/src/main/java/io/substrait/expression/Expression.java +++ b/core/src/main/java/io/substrait/expression/Expression.java @@ -662,21 +662,96 @@ public R accept( } } + /** + * Base interface for user-defined literals. + * + *

User-defined literals can be encoded in one of two ways as per the Substrait spec: + * + *

+ * + * @see UserDefinedAny + * @see UserDefinedStruct + */ + interface UserDefinedLiteral extends Literal { + String urn(); + + String name(); + + List typeParameters(); + } + + /** + * User-defined literal with value encoded as {@code google.protobuf.Any}. + * + *

This encoding allows for arbitrary binary data to be stored in the literal value. + */ @Value.Immutable - abstract class UserDefinedLiteral implements Literal { - public abstract ByteString value(); + abstract class UserDefinedAny implements UserDefinedLiteral { + @Override + public abstract String urn(); + + @Override + public abstract String name(); + + @Override + public abstract List typeParameters(); + + public abstract com.google.protobuf.Any value(); + + @Override + public Type.UserDefined getType() { + return Type.UserDefined.builder() + .nullable(nullable()) + .urn(urn()) + .name(name()) + .typeParameters(typeParameters()) + .build(); + } + + public static ImmutableExpression.UserDefinedAny.Builder builder() { + return ImmutableExpression.UserDefinedAny.builder(); + } + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); + } + } + + /** + * User-defined literal with value encoded as {@code Literal.Struct}. + * + *

This encoding uses a structured list of fields to represent the literal value. + */ + @Value.Immutable + abstract class UserDefinedStruct implements UserDefinedLiteral { + @Override public abstract String urn(); + @Override public abstract String name(); @Override - public Type getType() { - return Type.withNullability(nullable()).userDefined(urn(), name()); + public abstract List typeParameters(); + + public abstract List fields(); + + @Override + public Type.UserDefined getType() { + return Type.UserDefined.builder() + .nullable(nullable()) + .urn(urn()) + .name(name()) + .typeParameters(typeParameters()) + .build(); } - public static ImmutableExpression.UserDefinedLiteral.Builder builder() { - return ImmutableExpression.UserDefinedLiteral.builder(); + public static ImmutableExpression.UserDefinedStruct.Builder builder() { + return ImmutableExpression.UserDefinedStruct.builder(); } @Override diff --git a/core/src/main/java/io/substrait/expression/ExpressionCreator.java b/core/src/main/java/io/substrait/expression/ExpressionCreator.java index adf157d7b..74de45732 100644 --- a/core/src/main/java/io/substrait/expression/ExpressionCreator.java +++ b/core/src/main/java/io/substrait/expression/ExpressionCreator.java @@ -286,13 +286,87 @@ public static Expression.StructLiteral struct( return Expression.StructLiteral.builder().nullable(nullable).addAllFields(values).build(); } - public static Expression.UserDefinedLiteral userDefinedLiteral( + /** + * Create a UserDefinedAny with google.protobuf.Any representation. + * + * @param nullable whether the literal is nullable + * @param urn the URN of the user-defined type + * @param name the name of the user-defined type + * @param value the value, encoded as google.protobuf.Any + */ + public static Expression.UserDefinedAny userDefinedLiteralAny( boolean nullable, String urn, String name, Any value) { - return Expression.UserDefinedLiteral.builder() + return Expression.UserDefinedAny.builder() + .nullable(nullable) + .urn(urn) + .name(name) + .value(value) + .build(); + } + + /** + * Create a UserDefinedAny with google.protobuf.Any representation and type parameters. + * + * @param nullable whether the literal is nullable + * @param urn the URN of the user-defined type + * @param name the name of the user-defined type + * @param typeParameters the type parameters for the user-defined type + * @param value the value, encoded as google.protobuf.Any + */ + public static Expression.UserDefinedAny userDefinedLiteralAny( + boolean nullable, + String urn, + String name, + java.util.List typeParameters, + Any value) { + return Expression.UserDefinedAny.builder() + .nullable(nullable) + .urn(urn) + .name(name) + .addAllTypeParameters(typeParameters) + .value(value) + .build(); + } + + /** + * Create a UserDefinedStruct with Struct representation. + * + * @param nullable whether the literal is nullable + * @param urn the URN of the user-defined type + * @param name the name of the user-defined type + * @param fields the fields, as a list of Literal values + */ + public static Expression.UserDefinedStruct userDefinedLiteralStruct( + boolean nullable, String urn, String name, java.util.List fields) { + return Expression.UserDefinedStruct.builder() + .nullable(nullable) + .urn(urn) + .name(name) + .addAllFields(fields) + .build(); + } + + /** + * Create a UserDefinedStruct with Struct representation and type parameters. + * + * @param nullable whether the literal is nullable + * @param urn the URN of the user-defined type + * @param name the name of the user-defined type + * @param typeParameters the type parameters for the user-defined type + * @param fields the fields, as a list of Literal values + */ + public static Expression.UserDefinedStruct userDefinedLiteralStruct( + boolean nullable, + String urn, + String name, + java.util.List typeParameters, + java.util.List fields) { + return Expression.UserDefinedStruct.builder() .nullable(nullable) .urn(urn) .name(name) - .value(value.toByteString()) + .addAllTypeParameters(typeParameters) + .addAllFields(fields) .build(); } diff --git a/core/src/main/java/io/substrait/expression/ExpressionVisitor.java b/core/src/main/java/io/substrait/expression/ExpressionVisitor.java index d64cab48c..7cec9b953 100644 --- a/core/src/main/java/io/substrait/expression/ExpressionVisitor.java +++ b/core/src/main/java/io/substrait/expression/ExpressionVisitor.java @@ -62,7 +62,9 @@ public interface ExpressionVisitor { - try { - bldr.setNullable(expr.nullable()) - .setUserDefined( - Expression.Literal.UserDefined.newBuilder() - .setTypeReference(typeReference) - .setValue(Any.parseFrom(expr.value()))) - .build(); - } catch (InvalidProtocolBufferException e) { - throw new IllegalStateException(e); + Expression.Literal.UserDefined.Builder userDefinedBuilder = + Expression.Literal.UserDefined.newBuilder() + .setTypeReference(typeReference) + .addAllTypeParameters(expr.typeParameters()) + .setValue(expr.value()); + + bldr.setNullable(expr.nullable()).setUserDefined(userDefinedBuilder).build(); + }); + } + + @Override + public Expression visit( + io.substrait.expression.Expression.UserDefinedStruct expr, EmptyVisitationContext context) { + int typeReference = + extensionCollector.getTypeReference(SimpleExtension.TypeAnchor.of(expr.urn(), expr.name())); + return lit( + bldr -> { + Expression.Literal.Struct.Builder structBuilder = Expression.Literal.Struct.newBuilder(); + for (io.substrait.expression.Expression.Literal field : expr.fields()) { + structBuilder.addFields(toLiteral(field)); } + + Expression.Literal.UserDefined.Builder userDefinedBuilder = + Expression.Literal.UserDefined.newBuilder() + .setTypeReference(typeReference) + .addAllTypeParameters(expr.typeParameters()) + .setStruct(structBuilder.build()); + + bldr.setNullable(expr.nullable()).setUserDefined(userDefinedBuilder).build(); }); } diff --git a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java index 8f95cdf07..847fcae55 100644 --- a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java @@ -492,10 +492,36 @@ public Expression.Literal from(io.substrait.proto.Expression.Literal literal) { { io.substrait.proto.Expression.Literal.UserDefined userDefinedLiteral = literal.getUserDefined(); + SimpleExtension.Type type = lookup.getType(userDefinedLiteral.getTypeReference(), extensions); - return ExpressionCreator.userDefinedLiteral( - literal.getNullable(), type.urn(), type.name(), userDefinedLiteral.getValue()); + String urn = type.urn(); + String name = type.name(); + + switch (userDefinedLiteral.getValCase()) { + case VALUE: + return ExpressionCreator.userDefinedLiteralAny( + literal.getNullable(), + urn, + name, + userDefinedLiteral.getTypeParametersList(), + userDefinedLiteral.getValue()); + case STRUCT: + return ExpressionCreator.userDefinedLiteralStruct( + literal.getNullable(), + urn, + name, + userDefinedLiteral.getTypeParametersList(), + userDefinedLiteral.getStruct().getFieldsList().stream() + .map(this::from) + .collect(Collectors.toList())); + case VAL_NOT_SET: + throw new IllegalStateException( + "UserDefined literal has no value (neither 'value' nor 'struct' is set)"); + default: + throw new IllegalStateException( + "Unknown UserDefined literal value case: " + userDefinedLiteral.getValCase()); + } } default: throw new IllegalStateException("Unexpected value: " + literal.getLiteralTypeCase()); diff --git a/core/src/main/java/io/substrait/extension/ExtensionCollector.java b/core/src/main/java/io/substrait/extension/ExtensionCollector.java index 7ad07a6b1..1ac3b2f1b 100644 --- a/core/src/main/java/io/substrait/extension/ExtensionCollector.java +++ b/core/src/main/java/io/substrait/extension/ExtensionCollector.java @@ -63,6 +63,77 @@ public int getTypeReference(SimpleExtension.TypeAnchor typeAnchor) { return counter; } + /** + * Returns an ExtensionCollection containing only the types and functions that have been tracked + * by this collector. This provides a minimal collection with exactly what was used during + * serialization. + * + *

This collection contains: + * + *

    + *
  • Only the types that were referenced via {@link #getTypeReference} + *
  • Only the functions that were referenced via {@link #getFunctionReference} + *
  • URI/URN mappings for only the used extension URNs + *
+ * + *

Types from the catalog are resolved, while custom UserDefined types (not in the catalog) are + * created via {@link SimpleExtension.Type#of(String, String)}. + * + * @return an ExtensionCollection with only the used types, functions, and URI/URN mappings + */ + public SimpleExtension.ExtensionCollection getExtensionCollection() { + java.util.List types = new ArrayList<>(); + java.util.List scalarFunctions = new ArrayList<>(); + java.util.List aggregateFunctions = new ArrayList<>(); + java.util.List windowFunctions = new ArrayList<>(); + + java.util.Set usedUrns = new java.util.HashSet<>(); + + for (Map.Entry entry : typeMap.forwardEntrySet()) { + SimpleExtension.TypeAnchor anchor = entry.getValue(); + usedUrns.add(anchor.urn()); + if (extensionCollection.hasType(anchor)) { + types.add(extensionCollection.getType(anchor)); + } else { + types.add(SimpleExtension.Type.of(anchor.urn(), anchor.key())); + } + } + + for (Map.Entry entry : funcMap.forwardEntrySet()) { + SimpleExtension.FunctionAnchor anchor = entry.getValue(); + usedUrns.add(anchor.urn()); + + if (extensionCollection.hasScalarFunction(anchor)) { + scalarFunctions.add(extensionCollection.getScalarFunction(anchor)); + } else if (extensionCollection.hasAggregateFunction(anchor)) { + aggregateFunctions.add(extensionCollection.getAggregateFunction(anchor)); + } else if (extensionCollection.hasWindowFunction(anchor)) { + windowFunctions.add(extensionCollection.getWindowFunction(anchor)); + } else { + throw new IllegalArgumentException( + String.format( + "Function %s::%s was tracked but not found in catalog as scalar, aggregate, or window function", + anchor.urn(), anchor.key())); + } + } + + BidiMap uriUrnMap = new BidiMap<>(); + for (String urn : usedUrns) { + String uri = extensionCollection.getUriFromUrn(urn); + if (uri != null) { + uriUrnMap.put(uri, urn); + } + } + + return SimpleExtension.ExtensionCollection.builder() + .addAllTypes(types) + .addAllScalarFunctions(scalarFunctions) + .addAllAggregateFunctions(aggregateFunctions) + .addAllWindowFunctions(windowFunctions) + .uriUrnMap(uriUrnMap) + .build(); + } + public void addExtensionsToPlan(Plan.Builder builder) { SimpleExtensions simpleExtensions = getExtensions(); diff --git a/core/src/main/java/io/substrait/extension/SimpleExtension.java b/core/src/main/java/io/substrait/extension/SimpleExtension.java index 39d7c45e0..02608cbf3 100644 --- a/core/src/main/java/io/substrait/extension/SimpleExtension.java +++ b/core/src/main/java/io/substrait/extension/SimpleExtension.java @@ -551,6 +551,18 @@ public abstract static class Type { public TypeAnchor getAnchor() { return anchorSupplier.get(); } + + /** + * Creates a minimal Type instance for custom UserDefined types that aren't loaded from YAML. + * This is useful for programmatically constructed types during protobuf deserialization. + * + * @param urn the extension URN (e.g., "extension:test:custom") + * @param name the type name (e.g., "MyCustomType") + * @return a Type instance with the specified urn and name + */ + public static Type of(String urn, String name) { + return ImmutableSimpleExtension.Type.builder().urn(urn).name(name).build(); + } } @JsonDeserialize(as = ImmutableSimpleExtension.ExtensionSignatures.class) @@ -666,6 +678,10 @@ public Type getType(TypeAnchor anchor) { anchor.key(), anchor.urn())); } + public boolean hasType(TypeAnchor anchor) { + return typeLookup.get().containsKey(anchor); + } + public ScalarFunctionVariant getScalarFunction(FunctionAnchor anchor) { ScalarFunctionVariant variant = scalarFunctionsLookup.get().get(anchor); if (variant != null) { @@ -718,6 +734,18 @@ public WindowFunctionVariant getWindowFunction(FunctionAnchor anchor) { anchor.key(), anchor.urn())); } + public boolean hasScalarFunction(FunctionAnchor anchor) { + return scalarFunctionsLookup.get().containsKey(anchor); + } + + public boolean hasAggregateFunction(FunctionAnchor anchor) { + return aggregateFunctionsLookup.get().containsKey(anchor); + } + + public boolean hasWindowFunction(FunctionAnchor anchor) { + return windowFunctionsLookup.get().containsKey(anchor); + } + /** * Gets the URI for a given URN. This is for internal framework use during URI/URN migration. * diff --git a/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java b/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java index 57132a940..68395ac0d 100644 --- a/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java +++ b/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java @@ -203,9 +203,15 @@ public Optional visit(Expression.StructLiteral expr, EmptyVisitation return visitLiteral(expr); } + @Override + public Optional visit(Expression.UserDefinedAny expr, EmptyVisitationContext context) + throws E { + return visitLiteral(expr); + } + @Override public Optional visit( - Expression.UserDefinedLiteral expr, EmptyVisitationContext context) throws E { + Expression.UserDefinedStruct expr, EmptyVisitationContext context) throws E { return visitLiteral(expr); } diff --git a/core/src/main/java/io/substrait/type/Type.java b/core/src/main/java/io/substrait/type/Type.java index aaf97aa12..7ef2d75a7 100644 --- a/core/src/main/java/io/substrait/type/Type.java +++ b/core/src/main/java/io/substrait/type/Type.java @@ -393,6 +393,23 @@ abstract class UserDefined implements Type { public abstract String name(); + /** + * Returns the type parameters for this user-defined type. + * + *

Type parameters are used to represent parameterized/generic types, such as {@code + * List} or {@code Map}. Each parameter in the list represents a type argument + * that specializes the generic user-defined type. + * + *

For example, a user-defined type {@code MyList} parameterized by {@code i32} would have + * one type parameter containing the {@code i32} type definition. + * + * @return a list of type parameters, or an empty list if this type is not parameterized + */ + @Value.Default + public java.util.List typeParameters() { + return java.util.Collections.emptyList(); + } + public static ImmutableType.UserDefined.Builder builder() { return ImmutableType.UserDefined.builder(); } diff --git a/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java b/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java index 691d4bce5..67d7bc9b5 100644 --- a/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java +++ b/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java @@ -165,6 +165,6 @@ public final T visit(final Type.Map expr) { public final T visit(final Type.UserDefined expr) { int ref = extensionCollector.getTypeReference(SimpleExtension.TypeAnchor.of(expr.urn(), expr.name())); - return typeContainer(expr).userDefined(ref); + return typeContainer(expr).userDefined(ref, expr.typeParameters()); } } diff --git a/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java b/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java index 6a1bc3186..1009fe52a 100644 --- a/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java +++ b/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java @@ -131,6 +131,9 @@ public final T struct(T... types) { public abstract T userDefined(int ref); + public abstract T userDefined( + int ref, java.util.List typeParameters); + protected abstract T wrap(Object o); protected abstract I i(int integerValue); diff --git a/core/src/main/java/io/substrait/type/proto/ParameterizedProtoConverter.java b/core/src/main/java/io/substrait/type/proto/ParameterizedProtoConverter.java index 4e0caa7c2..137c1fba3 100644 --- a/core/src/main/java/io/substrait/type/proto/ParameterizedProtoConverter.java +++ b/core/src/main/java/io/substrait/type/proto/ParameterizedProtoConverter.java @@ -262,6 +262,13 @@ public ParameterizedType userDefined(int ref) { "User defined types are not supported in Parameterized Types for now"); } + @Override + public ParameterizedType userDefined( + int ref, java.util.List typeParameters) { + throw new UnsupportedOperationException( + "User defined types are not supported in Parameterized Types for now"); + } + @Override protected ParameterizedType wrap(final Object o) { ParameterizedType.Builder bldr = ParameterizedType.newBuilder(); diff --git a/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java b/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java index 95d42328a..ee77e1445 100644 --- a/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java +++ b/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java @@ -90,7 +90,13 @@ public Type from(io.substrait.proto.Type type) { { io.substrait.proto.Type.UserDefined userDefined = type.getUserDefined(); SimpleExtension.Type t = lookup.getType(userDefined.getTypeReference(), extensions); - return n(userDefined.getNullability()).userDefined(t.urn(), t.name()); + boolean nullable = isNullable(userDefined.getNullability()); + return io.substrait.type.Type.UserDefined.builder() + .nullable(nullable) + .urn(t.urn()) + .name(t.name()) + .typeParameters(userDefined.getTypeParametersList()) + .build(); } case USER_DEFINED_TYPE_REFERENCE: throw new UnsupportedOperationException("Unsupported user defined reference: " + type); diff --git a/core/src/main/java/io/substrait/type/proto/TypeExpressionProtoVisitor.java b/core/src/main/java/io/substrait/type/proto/TypeExpressionProtoVisitor.java index 96cddd395..a3412a9e3 100644 --- a/core/src/main/java/io/substrait/type/proto/TypeExpressionProtoVisitor.java +++ b/core/src/main/java/io/substrait/type/proto/TypeExpressionProtoVisitor.java @@ -355,6 +355,13 @@ public DerivationExpression userDefined(int ref) { "User defined types are not supported in Derivation Expressions for now"); } + @Override + public DerivationExpression userDefined( + int ref, java.util.List typeParameters) { + throw new UnsupportedOperationException( + "User defined types are not supported in Derivation Expressions for now"); + } + @Override protected DerivationExpression wrap(final Object o) { DerivationExpression.Builder bldr = DerivationExpression.newBuilder(); diff --git a/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java b/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java index 2d0ed0ffc..7cb98263f 100644 --- a/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java +++ b/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java @@ -133,6 +133,17 @@ public Type userDefined(int ref) { Type.UserDefined.newBuilder().setTypeReference(ref).setNullability(nullability).build()); } + @Override + public Type userDefined( + int ref, java.util.List typeParameters) { + return wrap( + Type.UserDefined.newBuilder() + .setTypeReference(ref) + .setNullability(nullability) + .addAllTypeParameters(typeParameters) + .build()); + } + @Override protected Type wrap(final Object o) { Type.Builder bldr = Type.newBuilder(); diff --git a/core/src/test/java/io/substrait/extension/ExtensionCollectorGetCollectionTest.java b/core/src/test/java/io/substrait/extension/ExtensionCollectorGetCollectionTest.java new file mode 100644 index 000000000..eca391891 --- /dev/null +++ b/core/src/test/java/io/substrait/extension/ExtensionCollectorGetCollectionTest.java @@ -0,0 +1,173 @@ +package io.substrait.extension; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.junit.jupiter.api.Test; + +public class ExtensionCollectorGetCollectionTest { + + @Test + public void getExtensionCollection_containsOnlyTrackedTypes() { + BidiMap uriUrnMap = new BidiMap<>(); + uriUrnMap.put("https://example.com/catalog", "extension:example:catalog"); + + SimpleExtension.Type catalogType1 = + SimpleExtension.Type.of("extension:example:catalog", "type1"); + SimpleExtension.Type catalogType2 = + SimpleExtension.Type.of("extension:example:catalog", "type2"); + SimpleExtension.Type catalogType3 = + SimpleExtension.Type.of("extension:example:catalog", "type3"); + + SimpleExtension.ExtensionCollection catalog = + SimpleExtension.ExtensionCollection.builder() + .addTypes(catalogType1, catalogType2, catalogType3) + .uriUrnMap(uriUrnMap) + .build(); + + ExtensionCollector collector = new ExtensionCollector(catalog); + + collector.getTypeReference(catalogType1.getAnchor()); + collector.getTypeReference(catalogType2.getAnchor()); + + SimpleExtension.ExtensionCollection result = collector.getExtensionCollection(); + + assertEquals(2, result.types().size()); + assertEquals("type1", result.types().get(0).name()); + assertEquals("type2", result.types().get(1).name()); + } + + @Test + public void getExtensionCollection_containsOnlyTrackedFunctions() { + BidiMap uriUrnMap = new BidiMap<>(); + uriUrnMap.put("https://example.com/catalog", "extension:example:catalog"); + + SimpleExtension.ScalarFunctionVariant func1 = + ImmutableSimpleExtension.ScalarFunctionVariant.builder() + .urn("extension:example:catalog") + .name("func1") + .returnType(io.substrait.function.TypeExpressionCreator.REQUIRED.BOOLEAN) + .build(); + + SimpleExtension.ScalarFunctionVariant func2 = + ImmutableSimpleExtension.ScalarFunctionVariant.builder() + .urn("extension:example:catalog") + .name("func2") + .returnType(io.substrait.function.TypeExpressionCreator.REQUIRED.BOOLEAN) + .build(); + + SimpleExtension.ScalarFunctionVariant func3 = + ImmutableSimpleExtension.ScalarFunctionVariant.builder() + .urn("extension:example:catalog") + .name("func3") + .returnType(io.substrait.function.TypeExpressionCreator.REQUIRED.BOOLEAN) + .build(); + + SimpleExtension.ExtensionCollection catalog = + SimpleExtension.ExtensionCollection.builder() + .addScalarFunctions(func1, func2, func3) + .uriUrnMap(uriUrnMap) + .build(); + + ExtensionCollector collector = new ExtensionCollector(catalog); + + collector.getFunctionReference(func1); + collector.getFunctionReference(func3); + + SimpleExtension.ExtensionCollection result = collector.getExtensionCollection(); + + assertEquals(2, result.scalarFunctions().size()); + assertEquals("func1", result.scalarFunctions().get(0).name()); + assertEquals("func3", result.scalarFunctions().get(1).name()); + } + + @Test + public void getExtensionCollection_includesCustomTypes() { + SimpleExtension.ExtensionCollection emptyCatalog = + SimpleExtension.ExtensionCollection.builder().build(); + + ExtensionCollector collector = new ExtensionCollector(emptyCatalog); + + SimpleExtension.TypeAnchor customType = + SimpleExtension.TypeAnchor.of("extension:test:custom", "MyCustomType"); + + collector.getTypeReference(customType); + + SimpleExtension.ExtensionCollection result = collector.getExtensionCollection(); + + assertEquals(1, result.types().size()); + assertEquals("MyCustomType", result.types().get(0).name()); + assertEquals("extension:test:custom", result.types().get(0).urn()); + } + + @Test + public void getExtensionCollection_includesOnlyUsedUriUrnMappings() { + BidiMap uriUrnMap = new BidiMap<>(); + uriUrnMap.put("https://example.com/urn1", "extension:example:urn1"); + uriUrnMap.put("https://example.com/urn2", "extension:example:urn2"); + uriUrnMap.put("https://example.com/urn3", "extension:example:urn3"); + + SimpleExtension.Type type1 = SimpleExtension.Type.of("extension:example:urn1", "type1"); + SimpleExtension.Type type2 = SimpleExtension.Type.of("extension:example:urn2", "type2"); + SimpleExtension.Type type3 = SimpleExtension.Type.of("extension:example:urn3", "type3"); + + SimpleExtension.ExtensionCollection catalog = + SimpleExtension.ExtensionCollection.builder() + .addTypes(type1, type2, type3) + .uriUrnMap(uriUrnMap) + .build(); + + ExtensionCollector collector = new ExtensionCollector(catalog); + + collector.getTypeReference(type1.getAnchor()); + collector.getTypeReference(type3.getAnchor()); + + SimpleExtension.ExtensionCollection result = collector.getExtensionCollection(); + + assertEquals(2, result.uriUrnMap().forwardEntrySet().size()); + assertNotNull(result.getUriFromUrn("extension:example:urn1")); + assertNotNull(result.getUriFromUrn("extension:example:urn3")); + assertEquals("https://example.com/urn1", result.getUriFromUrn("extension:example:urn1")); + assertEquals("https://example.com/urn3", result.getUriFromUrn("extension:example:urn3")); + } + + @Test + public void getExtensionCollection_emptyWhenNothingTracked() { + SimpleExtension.ExtensionCollection catalog = + SimpleExtension.ExtensionCollection.builder().build(); + + ExtensionCollector collector = new ExtensionCollector(catalog); + + SimpleExtension.ExtensionCollection result = collector.getExtensionCollection(); + + assertTrue(result.types().isEmpty()); + assertTrue(result.scalarFunctions().isEmpty()); + assertTrue(result.aggregateFunctions().isEmpty()); + assertTrue(result.windowFunctions().isEmpty()); + } + + @Test + public void getExtensionCollection_throwsWhenFunctionNotInCatalog() { + SimpleExtension.ExtensionCollection emptyCatalog = + SimpleExtension.ExtensionCollection.builder().build(); + + ExtensionCollector collector = new ExtensionCollector(emptyCatalog); + + SimpleExtension.ScalarFunctionVariant func = + ImmutableSimpleExtension.ScalarFunctionVariant.builder() + .urn("extension:missing:catalog") + .name("missing_func") + .returnType(io.substrait.function.TypeExpressionCreator.REQUIRED.BOOLEAN) + .build(); + + collector.getFunctionReference(func); + + IllegalArgumentException exception = + org.junit.jupiter.api.Assertions.assertThrows( + IllegalArgumentException.class, () -> collector.getExtensionCollection()); + + assertTrue(exception.getMessage().contains("extension:missing:catalog::missing_func")); + assertTrue(exception.getMessage().contains("not found in catalog")); + } +} diff --git a/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java index ccac93bcb..1ce55163a 100644 --- a/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java @@ -3,16 +3,23 @@ import static io.substrait.expression.proto.ProtoExpressionConverter.EMPTY_TYPE; import static org.junit.jupiter.api.Assertions.assertEquals; +import com.google.protobuf.Any; import io.substrait.TestBase; +import io.substrait.expression.Expression; import io.substrait.expression.ExpressionCreator; import io.substrait.expression.proto.ExpressionProtoConverter; import io.substrait.expression.proto.ProtoExpressionConverter; +import io.substrait.extension.SimpleExtension; import io.substrait.util.EmptyVisitationContext; import java.math.BigDecimal; import org.junit.jupiter.api.Test; public class LiteralRoundtripTest extends TestBase { + // Load custom extensions for UserDefined literal tests + private static final SimpleExtension.ExtensionCollection testExtensions = + SimpleExtension.load(java.util.Arrays.asList("/extensions/custom_extensions.yaml")); + @Test void decimal() { io.substrait.expression.Expression.DecimalLiteral val = @@ -22,4 +29,114 @@ void decimal() { new ProtoExpressionConverter(null, null, EMPTY_TYPE, protoRelConverter); assertEquals(val, from.from(val.accept(to, EmptyVisitationContext.INSTANCE))); } + + @Test + void userDefinedLiteralWithAnyRepresentation() { + io.substrait.proto.Expression.Literal innerLiteral = + io.substrait.proto.Expression.Literal.newBuilder().setI32(42).build(); + Any anyValue = Any.pack(innerLiteral); + + String urn = "extension:test:custom_extensions"; + String typeName = "customType1"; + + Expression.UserDefinedLiteral val = + ExpressionCreator.userDefinedLiteralAny(false, urn, typeName, anyValue); + + ExpressionProtoConverter exprProtoConv = + new ExpressionProtoConverter(functionCollector, relProtoConverter); + ProtoExpressionConverter protoExprConv = + new ProtoExpressionConverter( + functionCollector, testExtensions, EMPTY_TYPE, protoRelConverter); + assertEquals(val, protoExprConv.from(exprProtoConv.toProto(val))); + } + + @Test + void userDefinedLiteralWithStructRepresentation() { + String urn = "extension:test:custom_extensions"; + String typeName = "customType2"; + + java.util.List fields = + java.util.Arrays.asList( + ExpressionCreator.i32(false, 42), ExpressionCreator.string(false, "test")); + Expression.UserDefinedLiteral val = + ExpressionCreator.userDefinedLiteralStruct(false, urn, typeName, fields); + + ExpressionProtoConverter exprProtoConv = + new ExpressionProtoConverter(functionCollector, relProtoConverter); + ProtoExpressionConverter protoExprConv = + new ProtoExpressionConverter( + functionCollector, testExtensions, EMPTY_TYPE, protoRelConverter); + assertEquals(val, protoExprConv.from(exprProtoConv.toProto(val))); + } + + @Test + void userDefinedLiteralWithAnyRepresentationAndTypeParameters() { + io.substrait.proto.Expression.Literal innerLiteral = + io.substrait.proto.Expression.Literal.newBuilder().setI32(42).build(); + Any anyValue = Any.pack(innerLiteral); + + String urn = "extension:test:custom_extensions"; + String typeName = "customType1"; + + java.util.List typeParams = + java.util.Arrays.asList( + io.substrait.proto.Type.Parameter.newBuilder() + .setDataType( + io.substrait.proto.Type.newBuilder() + .setI32( + io.substrait.proto.Type.I32 + .newBuilder() + .setNullability( + io.substrait.proto.Type.Nullability.NULLABILITY_REQUIRED))) + .build()); + + Expression.UserDefinedLiteral val = + ExpressionCreator.userDefinedLiteralAny(false, urn, typeName, typeParams, anyValue); + + ExpressionProtoConverter exprProtoConv = + new ExpressionProtoConverter(functionCollector, relProtoConverter); + ProtoExpressionConverter protoExprConv = + new ProtoExpressionConverter( + functionCollector, testExtensions, EMPTY_TYPE, protoRelConverter); + + Expression.UserDefinedLiteral roundtripped = + (Expression.UserDefinedLiteral) protoExprConv.from(exprProtoConv.toProto(val)); + + assertEquals(val, roundtripped); + } + + @Test + void userDefinedLiteralWithStructRepresentationAndTypeParameters() { + String urn = "extension:test:custom_extensions"; + String typeName = "customType2"; + + java.util.List fields = + java.util.Arrays.asList( + ExpressionCreator.i32(false, 42), ExpressionCreator.string(false, "test")); + + java.util.List typeParams = + java.util.Arrays.asList( + io.substrait.proto.Type.Parameter.newBuilder() + .setDataType( + io.substrait.proto.Type.newBuilder() + .setString( + io.substrait.proto.Type.String.newBuilder() + .setNullability( + io.substrait.proto.Type.Nullability.NULLABILITY_NULLABLE))) + .build()); + + Expression.UserDefinedLiteral val = + ExpressionCreator.userDefinedLiteralStruct(false, urn, typeName, typeParams, fields); + + ExpressionProtoConverter exprProtoConv = + new ExpressionProtoConverter(functionCollector, relProtoConverter); + ProtoExpressionConverter protoExprConv = + new ProtoExpressionConverter( + functionCollector, testExtensions, EMPTY_TYPE, protoRelConverter); + + Expression.UserDefinedLiteral roundtripped = + (Expression.UserDefinedLiteral) protoExprConv.from(exprProtoConv.toProto(val)); + + assertEquals(val, roundtripped); + } } diff --git a/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java b/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java index 71de9a7d5..cee84d13d 100644 --- a/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java +++ b/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java @@ -37,7 +37,8 @@ import io.substrait.expression.Expression.TimestampLiteral; import io.substrait.expression.Expression.TimestampTZLiteral; import io.substrait.expression.Expression.UUIDLiteral; -import io.substrait.expression.Expression.UserDefinedLiteral; +import io.substrait.expression.Expression.UserDefinedAny; +import io.substrait.expression.Expression.UserDefinedStruct; import io.substrait.expression.Expression.VarCharLiteral; import io.substrait.expression.Expression.WindowFunctionInvocation; import io.substrait.expression.ExpressionVisitor; @@ -188,9 +189,14 @@ public String visit(StructLiteral expr, EmptyVisitationContext context) throws R } @Override - public String visit(UserDefinedLiteral expr, EmptyVisitationContext context) + public String visit(UserDefinedAny expr, EmptyVisitationContext context) throws RuntimeException { + return ""; + } + + @Override + public String visit(UserDefinedStruct expr, EmptyVisitationContext context) throws RuntimeException { - return ""; + return ""; } @Override diff --git a/examples/substrait-spark/src/main/java/io/substrait/examples/util/SubstraitStringify.java b/examples/substrait-spark/src/main/java/io/substrait/examples/util/SubstraitStringify.java index f3b34f6c2..de67a3ee8 100644 --- a/examples/substrait-spark/src/main/java/io/substrait/examples/util/SubstraitStringify.java +++ b/examples/substrait-spark/src/main/java/io/substrait/examples/util/SubstraitStringify.java @@ -91,7 +91,7 @@ public static List explain(io.substrait.plan.Plan plan) { /** * Explains the Sustrait relation * - * @param plan Subsrait relation + * @param rel Subsrait relation * @return List of strings; typically these would then be logged or sent to stdout */ public static List explain(io.substrait.relation.Rel rel) { diff --git a/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java b/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java index 932b8f6d8..5cdf8088f 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java @@ -363,8 +363,8 @@ public RelDataType visit(Type.UserDefined expr) throws RuntimeException { if (type != null) { return type; } - throw new UnsupportedOperationException( - String.format("Unable to map user-defined type: %s", expr)); + return io.substrait.isthmus.type.SubstraitUserDefinedType.SubstraitUserDefinedAnyType.from( + expr); } private boolean n(NullableType type) { diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java b/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java index 3406de7de..f45ca86d3 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java @@ -5,7 +5,7 @@ import io.substrait.expression.ExpressionCreator; import io.substrait.isthmus.CallConverter; import io.substrait.isthmus.TypeConverter; -import io.substrait.type.Type; +import io.substrait.isthmus.type.SubstraitUserDefinedType; import java.util.ArrayList; import java.util.List; import java.util.Optional; @@ -41,18 +41,26 @@ public class CallConverters { }; /** - * {@link SqlKind#REINTERPRET} is utilized by Isthmus to represent and store {@link - * Expression.UserDefinedLiteral}s within Calcite. + * {@link SqlKind#REINTERPRET} is utilized by Isthmus to represent {@link + * Expression.UserDefinedAny} literals within Calcite. * - *

When converting from Substrait to Calcite, the {@link Expression.UserDefinedLiteral#value()} - * is stored within a {@link org.apache.calcite.sql.type.SqlTypeName#BINARY} {@link - * org.apache.calcite.rex.RexLiteral} and then re-interpreted to have the correct type. + *

When converting from Substrait to Calcite, UserDefinedAny literals are serialized to binary + * and stored as {@link org.apache.calcite.sql.type.SqlTypeName#BINARY} {@link + * org.apache.calcite.rex.RexLiteral}, then re-interpreted to have a custom {@link + * SubstraitUserDefinedType.SubstraitUserDefinedAnyType} that preserves all metadata including + * type parameters. * - *

See {@link ExpressionRexConverter#visit(Expression.UserDefinedLiteral, - * SubstraitRelNodeConverter.Context)} for this conversion. + *

Note: {@link Expression.UserDefinedStruct} literals are NOT handled via REINTERPRET. + * Instead, they are represented as Calcite ROW literals with {@link + * SubstraitUserDefinedType.SubstraitUserDefinedStructType} and converted via {@link + * LiteralConverter}. * - *

When converting from Calcite to Substrait, this call converter extracts the {@link - * Expression.UserDefinedLiteral} that was stored. + *

See {@link ExpressionRexConverter#visit(Expression.UserDefinedAny, + * SubstraitRelNodeConverter.Context)} for the UserDefinedAny conversion. + * + *

When converting from Calcite back to Substrait, this call converter deserializes the binary + * value and reconstructs the UserDefinedAny literal with all metadata preserved (including type + * parameters). */ public static Function REINTERPRET = typeConverter -> @@ -61,20 +69,28 @@ public class CallConverters { return null; } Expression operand = visitor.apply(call.getOperands().get(0)); - Type type = typeConverter.toSubstrait(call.getType()); - // For now, we only support handling of SqlKind.REINTEPRETET for the case of stored - // user-defined literals if (operand instanceof Expression.FixedBinaryLiteral - && type instanceof Type.UserDefined) { + && call.getType() instanceof SubstraitUserDefinedType.SubstraitUserDefinedAnyType) { Expression.FixedBinaryLiteral literal = (Expression.FixedBinaryLiteral) operand; - Type.UserDefined t = (Type.UserDefined) type; - - return Expression.UserDefinedLiteral.builder() - .urn(t.urn()) - .name(t.name()) - .value(literal.value()) - .build(); + SubstraitUserDefinedType.SubstraitUserDefinedAnyType customType = + (SubstraitUserDefinedType.SubstraitUserDefinedAnyType) call.getType(); + + try { + com.google.protobuf.Any anyValue = + com.google.protobuf.Any.parseFrom(literal.value().toByteArray()); + + return Expression.UserDefinedAny.builder() + .urn(customType.getUrn()) + .name(customType.getName()) + .typeParameters(customType.getTypeParameters()) + .value(anyValue) + .nullable(customType.isNullable()) + .build(); + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw new IllegalArgumentException( + "Failed to parse UserDefinedAny literal value", e); + } } return null; }; diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java index 2b8052889..bf6aeceba 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java @@ -109,12 +109,47 @@ public RexNode visit(Expression.NullLiteral expr, Context context) throws Runtim } @Override - public RexNode visit(Expression.UserDefinedLiteral expr, Context context) - throws RuntimeException { + public RexNode visit(Expression.UserDefinedAny expr, Context context) throws RuntimeException { + io.substrait.isthmus.type.SubstraitUserDefinedType.SubstraitUserDefinedAnyType customType = + io.substrait.isthmus.type.SubstraitUserDefinedType.SubstraitUserDefinedAnyType.from( + expr.getType()); + RexLiteral binaryLiteral = rexBuilder.makeBinaryLiteral(new ByteString(expr.value().toByteArray())); - RelDataType type = typeConverter.toCalcite(typeFactory, expr.getType()); - return rexBuilder.makeReinterpretCast(type, binaryLiteral, rexBuilder.makeLiteral(false)); + return rexBuilder.makeReinterpretCast(customType, binaryLiteral, rexBuilder.makeLiteral(false)); + } + + @Override + public RexNode visit(Expression.UserDefinedStruct expr, Context context) throws RuntimeException { + // Convert field types to Calcite types for the struct representation + java.util.List fieldTypes = + expr.fields().stream() + .map(field -> typeConverter.toCalcite(typeFactory, field.getType())) + .collect(java.util.stream.Collectors.toList()); + + // Generate dummy field names (f0, f1, f2, etc.) to satisfy Calcite's ROW type requirements. + // Substrait UserDefinedStruct doesn't have field names - just ordered field values. + // These synthetic names are discarded during conversion back to Substrait. + java.util.List fieldNames = + java.util.stream.IntStream.range(0, expr.fields().size()) + .mapToObj(i -> "f" + i) + .collect(java.util.stream.Collectors.toList()); + + io.substrait.isthmus.type.SubstraitUserDefinedType.SubstraitUserDefinedStructType customType = + new io.substrait.isthmus.type.SubstraitUserDefinedType.SubstraitUserDefinedStructType( + expr.urn(), + expr.name(), + expr.typeParameters(), + expr.nullable(), + fieldTypes, + fieldNames); + + java.util.List fieldLiterals = + expr.fields().stream() + .map(field -> (RexLiteral) field.accept(this, context)) + .collect(java.util.stream.Collectors.toList()); + + return rexBuilder.makeLiteral(fieldLiterals, customType, false); } @Override diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/LiteralConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/LiteralConverter.java index 02cb8a116..910b9682f 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/LiteralConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/LiteralConverter.java @@ -27,8 +27,6 @@ import org.apache.calcite.util.TimestampString; public class LiteralConverter { - // TODO: Handle conversion of user-defined type literals - static final DateTimeFormatter CALCITE_LOCAL_DATE_FORMATTER = DateTimeFormatter.ISO_LOCAL_DATE; static final DateTimeFormatter CALCITE_LOCAL_TIME_FORMATTER = new DateTimeFormatterBuilder() @@ -195,6 +193,25 @@ public Expression.Literal convert(RexLiteral literal) { case ROW: { + // Check if this is a SubstraitUserDefinedStructType + if (literal.getType() + instanceof + io.substrait.isthmus.type.SubstraitUserDefinedType.SubstraitUserDefinedStructType) { + io.substrait.isthmus.type.SubstraitUserDefinedType.SubstraitUserDefinedStructType + udtType = + (io.substrait.isthmus.type.SubstraitUserDefinedType + .SubstraitUserDefinedStructType) + literal.getType(); + List literals = (List) literal.getValue(); + return ExpressionCreator.userDefinedLiteralStruct( + udtType.isNullable(), + udtType.getUrn(), + udtType.getName(), + udtType.getTypeParameters(), + literals.stream().map(this::convert).collect(Collectors.toList())); + } + + // Regular struct List literals = (List) literal.getValue(); return ExpressionCreator.struct( n, literals.stream().map(this::convert).collect(Collectors.toList())); diff --git a/isthmus/src/main/java/io/substrait/isthmus/type/SubstraitUserDefinedType.java b/isthmus/src/main/java/io/substrait/isthmus/type/SubstraitUserDefinedType.java new file mode 100644 index 000000000..48fda7c71 --- /dev/null +++ b/isthmus/src/main/java/io/substrait/isthmus/type/SubstraitUserDefinedType.java @@ -0,0 +1,233 @@ +package io.substrait.isthmus.type; + +import io.substrait.type.Type; +import java.util.List; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeField; +import org.apache.calcite.rel.type.RelDataTypeFieldImpl; +import org.apache.calcite.rel.type.RelDataTypeImpl; +import org.apache.calcite.sql.type.SqlTypeName; + +/** + * Base class for custom Calcite {@link RelDataType} implementations representing Substrait + * user-defined types. + * + *

These custom types preserve all UDT metadata (URN, name, type parameters) during Calcite + * roundtrips, unlike the previous approach which flattened everything to binary with REINTERPRET. + * + *

Two concrete implementations exist: + * + *

    + *
  • {@link SubstraitUserDefinedAnyType} - For opaque binary UDT literals (wraps protobuf Any) + *
  • {@link SubstraitUserDefinedStructType} - For structured UDT literals with accessible fields + *
+ * + * @see SubstraitUserDefinedAnyType + * @see SubstraitUserDefinedStructType + * @see io.substrait.expression.Expression.UserDefinedAny + * @see io.substrait.expression.Expression.UserDefinedStruct + */ +public abstract class SubstraitUserDefinedType extends RelDataTypeImpl { + + private final String urn; + private final String name; + private final List typeParameters; + private final boolean nullable; + + protected SubstraitUserDefinedType( + String urn, + String name, + List typeParameters, + boolean nullable) { + this.urn = urn; + this.name = name; + this.typeParameters = + typeParameters != null ? typeParameters : java.util.Collections.emptyList(); + this.nullable = nullable; + computeDigest(); + } + + public String getUrn() { + return urn; + } + + public String getName() { + return name; + } + + public List getTypeParameters() { + return typeParameters; + } + + @Override + public boolean isNullable() { + return nullable; + } + + @Override + public SqlTypeName getSqlTypeName() { + return SqlTypeName.OTHER; + } + + /** Converts this Calcite type back to a Substrait {@link Type.UserDefined}. */ + public Type.UserDefined toSubstraitType() { + return Type.UserDefined.builder() + .urn(urn) + .name(name) + .typeParameters(typeParameters) + .nullable(nullable) + .build(); + } + + @Override + protected void generateTypeString(StringBuilder sb, boolean withDetail) { + sb.append(name); + if (!typeParameters.isEmpty()) { + sb.append("<"); + sb.append(String.join(", ", java.util.Collections.nCopies(typeParameters.size(), "_"))); + sb.append(">"); + } + } + + /** + * Custom Calcite type representing a Substrait {@link + * io.substrait.expression.Expression.UserDefinedAny} type. + * + *

This type wraps opaque binary data (protobuf Any) and preserves all UDT metadata including + * type parameters during Calcite roundtrips. + * + *

Note: The actual value (protobuf Any) is not stored in the type itself - it's stored in the + * literal. This type only carries the metadata (URN, name, type parameters). + * + *

Both {@link io.substrait.expression.Expression.UserDefinedAny UserDefinedAny} and {@link + * io.substrait.expression.Expression.UserDefinedStruct UserDefinedStruct} literals use this type + * when passing through Calcite, as they both need to be serialized to binary with REINTERPRET. + * + * @see SubstraitUserDefinedStructType + * @see io.substrait.expression.Expression.UserDefinedAny + * @see io.substrait.expression.Expression.UserDefinedStruct + */ + public static class SubstraitUserDefinedAnyType extends SubstraitUserDefinedType { + + public SubstraitUserDefinedAnyType( + String urn, + String name, + List typeParameters, + boolean nullable) { + super(urn, name, typeParameters, nullable); + } + + /** Creates a SubstraitUserDefinedAnyType from a Substrait Type.UserDefined. */ + public static SubstraitUserDefinedAnyType from(io.substrait.type.Type.UserDefined type) { + return new SubstraitUserDefinedAnyType( + type.urn(), type.name(), type.typeParameters(), type.nullable()); + } + } + + /** + * Custom Calcite type representing a Substrait {@link + * io.substrait.expression.Expression.UserDefinedStruct} type. + * + *

This type represents a structured UDT with explicitly defined fields. Unlike {@link + * SubstraitUserDefinedAnyType}, the fields are accessible and can be represented as a Calcite + * STRUCT/ROW type with additional UDT metadata (URN, name, type parameters). + * + *

Note: Currently, UserDefinedStruct literals are serialized to binary when passing through + * Calcite (using {@link SubstraitUserDefinedAnyType}), so this structured type is primarily for + * future use when Calcite can better handle structured user-defined types. + * + * @see SubstraitUserDefinedAnyType + * @see io.substrait.expression.Expression.UserDefinedStruct + */ + public static class SubstraitUserDefinedStructType extends SubstraitUserDefinedType { + + private final List fieldTypes; + private final List fieldNames; + + public SubstraitUserDefinedStructType( + String urn, + String name, + List typeParameters, + boolean nullable, + List fieldTypes, + List fieldNames) { + super(urn, name, typeParameters, nullable); + if (fieldTypes.size() != fieldNames.size()) { + throw new IllegalArgumentException("Field types and names must have same length"); + } + this.fieldTypes = fieldTypes; + this.fieldNames = fieldNames; + } + + @Override + public List getFieldList() { + java.util.List fields = new java.util.ArrayList<>(); + for (int i = 0; i < fieldTypes.size(); i++) { + fields.add(new RelDataTypeFieldImpl(fieldNames.get(i), i, fieldTypes.get(i))); + } + return fields; + } + + @Override + public int getFieldCount() { + return fieldTypes.size(); + } + + @Override + public RelDataTypeField getField(String fieldName, boolean caseSensitive, boolean elideRecord) { + for (int i = 0; i < fieldNames.size(); i++) { + String name = fieldNames.get(i); + if (caseSensitive ? name.equals(fieldName) : name.equalsIgnoreCase(fieldName)) { + return new RelDataTypeFieldImpl(name, i, fieldTypes.get(i)); + } + } + return null; + } + + public List getFieldTypes() { + return fieldTypes; + } + + @Override + public List getFieldNames() { + return fieldNames; + } + + @Override + public SqlTypeName getSqlTypeName() { + // Can be considered as ROW since it has structure + return SqlTypeName.ROW; + } + + @Override + protected void generateTypeString(StringBuilder sb, boolean withDetail) { + sb.append(getName()); + if (!getTypeParameters().isEmpty()) { + sb.append("<"); + sb.append( + String.join(", ", java.util.Collections.nCopies(getTypeParameters().size(), "_"))); + sb.append(">"); + } + if (withDetail && fieldNames != null) { + sb.append("("); + sb.append( + java.util.stream.IntStream.range(0, fieldNames.size()) + .mapToObj(i -> fieldNames.get(i) + ": " + fieldTypes.get(i)) + .collect(java.util.stream.Collectors.joining(", "))); + sb.append(")"); + } + } + + /** + * Creates a SubstraitUserDefinedStructType from a Substrait Type.UserDefined and field + * information. + */ + public static SubstraitUserDefinedStructType from( + io.substrait.type.Type.UserDefined type, + List fieldTypes, + List fieldNames) { + return new SubstraitUserDefinedStructType( + type.urn(), type.name(), type.typeParameters(), type.nullable(), fieldTypes, fieldNames); + } + } +} diff --git a/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java b/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java index c80f6a4ba..bde3f2a39 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java @@ -15,7 +15,6 @@ import io.substrait.isthmus.expression.WindowFunctionConverter; import io.substrait.isthmus.utils.UserTypeFactory; import io.substrait.proto.Expression; -import io.substrait.proto.Expression.Literal.Builder; import io.substrait.relation.ProtoRelConverter; import io.substrait.relation.Rel; import io.substrait.relation.RelProtoConverter; @@ -287,6 +286,31 @@ protected SubstraitRelNodeConverter createSubstraitRelNodeConverter(RelBuilder r } } + // Helper methods for roundtrip assertions + + /** Assert that a relation roundtrips correctly through Calcite conversion. */ + private void assertCalciteRoundtrip(Rel originalRel) { + RelNode calciteRel = substraitToCalcite.convert(originalRel); + Rel calciteRoundtrippedRel = calciteToSubstrait.apply(calciteRel); + assertEquals(originalRel, calciteRoundtrippedRel); + } + + /** Assert that a relation roundtrips correctly through Proto serialization. */ + private void assertProtoRoundtrip(Rel originalRel) { + ExtensionCollector extensionCollector = new ExtensionCollector(); + io.substrait.proto.Rel protoRel = + new RelProtoConverter(extensionCollector).toProto(originalRel); + Rel protoRoundtrippedRel = + new ProtoRelConverter(extensionCollector, extensionCollection).from(protoRel); + assertEquals(originalRel, protoRoundtrippedRel); + } + + /** Assert that a relation roundtrips correctly through both Calcite and Proto conversions. */ + private void assertRoundtrip(Rel originalRel) { + assertCalciteRoundtrip(originalRel); + assertProtoRoundtrip(originalRel); + } + @Test void customScalarFunctionRoundtrip() { // CREATE TABLE example(a TEXT) @@ -585,24 +609,207 @@ void customTypesInFunctionsRoundtrip() { @Test void customTypesLiteralInFunctionsRoundtrip() { - Builder bldr = Expression.Literal.newBuilder(); + Expression.Literal.Builder bldr = Expression.Literal.newBuilder(); Any anyValue = Any.pack(bldr.setI32(10).build()); - UserDefinedLiteral val = ExpressionCreator.userDefinedLiteral(false, URN, "a_type", anyValue); + UserDefinedLiteral val = + ExpressionCreator.userDefinedLiteralAny(false, URN, "a_type", anyValue); - Rel rel1 = + Rel originalRel = b.project( input -> List.of(b.scalarFn(URN, "to_b_type:u!a_type", R.userDefined(URN, "b_type"), val)), b.remap(1), b.namedScan(List.of("example"), List.of("a"), List.of(N.userDefined(URN, "a_type")))); - RelNode calciteRel = substraitToCalcite.convert(rel1); - Rel rel2 = calciteToSubstrait.apply(calciteRel); - assertEquals(rel1, rel2); + assertRoundtrip(originalRel); + } - ExtensionCollector extensionCollector = new ExtensionCollector(); - io.substrait.proto.Rel protoRel = new RelProtoConverter(extensionCollector).toProto(rel1); - Rel rel3 = new ProtoRelConverter(extensionCollector, extensionCollection).from(protoRel); - assertEquals(rel1, rel3); + @Test + void multipleDifferentUserDefinedAnyTypesProtoRoundtrip() { + // Test that UserDefinedAny literals with different payload types have different type names + // a_type wraps int, b_type wraps string - proto only + Expression.Literal.Builder bldr1 = Expression.Literal.newBuilder(); + Any anyValue1 = Any.pack(bldr1.setI32(100).build()); + UserDefinedLiteral aTypeLit = + ExpressionCreator.userDefinedLiteralAny(false, URN, "a_type", anyValue1); + + Expression.Literal.Builder bldr2 = Expression.Literal.newBuilder(); + Any anyValue2 = Any.pack(bldr2.setString("b_value").build()); + UserDefinedLiteral bTypeLit = + ExpressionCreator.userDefinedLiteralAny(false, URN, "b_type", anyValue2); + + Rel originalRel = + b.project( + input -> List.of(aTypeLit, bTypeLit), + b.remap(2), + b.namedScan(List.of("example"), List.of("a"), List.of(N.userDefined(URN, "a_type")))); + + assertProtoRoundtrip(originalRel); + } + + @Test + void userDefinedStructWithPrimitivesProtoRoundtrip() { + // Test UserDefinedStruct with various primitive field types - proto roundtrip only + io.substrait.expression.Expression.UserDefinedStruct val = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(URN) + .name("a_type") + .addFields(ExpressionCreator.i32(false, 42)) + .addFields(ExpressionCreator.string(false, "hello")) + .addFields(ExpressionCreator.bool(false, true)) + .addFields(ExpressionCreator.fp64(false, 2.718)) + .build(); + + Rel originalRel = + b.project( + input -> + List.of(b.scalarFn(URN, "to_b_type:u!a_type", R.userDefined(URN, "b_type"), val)), + b.remap(1), + b.namedScan(List.of("example"), List.of("a"), List.of(N.userDefined(URN, "a_type")))); + + assertProtoRoundtrip(originalRel); + } + + @Test + void userDefinedStructWithNestedStructProtoRoundtrip() { + // Test UserDefinedStruct with nested struct fields - proto roundtrip only + io.substrait.expression.Expression.StructLiteral innerStruct = + ExpressionCreator.struct( + false, ExpressionCreator.i32(false, 10), ExpressionCreator.string(false, "nested")); + + io.substrait.expression.Expression.UserDefinedStruct val = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(URN) + .name("a_type") + .addFields(ExpressionCreator.i32(false, 100)) + .addFields(innerStruct) + .addFields(ExpressionCreator.bool(false, false)) + .build(); + + Rel originalRel = + b.project( + input -> + List.of(b.scalarFn(URN, "to_b_type:u!a_type", R.userDefined(URN, "b_type"), val)), + b.remap(1), + b.namedScan(List.of("example"), List.of("a"), List.of(N.userDefined(URN, "a_type")))); + + assertProtoRoundtrip(originalRel); + } + + @Test + void multipleUserDefinedStructDifferentStructuresProtoRoundtrip() { + // Test multiple UserDefinedStruct types with different struct schemas + // a_type: {content: string} + // b_type: {content_int: i32, content_fp: fp64} + io.substrait.expression.Expression.UserDefinedStruct aTypeStruct = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(URN) + .name("a_type") + .addFields(ExpressionCreator.string(false, "hello")) + .build(); + + io.substrait.expression.Expression.UserDefinedStruct bTypeStruct = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(URN) + .name("b_type") + .addFields(ExpressionCreator.i32(false, 42)) + .addFields(ExpressionCreator.fp64(false, 3.14159)) + .build(); + + Rel originalRel = + b.project( + input -> List.of(aTypeStruct, bTypeStruct), + b.remap(2), + b.namedScan(List.of("example"), List.of("a"), List.of(N.userDefined(URN, "a_type")))); + + assertProtoRoundtrip(originalRel); + } + + @Test + void intermixedUserDefinedAnyAndStructProtoRoundtrip() { + // Test intermixing UserDefinedAny and UserDefinedStruct in the same query + Expression.Literal.Builder bldr1 = Expression.Literal.newBuilder(); + Any anyValue1 = Any.pack(bldr1.setI64(999L).build()); + UserDefinedLiteral anyLit1 = + ExpressionCreator.userDefinedLiteralAny(false, URN, "a_type", anyValue1); + + io.substrait.expression.Expression.UserDefinedStruct structLit1 = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(URN) + .name("a_type") + .addFields(ExpressionCreator.i32(false, 123)) + .addFields(ExpressionCreator.bool(false, false)) + .build(); + + Expression.Literal.Builder bldr2 = Expression.Literal.newBuilder(); + Any anyValue2 = Any.pack(bldr2.setString("mixed").build()); + UserDefinedLiteral anyLit2 = + ExpressionCreator.userDefinedLiteralAny(false, URN, "a_type", anyValue2); + + io.substrait.expression.Expression.UserDefinedStruct structLit2 = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(URN) + .name("a_type") + .addFields(ExpressionCreator.fp64(false, 1.414)) + .build(); + + Rel originalRel = + b.project( + input -> + List.of( + b.scalarFn(URN, "to_b_type:u!a_type", R.userDefined(URN, "b_type"), anyLit1), + b.scalarFn(URN, "to_b_type:u!a_type", R.userDefined(URN, "b_type"), structLit1), + b.scalarFn(URN, "to_b_type:u!a_type", R.userDefined(URN, "b_type"), anyLit2), + b.scalarFn( + URN, "to_b_type:u!a_type", R.userDefined(URN, "b_type"), structLit2)), + b.remap(4), + b.namedScan(List.of("example"), List.of("a"), List.of(N.userDefined(URN, "a_type")))); + + assertProtoRoundtrip(originalRel); + } + + @Test + void multipleDifferentUDTTypesWithAnyAndStructProtoRoundtrip() { + // Test multiple different UDT type names (a_type, b_type) with both Any and Struct + Expression.Literal.Builder aTypeBldr = Expression.Literal.newBuilder(); + Any aTypeAny = Any.pack(aTypeBldr.setI32(42).build()); + UserDefinedLiteral aTypeAny1 = + ExpressionCreator.userDefinedLiteralAny(false, URN, "a_type", aTypeAny); + + io.substrait.expression.Expression.UserDefinedStruct aTypeStruct = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(URN) + .name("a_type") + .addFields(ExpressionCreator.i32(false, 100)) + .build(); + + Expression.Literal.Builder bTypeBldr = Expression.Literal.newBuilder(); + Any bTypeAny = Any.pack(bTypeBldr.setString("b_val").build()); + UserDefinedLiteral bTypeAny1 = + ExpressionCreator.userDefinedLiteralAny(false, URN, "b_type", bTypeAny); + + io.substrait.expression.Expression.UserDefinedStruct bTypeStruct = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(URN) + .name("b_type") + .addFields(ExpressionCreator.string(false, "struct_b")) + .addFields(ExpressionCreator.bool(false, true)) + .build(); + + Rel originalRel = + b.project( + input -> List.of(aTypeAny1, aTypeStruct, bTypeAny1, bTypeStruct), + b.remap(4), + b.namedScan(List.of("example"), List.of("a"), List.of(N.userDefined(URN, "a_type")))); + + assertProtoRoundtrip(originalRel); } } diff --git a/spark/src/main/scala/io/substrait/debug/ExpressionToString.scala b/spark/src/main/scala/io/substrait/debug/ExpressionToString.scala index 5377f4257..133052fa1 100644 --- a/spark/src/main/scala/io/substrait/debug/ExpressionToString.scala +++ b/spark/src/main/scala/io/substrait/debug/ExpressionToString.scala @@ -76,8 +76,12 @@ class ExpressionToString extends DefaultExpressionVisitor[String] { s"${expr.declaration().key()}[${expr.outputType().accept(ToTypeString.INSTANCE)}]($args)" } + override def visit(expr: Expression.UserDefinedAny, context: EmptyVisitationContext): String = { + expr.toString + } + override def visit( - expr: Expression.UserDefinedLiteral, + expr: Expression.UserDefinedStruct, context: EmptyVisitationContext): String = { expr.toString } diff --git a/spark/src/main/scala/io/substrait/spark/DefaultExpressionVisitor.scala b/spark/src/main/scala/io/substrait/spark/DefaultExpressionVisitor.scala index 5f7137b14..07594d3bf 100644 --- a/spark/src/main/scala/io/substrait/spark/DefaultExpressionVisitor.scala +++ b/spark/src/main/scala/io/substrait/spark/DefaultExpressionVisitor.scala @@ -65,9 +65,9 @@ class DefaultExpressionVisitor[T] context: EmptyVisitationContext): T = e.accept(this, context) - override def visit( - userDefinedLiteral: Expression.UserDefinedLiteral, - context: EmptyVisitationContext): T = { - visitFallback(userDefinedLiteral, context) - } + override def visit(expr: Expression.UserDefinedAny, context: EmptyVisitationContext): T = + visitFallback(expr, context) + + override def visit(expr: Expression.UserDefinedStruct, context: EmptyVisitationContext): T = + visitFallback(expr, context) } From dcb152972b7ffe3efc1c854d64fe2b47d1f5bd08 Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Wed, 19 Nov 2025 17:38:44 -0500 Subject: [PATCH 02/11] some more refactor --- core/src/test/java/io/substrait/TestBase.java | 64 +++ .../type/proto/LiteralRoundtripTest.java | 47 +-- .../substrait/isthmus/CustomFunctionTest.java | 220 +--------- .../io/substrait/isthmus/PlanTestBase.java | 78 +++- .../isthmus/UserDefinedTypeLiteralTest.java | 395 ++++++++++++++++++ .../isthmus/utils/UserTypeFactory.java | 13 +- 6 files changed, 549 insertions(+), 268 deletions(-) create mode 100644 isthmus/src/test/java/io/substrait/isthmus/UserDefinedTypeLiteralTest.java diff --git a/core/src/test/java/io/substrait/TestBase.java b/core/src/test/java/io/substrait/TestBase.java index 3defbf78f..93ed8c9a3 100644 --- a/core/src/test/java/io/substrait/TestBase.java +++ b/core/src/test/java/io/substrait/TestBase.java @@ -1,8 +1,12 @@ package io.substrait; +import static io.substrait.expression.proto.ProtoExpressionConverter.EMPTY_TYPE; import static org.junit.jupiter.api.Assertions.assertEquals; import io.substrait.dsl.SubstraitBuilder; +import io.substrait.expression.Expression; +import io.substrait.expression.proto.ExpressionProtoConverter; +import io.substrait.expression.proto.ProtoExpressionConverter; import io.substrait.extension.DefaultExtensionCatalog; import io.substrait.extension.ExtensionCollector; import io.substrait.extension.SimpleExtension; @@ -10,6 +14,7 @@ import io.substrait.relation.Rel; import io.substrait.relation.RelProtoConverter; import io.substrait.type.TypeCreator; +import io.substrait.util.EmptyVisitationContext; public abstract class TestBase { @@ -30,4 +35,63 @@ protected void verifyRoundTrip(Rel rel) { Rel relReturned = protoRelConverter.from(protoRel); assertEquals(rel, relReturned); } + + /** + * Assert that a literal/expression roundtrips correctly through Proto serialization. Uses default + * (null) extension collection. + */ + protected void assertLiteralRoundtrip(Expression.Literal literal) { + assertLiteralRoundtrip(literal, null); + } + + /** + * Assert that a literal/expression roundtrips correctly through Proto serialization. + * + * @param literal the literal to roundtrip + * @param extensions custom extension collection, or null to use no extensions + */ + protected void assertLiteralRoundtrip( + Expression.Literal literal, + SimpleExtension.@org.jspecify.annotations.Nullable ExtensionCollection extensions) { + ExpressionProtoConverter toProto = + new ExpressionProtoConverter(functionCollector, relProtoConverter); + ProtoExpressionConverter fromProto = + new ProtoExpressionConverter(functionCollector, extensions, EMPTY_TYPE, protoRelConverter); + + io.substrait.proto.Expression protoExpr = + literal.accept(toProto, EmptyVisitationContext.INSTANCE); + io.substrait.proto.Expression.Literal protoLiteral = protoExpr.getLiteral(); + Expression.Literal roundtripped = fromProto.from(protoLiteral); + + assertEquals(literal, roundtripped); + } + + /** + * Assert that an expression roundtrips correctly through Proto serialization. Uses default (null) + * extension collection. + */ + protected void assertExpressionRoundtrip(Expression expression) { + assertExpressionRoundtrip(expression, null); + } + + /** + * Assert that an expression roundtrips correctly through Proto serialization. + * + * @param expression the expression to roundtrip + * @param extensions custom extension collection, or null to use no extensions + */ + protected void assertExpressionRoundtrip( + Expression expression, + SimpleExtension.@org.jspecify.annotations.Nullable ExtensionCollection extensions) { + ExpressionProtoConverter toProto = + new ExpressionProtoConverter(functionCollector, relProtoConverter); + ProtoExpressionConverter fromProto = + new ProtoExpressionConverter(functionCollector, extensions, EMPTY_TYPE, protoRelConverter); + + io.substrait.proto.Expression protoExpression = + expression.accept(toProto, EmptyVisitationContext.INSTANCE); + Expression roundtripped = fromProto.from(protoExpression); + + assertEquals(expression, roundtripped); + } } diff --git a/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java index 1ce55163a..70f0462fb 100644 --- a/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java @@ -1,16 +1,10 @@ package io.substrait.type.proto; -import static io.substrait.expression.proto.ProtoExpressionConverter.EMPTY_TYPE; -import static org.junit.jupiter.api.Assertions.assertEquals; - import com.google.protobuf.Any; import io.substrait.TestBase; import io.substrait.expression.Expression; import io.substrait.expression.ExpressionCreator; -import io.substrait.expression.proto.ExpressionProtoConverter; -import io.substrait.expression.proto.ProtoExpressionConverter; import io.substrait.extension.SimpleExtension; -import io.substrait.util.EmptyVisitationContext; import java.math.BigDecimal; import org.junit.jupiter.api.Test; @@ -24,10 +18,7 @@ public class LiteralRoundtripTest extends TestBase { void decimal() { io.substrait.expression.Expression.DecimalLiteral val = ExpressionCreator.decimal(false, BigDecimal.TEN, 10, 2); - ExpressionProtoConverter to = new ExpressionProtoConverter(null, null); - ProtoExpressionConverter from = - new ProtoExpressionConverter(null, null, EMPTY_TYPE, protoRelConverter); - assertEquals(val, from.from(val.accept(to, EmptyVisitationContext.INSTANCE))); + assertLiteralRoundtrip(val); } @Test @@ -42,12 +33,7 @@ void userDefinedLiteralWithAnyRepresentation() { Expression.UserDefinedLiteral val = ExpressionCreator.userDefinedLiteralAny(false, urn, typeName, anyValue); - ExpressionProtoConverter exprProtoConv = - new ExpressionProtoConverter(functionCollector, relProtoConverter); - ProtoExpressionConverter protoExprConv = - new ProtoExpressionConverter( - functionCollector, testExtensions, EMPTY_TYPE, protoRelConverter); - assertEquals(val, protoExprConv.from(exprProtoConv.toProto(val))); + assertLiteralRoundtrip(val, testExtensions); } @Test @@ -61,12 +47,7 @@ void userDefinedLiteralWithStructRepresentation() { Expression.UserDefinedLiteral val = ExpressionCreator.userDefinedLiteralStruct(false, urn, typeName, fields); - ExpressionProtoConverter exprProtoConv = - new ExpressionProtoConverter(functionCollector, relProtoConverter); - ProtoExpressionConverter protoExprConv = - new ProtoExpressionConverter( - functionCollector, testExtensions, EMPTY_TYPE, protoRelConverter); - assertEquals(val, protoExprConv.from(exprProtoConv.toProto(val))); + assertLiteralRoundtrip(val, testExtensions); } @Test @@ -93,16 +74,7 @@ void userDefinedLiteralWithAnyRepresentationAndTypeParameters() { Expression.UserDefinedLiteral val = ExpressionCreator.userDefinedLiteralAny(false, urn, typeName, typeParams, anyValue); - ExpressionProtoConverter exprProtoConv = - new ExpressionProtoConverter(functionCollector, relProtoConverter); - ProtoExpressionConverter protoExprConv = - new ProtoExpressionConverter( - functionCollector, testExtensions, EMPTY_TYPE, protoRelConverter); - - Expression.UserDefinedLiteral roundtripped = - (Expression.UserDefinedLiteral) protoExprConv.from(exprProtoConv.toProto(val)); - - assertEquals(val, roundtripped); + assertLiteralRoundtrip(val, testExtensions); } @Test @@ -128,15 +100,6 @@ void userDefinedLiteralWithStructRepresentationAndTypeParameters() { Expression.UserDefinedLiteral val = ExpressionCreator.userDefinedLiteralStruct(false, urn, typeName, typeParams, fields); - ExpressionProtoConverter exprProtoConv = - new ExpressionProtoConverter(functionCollector, relProtoConverter); - ProtoExpressionConverter protoExprConv = - new ProtoExpressionConverter( - functionCollector, testExtensions, EMPTY_TYPE, protoRelConverter); - - Expression.UserDefinedLiteral roundtripped = - (Expression.UserDefinedLiteral) protoExprConv.from(exprProtoConv.toProto(val)); - - assertEquals(val, roundtripped); + assertLiteralRoundtrip(val, testExtensions); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java b/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java index bde3f2a39..1e80cf264 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java @@ -7,7 +7,6 @@ import io.substrait.dsl.SubstraitBuilder; import io.substrait.expression.Expression.UserDefinedLiteral; import io.substrait.expression.ExpressionCreator; -import io.substrait.extension.ExtensionCollector; import io.substrait.extension.SimpleExtension; import io.substrait.isthmus.expression.AggregateFunctionConverter; import io.substrait.isthmus.expression.FunctionMappings; @@ -15,9 +14,7 @@ import io.substrait.isthmus.expression.WindowFunctionConverter; import io.substrait.isthmus.utils.UserTypeFactory; import io.substrait.proto.Expression; -import io.substrait.relation.ProtoRelConverter; import io.substrait.relation.Rel; -import io.substrait.relation.RelProtoConverter; import io.substrait.type.Type; import io.substrait.type.TypeCreator; import java.io.IOException; @@ -286,31 +283,6 @@ protected SubstraitRelNodeConverter createSubstraitRelNodeConverter(RelBuilder r } } - // Helper methods for roundtrip assertions - - /** Assert that a relation roundtrips correctly through Calcite conversion. */ - private void assertCalciteRoundtrip(Rel originalRel) { - RelNode calciteRel = substraitToCalcite.convert(originalRel); - Rel calciteRoundtrippedRel = calciteToSubstrait.apply(calciteRel); - assertEquals(originalRel, calciteRoundtrippedRel); - } - - /** Assert that a relation roundtrips correctly through Proto serialization. */ - private void assertProtoRoundtrip(Rel originalRel) { - ExtensionCollector extensionCollector = new ExtensionCollector(); - io.substrait.proto.Rel protoRel = - new RelProtoConverter(extensionCollector).toProto(originalRel); - Rel protoRoundtrippedRel = - new ProtoRelConverter(extensionCollector, extensionCollection).from(protoRel); - assertEquals(originalRel, protoRoundtrippedRel); - } - - /** Assert that a relation roundtrips correctly through both Calcite and Proto conversions. */ - private void assertRoundtrip(Rel originalRel) { - assertCalciteRoundtrip(originalRel); - assertProtoRoundtrip(originalRel); - } - @Test void customScalarFunctionRoundtrip() { // CREATE TABLE example(a TEXT) @@ -621,195 +593,7 @@ void customTypesLiteralInFunctionsRoundtrip() { b.remap(1), b.namedScan(List.of("example"), List.of("a"), List.of(N.userDefined(URN, "a_type")))); - assertRoundtrip(originalRel); - } - - @Test - void multipleDifferentUserDefinedAnyTypesProtoRoundtrip() { - // Test that UserDefinedAny literals with different payload types have different type names - // a_type wraps int, b_type wraps string - proto only - Expression.Literal.Builder bldr1 = Expression.Literal.newBuilder(); - Any anyValue1 = Any.pack(bldr1.setI32(100).build()); - UserDefinedLiteral aTypeLit = - ExpressionCreator.userDefinedLiteralAny(false, URN, "a_type", anyValue1); - - Expression.Literal.Builder bldr2 = Expression.Literal.newBuilder(); - Any anyValue2 = Any.pack(bldr2.setString("b_value").build()); - UserDefinedLiteral bTypeLit = - ExpressionCreator.userDefinedLiteralAny(false, URN, "b_type", anyValue2); - - Rel originalRel = - b.project( - input -> List.of(aTypeLit, bTypeLit), - b.remap(2), - b.namedScan(List.of("example"), List.of("a"), List.of(N.userDefined(URN, "a_type")))); - - assertProtoRoundtrip(originalRel); - } - - @Test - void userDefinedStructWithPrimitivesProtoRoundtrip() { - // Test UserDefinedStruct with various primitive field types - proto roundtrip only - io.substrait.expression.Expression.UserDefinedStruct val = - io.substrait.expression.Expression.UserDefinedStruct.builder() - .nullable(false) - .urn(URN) - .name("a_type") - .addFields(ExpressionCreator.i32(false, 42)) - .addFields(ExpressionCreator.string(false, "hello")) - .addFields(ExpressionCreator.bool(false, true)) - .addFields(ExpressionCreator.fp64(false, 2.718)) - .build(); - - Rel originalRel = - b.project( - input -> - List.of(b.scalarFn(URN, "to_b_type:u!a_type", R.userDefined(URN, "b_type"), val)), - b.remap(1), - b.namedScan(List.of("example"), List.of("a"), List.of(N.userDefined(URN, "a_type")))); - - assertProtoRoundtrip(originalRel); - } - - @Test - void userDefinedStructWithNestedStructProtoRoundtrip() { - // Test UserDefinedStruct with nested struct fields - proto roundtrip only - io.substrait.expression.Expression.StructLiteral innerStruct = - ExpressionCreator.struct( - false, ExpressionCreator.i32(false, 10), ExpressionCreator.string(false, "nested")); - - io.substrait.expression.Expression.UserDefinedStruct val = - io.substrait.expression.Expression.UserDefinedStruct.builder() - .nullable(false) - .urn(URN) - .name("a_type") - .addFields(ExpressionCreator.i32(false, 100)) - .addFields(innerStruct) - .addFields(ExpressionCreator.bool(false, false)) - .build(); - - Rel originalRel = - b.project( - input -> - List.of(b.scalarFn(URN, "to_b_type:u!a_type", R.userDefined(URN, "b_type"), val)), - b.remap(1), - b.namedScan(List.of("example"), List.of("a"), List.of(N.userDefined(URN, "a_type")))); - - assertProtoRoundtrip(originalRel); - } - - @Test - void multipleUserDefinedStructDifferentStructuresProtoRoundtrip() { - // Test multiple UserDefinedStruct types with different struct schemas - // a_type: {content: string} - // b_type: {content_int: i32, content_fp: fp64} - io.substrait.expression.Expression.UserDefinedStruct aTypeStruct = - io.substrait.expression.Expression.UserDefinedStruct.builder() - .nullable(false) - .urn(URN) - .name("a_type") - .addFields(ExpressionCreator.string(false, "hello")) - .build(); - - io.substrait.expression.Expression.UserDefinedStruct bTypeStruct = - io.substrait.expression.Expression.UserDefinedStruct.builder() - .nullable(false) - .urn(URN) - .name("b_type") - .addFields(ExpressionCreator.i32(false, 42)) - .addFields(ExpressionCreator.fp64(false, 3.14159)) - .build(); - - Rel originalRel = - b.project( - input -> List.of(aTypeStruct, bTypeStruct), - b.remap(2), - b.namedScan(List.of("example"), List.of("a"), List.of(N.userDefined(URN, "a_type")))); - - assertProtoRoundtrip(originalRel); - } - - @Test - void intermixedUserDefinedAnyAndStructProtoRoundtrip() { - // Test intermixing UserDefinedAny and UserDefinedStruct in the same query - Expression.Literal.Builder bldr1 = Expression.Literal.newBuilder(); - Any anyValue1 = Any.pack(bldr1.setI64(999L).build()); - UserDefinedLiteral anyLit1 = - ExpressionCreator.userDefinedLiteralAny(false, URN, "a_type", anyValue1); - - io.substrait.expression.Expression.UserDefinedStruct structLit1 = - io.substrait.expression.Expression.UserDefinedStruct.builder() - .nullable(false) - .urn(URN) - .name("a_type") - .addFields(ExpressionCreator.i32(false, 123)) - .addFields(ExpressionCreator.bool(false, false)) - .build(); - - Expression.Literal.Builder bldr2 = Expression.Literal.newBuilder(); - Any anyValue2 = Any.pack(bldr2.setString("mixed").build()); - UserDefinedLiteral anyLit2 = - ExpressionCreator.userDefinedLiteralAny(false, URN, "a_type", anyValue2); - - io.substrait.expression.Expression.UserDefinedStruct structLit2 = - io.substrait.expression.Expression.UserDefinedStruct.builder() - .nullable(false) - .urn(URN) - .name("a_type") - .addFields(ExpressionCreator.fp64(false, 1.414)) - .build(); - - Rel originalRel = - b.project( - input -> - List.of( - b.scalarFn(URN, "to_b_type:u!a_type", R.userDefined(URN, "b_type"), anyLit1), - b.scalarFn(URN, "to_b_type:u!a_type", R.userDefined(URN, "b_type"), structLit1), - b.scalarFn(URN, "to_b_type:u!a_type", R.userDefined(URN, "b_type"), anyLit2), - b.scalarFn( - URN, "to_b_type:u!a_type", R.userDefined(URN, "b_type"), structLit2)), - b.remap(4), - b.namedScan(List.of("example"), List.of("a"), List.of(N.userDefined(URN, "a_type")))); - - assertProtoRoundtrip(originalRel); - } - - @Test - void multipleDifferentUDTTypesWithAnyAndStructProtoRoundtrip() { - // Test multiple different UDT type names (a_type, b_type) with both Any and Struct - Expression.Literal.Builder aTypeBldr = Expression.Literal.newBuilder(); - Any aTypeAny = Any.pack(aTypeBldr.setI32(42).build()); - UserDefinedLiteral aTypeAny1 = - ExpressionCreator.userDefinedLiteralAny(false, URN, "a_type", aTypeAny); - - io.substrait.expression.Expression.UserDefinedStruct aTypeStruct = - io.substrait.expression.Expression.UserDefinedStruct.builder() - .nullable(false) - .urn(URN) - .name("a_type") - .addFields(ExpressionCreator.i32(false, 100)) - .build(); - - Expression.Literal.Builder bTypeBldr = Expression.Literal.newBuilder(); - Any bTypeAny = Any.pack(bTypeBldr.setString("b_val").build()); - UserDefinedLiteral bTypeAny1 = - ExpressionCreator.userDefinedLiteralAny(false, URN, "b_type", bTypeAny); - - io.substrait.expression.Expression.UserDefinedStruct bTypeStruct = - io.substrait.expression.Expression.UserDefinedStruct.builder() - .nullable(false) - .urn(URN) - .name("b_type") - .addFields(ExpressionCreator.string(false, "struct_b")) - .addFields(ExpressionCreator.bool(false, true)) - .build(); - - Rel originalRel = - b.project( - input -> List.of(aTypeAny1, aTypeStruct, bTypeAny1, bTypeStruct), - b.remap(4), - b.namedScan(List.of("example"), List.of("a"), List.of(N.userDefined(URN, "a_type")))); - - assertProtoRoundtrip(originalRel); + assertCalciteRoundtrip( + originalRel, substraitToCalcite, calciteToSubstrait, extensionCollection); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java b/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java index cce58e207..e1e4db3ba 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java +++ b/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java @@ -275,12 +275,50 @@ protected void assertFullRoundTripWithIdentityProjectionWorkaround( } /** - * Verifies that the given POJO can be converted: + * Verifies that a relation roundtrips correctly through Calcite conversion. This is isthmus' core + * responsibility: Substrait ↔ Calcite. * - *

    - *
  • From POJO to Proto and back - *
  • From POJO to Calcite and back - *
+ * @param rel the relation to roundtrip + */ + protected void assertCalciteRoundtrip(Rel rel) { + assertCalciteRoundtrip(rel, null, null, null); + } + + /** + * Verifies that a relation roundtrips correctly through Calcite conversion. This is isthmus' core + * responsibility: Substrait ↔ Calcite. + * + * @param rel the relation to roundtrip + * @param substraitToCalcite custom SubstraitToCalcite converter, or null to use default + * @param substraitRelVisitor custom SubstraitRelVisitor converter, or null to use default + * @param customExtensions custom extension collection, or null to use default + */ + protected void assertCalciteRoundtrip( + Rel rel, + @org.jspecify.annotations.Nullable SubstraitToCalcite substraitToCalcite, + @org.jspecify.annotations.Nullable SubstraitRelVisitor substraitRelVisitor, + SimpleExtension.@org.jspecify.annotations.Nullable ExtensionCollection customExtensions) { + SimpleExtension.ExtensionCollection exts = + customExtensions != null ? customExtensions : extensions; + + // Substrait -> Calcite + SubstraitToCalcite s2c = + substraitToCalcite != null ? substraitToCalcite : new SubstraitToCalcite(exts, typeFactory); + RelNode calcite = s2c.convert(rel); + + // Calcite -> Substrait + io.substrait.relation.Rel roundtripped = + substraitRelVisitor != null + ? substraitRelVisitor.apply(calcite) + : SubstraitRelVisitor.convert(calcite, exts); + + assertEquals(rel, roundtripped); + } + + /** + * Verifies that a relation can be converted through both proto and Calcite roundtrips. + * + * @param pojo1 the relation to roundtrip */ protected void assertFullRoundTrip(Rel pojo1) { // TODO: reuse the Plan.Root based assertFullRoundTrip by generating names @@ -315,6 +353,25 @@ protected void assertFullRoundTrip(Rel pojo1) { * */ protected void assertFullRoundTrip(Plan.Root pojo1) { + assertFullRoundTrip(pojo1, null, null); + } + + /** + * Verifies that the given POJO can be converted: + * + *
    + *
  • From POJO to Proto and back + *
  • From POJO to Calcite and back + *
+ * + * @param pojo1 the plan root to roundtrip + * @param substraitToCalcite custom SubstraitToCalcite converter, or null to use default + * @param substraitRelVisitor custom SubstraitRelVisitor converter, or null to use default + */ + protected void assertFullRoundTrip( + Plan.Root pojo1, + @org.jspecify.annotations.Nullable SubstraitToCalcite substraitToCalcite, + @org.jspecify.annotations.Nullable SubstraitRelVisitor substraitRelVisitor) { ExtensionCollector extensionCollector = new ExtensionCollector(); // Substrait POJO 1 -> Substrait Proto @@ -328,10 +385,17 @@ protected void assertFullRoundTrip(Plan.Root pojo1) { assertEquals(pojo1, pojo2); // Substrait POJO 2 -> Calcite - RelRoot calcite = new SubstraitToCalcite(extensions, typeFactory).convert(pojo2); + SubstraitToCalcite s2c = + substraitToCalcite != null + ? substraitToCalcite + : new SubstraitToCalcite(extensions, typeFactory); + RelRoot calcite = s2c.convert(pojo2); // Calcite -> Substrait POJO 3 - io.substrait.plan.Plan.Root pojo3 = SubstraitRelVisitor.convert(calcite, extensions); + io.substrait.plan.Plan.Root pojo3 = + substraitRelVisitor != null + ? SubstraitRelVisitor.convert(calcite, substraitRelVisitor) + : SubstraitRelVisitor.convert(calcite, extensions); // Verify that POJOs are the same assertEquals(pojo1, pojo3); diff --git a/isthmus/src/test/java/io/substrait/isthmus/UserDefinedTypeLiteralTest.java b/isthmus/src/test/java/io/substrait/isthmus/UserDefinedTypeLiteralTest.java new file mode 100644 index 000000000..59c513178 --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/UserDefinedTypeLiteralTest.java @@ -0,0 +1,395 @@ +package io.substrait.isthmus; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import com.google.protobuf.Any; +import io.substrait.dsl.SubstraitBuilder; +import io.substrait.expression.Expression.UserDefinedLiteral; +import io.substrait.expression.ExpressionCreator; +import io.substrait.extension.ExtensionCollector; +import io.substrait.extension.SimpleExtension; +import io.substrait.isthmus.expression.AggregateFunctionConverter; +import io.substrait.isthmus.expression.ScalarFunctionConverter; +import io.substrait.isthmus.expression.WindowFunctionConverter; +import io.substrait.isthmus.utils.UserTypeFactory; +import io.substrait.proto.Expression; +import io.substrait.relation.ProtoRelConverter; +import io.substrait.relation.Rel; +import io.substrait.relation.RelProtoConverter; +import io.substrait.type.Type; +import io.substrait.type.TypeCreator; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.List; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.tools.RelBuilder; +import org.jspecify.annotations.Nullable; +import org.junit.jupiter.api.Test; + +/** + * Tests for User-Defined Type literals, including both UserDefinedAny (protobuf Any-based) and + * UserDefinedStruct (struct-based) encoding strategies. + * + *

These tests verify proto serialization/deserialization of UDT literals (core's + * responsibility), using custom extensions defined in isthmus test resources. + */ +public class UserDefinedTypeLiteralTest extends PlanTestBase { + + // Define custom types in a "functions_custom.yaml" extension + static final String URN = "extension:substrait:functions_custom"; + static final String FUNCTIONS_CUSTOM; + + static { + try { + FUNCTIONS_CUSTOM = asString("extensions/functions_custom.yaml"); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + // Load custom extension into an ExtensionCollection + static final SimpleExtension.ExtensionCollection testExtensions = + SimpleExtension.load("custom.yaml", FUNCTIONS_CUSTOM); + + final SubstraitBuilder b = new SubstraitBuilder(testExtensions); + + // Create user-defined types + static final String aTypeName = "a_type"; + static final String bTypeName = "b_type"; + static final UserTypeFactory aTypeFactory = new UserTypeFactory(URN, aTypeName); + static final UserTypeFactory bTypeFactory = new UserTypeFactory(URN, bTypeName); + + // Mapper for user-defined types + static final UserTypeMapper userTypeMapper = + new UserTypeMapper() { + @Nullable + @Override + public Type toSubstrait(RelDataType relDataType) { + if (aTypeFactory.isTypeFromFactory(relDataType)) { + return TypeCreator.of(relDataType.isNullable()).userDefined(URN, aTypeName); + } + if (bTypeFactory.isTypeFromFactory(relDataType)) { + return TypeCreator.of(relDataType.isNullable()).userDefined(URN, bTypeName); + } + return null; + } + + @Nullable + @Override + public RelDataType toCalcite(Type.UserDefined type) { + if (type.urn().equals(URN)) { + if (type.name().equals(aTypeName)) { + return aTypeFactory.createCalcite(type.nullable()); + } + if (type.name().equals(bTypeName)) { + return bTypeFactory.createCalcite(type.nullable()); + } + } + return null; + } + }; + + TypeConverter typeConverter = new TypeConverter(userTypeMapper); + + // Create Function Converters that can handle the custom types + ScalarFunctionConverter scalarFunctionConverter = + new ScalarFunctionConverter( + testExtensions.scalarFunctions(), List.of(), typeFactory, typeConverter); + AggregateFunctionConverter aggregateFunctionConverter = + new AggregateFunctionConverter( + testExtensions.aggregateFunctions(), List.of(), typeFactory, typeConverter); + WindowFunctionConverter windowFunctionConverter = + new WindowFunctionConverter(testExtensions.windowFunctions(), typeFactory); + + final SubstraitToCalcite substraitToCalcite = + new CustomSubstraitToCalcite(testExtensions, typeFactory, typeConverter); + + // Create a SubstraitRelVisitor that uses the custom Function Converters + final SubstraitRelVisitor calciteToSubstrait = + new SubstraitRelVisitor( + typeFactory, + scalarFunctionConverter, + aggregateFunctionConverter, + windowFunctionConverter, + typeConverter, + ImmutableFeatureBoard.builder().build()); + + // Create a SubstraitToCalcite converter that has access to the custom Function Converters + class CustomSubstraitToCalcite extends SubstraitToCalcite { + + public CustomSubstraitToCalcite( + SimpleExtension.ExtensionCollection extensions, + RelDataTypeFactory typeFactory, + TypeConverter typeConverter) { + super(extensions, typeFactory, typeConverter); + } + + @Override + protected SubstraitRelNodeConverter createSubstraitRelNodeConverter(RelBuilder relBuilder) { + return new SubstraitRelNodeConverter( + typeFactory, + relBuilder, + scalarFunctionConverter, + aggregateFunctionConverter, + windowFunctionConverter, + typeConverter); + } + } + + /** + * Verifies proto roundtrip for a relation. This test class needs this method locally since it's + * testing proto serialization (core's responsibility) but must reside in isthmus to access custom + * test extensions. + */ + private void verifyProtoRoundTrip(Rel rel) { + ExtensionCollector functionCollector = new ExtensionCollector(); + RelProtoConverter relProtoConverter = new RelProtoConverter(functionCollector); + ProtoRelConverter protoRelConverter = new ProtoRelConverter(functionCollector, testExtensions); + + io.substrait.proto.Rel protoRel = relProtoConverter.toProto(rel); + Rel relReturned = protoRelConverter.from(protoRel); + assertEquals(rel, relReturned); + } + + @Test + void multipleDifferentUserDefinedAnyTypesProtoRoundtrip() { + // Test that UserDefinedAny literals with different payload types have different type names + // a_type wraps int, b_type wraps string - proto only + Expression.Literal.Builder bldr1 = Expression.Literal.newBuilder(); + Any anyValue1 = Any.pack(bldr1.setI32(100).build()); + UserDefinedLiteral aTypeLit = + ExpressionCreator.userDefinedLiteralAny(false, URN, "a_type", anyValue1); + + Expression.Literal.Builder bldr2 = Expression.Literal.newBuilder(); + Any anyValue2 = Any.pack(bldr2.setString("b_value").build()); + UserDefinedLiteral bTypeLit = + ExpressionCreator.userDefinedLiteralAny(false, URN, "b_type", anyValue2); + + Rel originalRel = + b.project( + input -> List.of(aTypeLit, bTypeLit), + b.remap(1, 2), // Select both expressions + b.namedScan(List.of("example"), List.of("a"), List.of(N.userDefined(URN, "a_type")))); + + verifyProtoRoundTrip(originalRel); + } + + @Test + void singleUserDefinedAnyCalciteRoundtrip() { + // Test that a single UserDefinedAny literal can roundtrip through Calcite + Expression.Literal.Builder bldr1 = Expression.Literal.newBuilder(); + Any anyValue1 = Any.pack(bldr1.setI32(100).build()); + UserDefinedLiteral aTypeLit = + ExpressionCreator.userDefinedLiteralAny(false, URN, "a_type", anyValue1); + + Rel originalRel = + b.project( + input -> List.of(aTypeLit), + b.remap(1), + b.namedScan(List.of("example"), List.of("a"), List.of(N.userDefined(URN, "a_type")))); + + assertCalciteRoundtrip(originalRel, substraitToCalcite, calciteToSubstrait, testExtensions); + } + + @Test + void singleUserDefinedStructCalciteRoundtrip() { + // Test that a single UserDefinedStruct literal can roundtrip through Calcite + io.substrait.expression.Expression.UserDefinedStruct val = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(URN) + .name("a_type") + .addFields(ExpressionCreator.i32(false, 42)) + .addFields(ExpressionCreator.string(false, "hello")) + .build(); + + Rel originalRel = + b.project( + input -> List.of(val), + b.remap(1), + b.namedScan(List.of("example"), List.of("a"), List.of(N.userDefined(URN, "a_type")))); + + assertCalciteRoundtrip(originalRel, substraitToCalcite, calciteToSubstrait, testExtensions); + } + + @Test + void multipleDifferentUserDefinedAnyTypesCalciteRoundtrip() { + // Test that multiple UserDefinedAny literals with different types can roundtrip through Calcite + Expression.Literal.Builder bldr1 = Expression.Literal.newBuilder(); + Any anyValue1 = Any.pack(bldr1.setI32(100).build()); + UserDefinedLiteral aTypeLit = + ExpressionCreator.userDefinedLiteralAny(false, URN, "a_type", anyValue1); + + Expression.Literal.Builder bldr2 = Expression.Literal.newBuilder(); + Any anyValue2 = Any.pack(bldr2.setString("b_value").build()); + UserDefinedLiteral bTypeLit = + ExpressionCreator.userDefinedLiteralAny(false, URN, "b_type", anyValue2); + + Rel originalRel = + b.project( + input -> List.of(aTypeLit, bTypeLit), + b.remap(1, 2), // Select both expressions + b.namedScan(List.of("example"), List.of("a"), List.of(N.userDefined(URN, "a_type")))); + + assertCalciteRoundtrip(originalRel, substraitToCalcite, calciteToSubstrait, testExtensions); + } + + @Test + void userDefinedStructWithPrimitivesProtoRoundtrip() { + // Test UserDefinedStruct with various primitive field types - proto roundtrip only + io.substrait.expression.Expression.UserDefinedStruct val = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(URN) + .name("a_type") + .addFields(ExpressionCreator.i32(false, 42)) + .addFields(ExpressionCreator.string(false, "hello")) + .addFields(ExpressionCreator.bool(false, true)) + .addFields(ExpressionCreator.fp64(false, 2.718)) + .build(); + + Rel originalRel = + b.project( + input -> List.of(val), + b.remap(1), + b.namedScan(List.of("example"), List.of("a"), List.of(N.userDefined(URN, "a_type")))); + + verifyProtoRoundTrip(originalRel); + } + + @Test + void userDefinedStructWithNestedStructProtoRoundtrip() { + // Test UserDefinedStruct with nested struct fields - proto roundtrip only + io.substrait.expression.Expression.StructLiteral innerStruct = + ExpressionCreator.struct( + false, ExpressionCreator.i32(false, 10), ExpressionCreator.string(false, "nested")); + + io.substrait.expression.Expression.UserDefinedStruct val = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(URN) + .name("a_type") + .addFields(ExpressionCreator.i32(false, 100)) + .addFields(innerStruct) + .addFields(ExpressionCreator.bool(false, false)) + .build(); + + Rel originalRel = + b.project( + input -> List.of(val), + b.remap(1), + b.namedScan(List.of("example"), List.of("a"), List.of(N.userDefined(URN, "a_type")))); + + verifyProtoRoundTrip(originalRel); + } + + @Test + void multipleUserDefinedStructDifferentStructuresProtoRoundtrip() { + // Test multiple UserDefinedStruct types with different struct schemas + // a_type: {content: string} + // b_type: {content_int: i32, content_fp: fp64} + io.substrait.expression.Expression.UserDefinedStruct aTypeStruct = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(URN) + .name("a_type") + .addFields(ExpressionCreator.string(false, "hello")) + .build(); + + io.substrait.expression.Expression.UserDefinedStruct bTypeStruct = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(URN) + .name("b_type") + .addFields(ExpressionCreator.i32(false, 42)) + .addFields(ExpressionCreator.fp64(false, 3.14159)) + .build(); + + Rel originalRel = + b.project( + input -> List.of(aTypeStruct, bTypeStruct), + b.remap(2), + b.namedScan(List.of("example"), List.of("a"), List.of(N.userDefined(URN, "a_type")))); + + verifyProtoRoundTrip(originalRel); + } + + @Test + void intermixedUserDefinedAnyAndStructProtoRoundtrip() { + // Test intermixing UserDefinedAny and UserDefinedStruct in the same query + Expression.Literal.Builder bldr1 = Expression.Literal.newBuilder(); + Any anyValue1 = Any.pack(bldr1.setI64(999L).build()); + UserDefinedLiteral anyLit1 = + ExpressionCreator.userDefinedLiteralAny(false, URN, "a_type", anyValue1); + + io.substrait.expression.Expression.UserDefinedStruct structLit1 = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(URN) + .name("a_type") + .addFields(ExpressionCreator.i32(false, 123)) + .addFields(ExpressionCreator.bool(false, false)) + .build(); + + Expression.Literal.Builder bldr2 = Expression.Literal.newBuilder(); + Any anyValue2 = Any.pack(bldr2.setString("mixed").build()); + UserDefinedLiteral anyLit2 = + ExpressionCreator.userDefinedLiteralAny(false, URN, "a_type", anyValue2); + + io.substrait.expression.Expression.UserDefinedStruct structLit2 = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(URN) + .name("a_type") + .addFields(ExpressionCreator.fp64(false, 1.414)) + .build(); + + Rel originalRel = + b.project( + input -> List.of(anyLit1, structLit1, anyLit2, structLit2), + b.remap(4), + b.namedScan(List.of("example"), List.of("a"), List.of(N.userDefined(URN, "a_type")))); + + verifyProtoRoundTrip(originalRel); + } + + @Test + void multipleDifferentUDTTypesWithAnyAndStructProtoRoundtrip() { + // Test multiple different UDT type names (a_type, b_type) with both Any and Struct + Expression.Literal.Builder aTypeBldr = Expression.Literal.newBuilder(); + Any aTypeAny = Any.pack(aTypeBldr.setI32(42).build()); + UserDefinedLiteral aTypeAny1 = + ExpressionCreator.userDefinedLiteralAny(false, URN, "a_type", aTypeAny); + + io.substrait.expression.Expression.UserDefinedStruct aTypeStruct = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(URN) + .name("a_type") + .addFields(ExpressionCreator.i32(false, 100)) + .build(); + + Expression.Literal.Builder bTypeBldr = Expression.Literal.newBuilder(); + Any bTypeAny = Any.pack(bTypeBldr.setString("b_val").build()); + UserDefinedLiteral bTypeAny1 = + ExpressionCreator.userDefinedLiteralAny(false, URN, "b_type", bTypeAny); + + io.substrait.expression.Expression.UserDefinedStruct bTypeStruct = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(URN) + .name("b_type") + .addFields(ExpressionCreator.string(false, "struct_b")) + .addFields(ExpressionCreator.bool(false, true)) + .build(); + + Rel originalRel = + b.project( + input -> List.of(aTypeAny1, aTypeStruct, bTypeAny1, bTypeStruct), + b.remap(4), + b.namedScan(List.of("example"), List.of("a"), List.of(N.userDefined(URN, "a_type")))); + + verifyProtoRoundTrip(originalRel); + } +} diff --git a/isthmus/src/test/java/io/substrait/isthmus/utils/UserTypeFactory.java b/isthmus/src/test/java/io/substrait/isthmus/utils/UserTypeFactory.java index 2c90f133d..f1004261a 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/utils/UserTypeFactory.java +++ b/isthmus/src/test/java/io/substrait/isthmus/utils/UserTypeFactory.java @@ -36,7 +36,18 @@ public Type createSubstrait(boolean nullable) { } public boolean isTypeFromFactory(RelDataType type) { - return type == N || type == R; + // Use value-based comparison instead of reference equality to handle + // cases where the same type is created by different factory instances + if (type == N || type == R) { + return true; + } + // Check if this is a type with the same name and SqlTypeName.OTHER + if (type != null + && type.getSqlTypeName() == SqlTypeName.OTHER + && type.toString().equals(this.name)) { + return true; + } + return false; } private static class InnerType extends RelDataTypeImpl { From e9ea6cd99c9211bc8b3182a97f865ad943c37cdc Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Wed, 19 Nov 2025 22:05:55 -0500 Subject: [PATCH 03/11] refactor: small simplification --- core/src/test/java/io/substrait/TestBase.java | 52 +++---------------- .../type/proto/LiteralRoundtripTest.java | 10 ++-- 2 files changed, 11 insertions(+), 51 deletions(-) diff --git a/core/src/test/java/io/substrait/TestBase.java b/core/src/test/java/io/substrait/TestBase.java index 93ed8c9a3..b88b177c0 100644 --- a/core/src/test/java/io/substrait/TestBase.java +++ b/core/src/test/java/io/substrait/TestBase.java @@ -14,7 +14,6 @@ import io.substrait.relation.Rel; import io.substrait.relation.RelProtoConverter; import io.substrait.type.TypeCreator; -import io.substrait.util.EmptyVisitationContext; public abstract class TestBase { @@ -36,62 +35,23 @@ protected void verifyRoundTrip(Rel rel) { assertEquals(rel, relReturned); } - /** - * Assert that a literal/expression roundtrips correctly through Proto serialization. Uses default - * (null) extension collection. - */ - protected void assertLiteralRoundtrip(Expression.Literal literal) { - assertLiteralRoundtrip(literal, null); - } - - /** - * Assert that a literal/expression roundtrips correctly through Proto serialization. - * - * @param literal the literal to roundtrip - * @param extensions custom extension collection, or null to use no extensions - */ - protected void assertLiteralRoundtrip( - Expression.Literal literal, - SimpleExtension.@org.jspecify.annotations.Nullable ExtensionCollection extensions) { - ExpressionProtoConverter toProto = - new ExpressionProtoConverter(functionCollector, relProtoConverter); - ProtoExpressionConverter fromProto = - new ProtoExpressionConverter(functionCollector, extensions, EMPTY_TYPE, protoRelConverter); - - io.substrait.proto.Expression protoExpr = - literal.accept(toProto, EmptyVisitationContext.INSTANCE); - io.substrait.proto.Expression.Literal protoLiteral = protoExpr.getLiteral(); - Expression.Literal roundtripped = fromProto.from(protoLiteral); - - assertEquals(literal, roundtripped); - } - - /** - * Assert that an expression roundtrips correctly through Proto serialization. Uses default (null) - * extension collection. - */ - protected void assertExpressionRoundtrip(Expression expression) { - assertExpressionRoundtrip(expression, null); - } - /** * Assert that an expression roundtrips correctly through Proto serialization. * * @param expression the expression to roundtrip * @param extensions custom extension collection, or null to use no extensions */ - protected void assertExpressionRoundtrip( + protected void verifyRoundTrip( Expression expression, SimpleExtension.@org.jspecify.annotations.Nullable ExtensionCollection extensions) { - ExpressionProtoConverter toProto = + ExpressionProtoConverter expressionProtoConverter = new ExpressionProtoConverter(functionCollector, relProtoConverter); - ProtoExpressionConverter fromProto = + ProtoExpressionConverter protoExpressionConverter = new ProtoExpressionConverter(functionCollector, extensions, EMPTY_TYPE, protoRelConverter); - io.substrait.proto.Expression protoExpression = - expression.accept(toProto, EmptyVisitationContext.INSTANCE); - Expression roundtripped = fromProto.from(protoExpression); + io.substrait.proto.Expression protoExpression = expressionProtoConverter.toProto(expression); + Expression expressionReturned = protoExpressionConverter.from(protoExpression); - assertEquals(expression, roundtripped); + assertEquals(expression, expressionReturned); } } diff --git a/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java index 70f0462fb..17665d6af 100644 --- a/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java @@ -18,7 +18,7 @@ public class LiteralRoundtripTest extends TestBase { void decimal() { io.substrait.expression.Expression.DecimalLiteral val = ExpressionCreator.decimal(false, BigDecimal.TEN, 10, 2); - assertLiteralRoundtrip(val); + verifyRoundTrip(val, null); } @Test @@ -33,7 +33,7 @@ void userDefinedLiteralWithAnyRepresentation() { Expression.UserDefinedLiteral val = ExpressionCreator.userDefinedLiteralAny(false, urn, typeName, anyValue); - assertLiteralRoundtrip(val, testExtensions); + verifyRoundTrip(val, testExtensions); } @Test @@ -47,7 +47,7 @@ void userDefinedLiteralWithStructRepresentation() { Expression.UserDefinedLiteral val = ExpressionCreator.userDefinedLiteralStruct(false, urn, typeName, fields); - assertLiteralRoundtrip(val, testExtensions); + verifyRoundTrip(val, testExtensions); } @Test @@ -74,7 +74,7 @@ void userDefinedLiteralWithAnyRepresentationAndTypeParameters() { Expression.UserDefinedLiteral val = ExpressionCreator.userDefinedLiteralAny(false, urn, typeName, typeParams, anyValue); - assertLiteralRoundtrip(val, testExtensions); + verifyRoundTrip(val, testExtensions); } @Test @@ -100,6 +100,6 @@ void userDefinedLiteralWithStructRepresentationAndTypeParameters() { Expression.UserDefinedLiteral val = ExpressionCreator.userDefinedLiteralStruct(false, urn, typeName, typeParams, fields); - assertLiteralRoundtrip(val, testExtensions); + verifyRoundTrip(val, testExtensions); } } From 5a8bc402c1e226f2f2d6ff5258c498c02e0f0789 Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Wed, 19 Nov 2025 22:17:41 -0500 Subject: [PATCH 04/11] test: simplify roundtrip test --- .../extension/DefaultExtensionCatalog.java | 3 + core/src/test/java/io/substrait/TestBase.java | 23 ++--- .../type/proto/LiteralRoundtripTest.java | 87 ++++--------------- 3 files changed, 27 insertions(+), 86 deletions(-) diff --git a/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java b/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java index 89aad954e..31214878c 100644 --- a/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java +++ b/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java @@ -22,6 +22,7 @@ public class DefaultExtensionCatalog { "extension:io.substrait:functions_rounding_decimal"; public static final String FUNCTIONS_SET = "extension:io.substrait:functions_set"; public static final String FUNCTIONS_STRING = "extension:io.substrait:functions_string"; + public static final String EXTENSION_TYPES = "extension:io.substrait:extension_types"; public static final SimpleExtension.ExtensionCollection DEFAULT_COLLECTION = loadDefaultCollection(); @@ -44,6 +45,8 @@ private static SimpleExtension.ExtensionCollection loadDefaultCollection() { .map(c -> String.format("/functions_%s.yaml", c)) .collect(Collectors.toList()); + defaultFiles.add("/extension_types.yaml"); + return SimpleExtension.load(defaultFiles); } } diff --git a/core/src/test/java/io/substrait/TestBase.java b/core/src/test/java/io/substrait/TestBase.java index b88b177c0..b5f1dd4f1 100644 --- a/core/src/test/java/io/substrait/TestBase.java +++ b/core/src/test/java/io/substrait/TestBase.java @@ -29,29 +29,22 @@ public abstract class TestBase { protected ProtoRelConverter protoRelConverter = new ProtoRelConverter(functionCollector, defaultExtensionCollection); + protected ExpressionProtoConverter expressionProtoConverter = + new ExpressionProtoConverter(functionCollector, relProtoConverter); + + protected ProtoExpressionConverter protoExpressionConverter = + new ProtoExpressionConverter( + functionCollector, defaultExtensionCollection, EMPTY_TYPE, protoRelConverter); + protected void verifyRoundTrip(Rel rel) { io.substrait.proto.Rel protoRel = relProtoConverter.toProto(rel); Rel relReturned = protoRelConverter.from(protoRel); assertEquals(rel, relReturned); } - /** - * Assert that an expression roundtrips correctly through Proto serialization. - * - * @param expression the expression to roundtrip - * @param extensions custom extension collection, or null to use no extensions - */ - protected void verifyRoundTrip( - Expression expression, - SimpleExtension.@org.jspecify.annotations.Nullable ExtensionCollection extensions) { - ExpressionProtoConverter expressionProtoConverter = - new ExpressionProtoConverter(functionCollector, relProtoConverter); - ProtoExpressionConverter protoExpressionConverter = - new ProtoExpressionConverter(functionCollector, extensions, EMPTY_TYPE, protoRelConverter); - + protected void verifyRoundTrip(Expression expression) { io.substrait.proto.Expression protoExpression = expressionProtoConverter.toProto(expression); Expression expressionReturned = protoExpressionConverter.from(protoExpression); - assertEquals(expression, expressionReturned); } } diff --git a/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java index 17665d6af..be6e8dcb7 100644 --- a/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java @@ -4,102 +4,47 @@ import io.substrait.TestBase; import io.substrait.expression.Expression; import io.substrait.expression.ExpressionCreator; -import io.substrait.extension.SimpleExtension; +import io.substrait.extension.DefaultExtensionCatalog; import java.math.BigDecimal; import org.junit.jupiter.api.Test; public class LiteralRoundtripTest extends TestBase { - // Load custom extensions for UserDefined literal tests - private static final SimpleExtension.ExtensionCollection testExtensions = - SimpleExtension.load(java.util.Arrays.asList("/extensions/custom_extensions.yaml")); - @Test void decimal() { io.substrait.expression.Expression.DecimalLiteral val = ExpressionCreator.decimal(false, BigDecimal.TEN, 10, 2); - verifyRoundTrip(val, null); + verifyRoundTrip(val); } @Test void userDefinedLiteralWithAnyRepresentation() { + // Create a struct literal inline representing a point with latitude=42, longitude=100 + io.substrait.proto.Expression.Literal.Struct pointStruct = + io.substrait.proto.Expression.Literal.Struct.newBuilder() + .addFields(io.substrait.proto.Expression.Literal.newBuilder().setI32(42)) + .addFields(io.substrait.proto.Expression.Literal.newBuilder().setI32(100)) + .build(); io.substrait.proto.Expression.Literal innerLiteral = - io.substrait.proto.Expression.Literal.newBuilder().setI32(42).build(); + io.substrait.proto.Expression.Literal.newBuilder().setStruct(pointStruct).build(); Any anyValue = Any.pack(innerLiteral); - String urn = "extension:test:custom_extensions"; - String typeName = "customType1"; - Expression.UserDefinedLiteral val = - ExpressionCreator.userDefinedLiteralAny(false, urn, typeName, anyValue); + ExpressionCreator.userDefinedLiteralAny( + false, DefaultExtensionCatalog.EXTENSION_TYPES, "point", anyValue); - verifyRoundTrip(val, testExtensions); + verifyRoundTrip(val); } @Test void userDefinedLiteralWithStructRepresentation() { - String urn = "extension:test:custom_extensions"; - String typeName = "customType2"; - java.util.List fields = java.util.Arrays.asList( - ExpressionCreator.i32(false, 42), ExpressionCreator.string(false, "test")); - Expression.UserDefinedLiteral val = - ExpressionCreator.userDefinedLiteralStruct(false, urn, typeName, fields); - - verifyRoundTrip(val, testExtensions); - } - - @Test - void userDefinedLiteralWithAnyRepresentationAndTypeParameters() { - io.substrait.proto.Expression.Literal innerLiteral = - io.substrait.proto.Expression.Literal.newBuilder().setI32(42).build(); - Any anyValue = Any.pack(innerLiteral); - - String urn = "extension:test:custom_extensions"; - String typeName = "customType1"; - - java.util.List typeParams = - java.util.Arrays.asList( - io.substrait.proto.Type.Parameter.newBuilder() - .setDataType( - io.substrait.proto.Type.newBuilder() - .setI32( - io.substrait.proto.Type.I32 - .newBuilder() - .setNullability( - io.substrait.proto.Type.Nullability.NULLABILITY_REQUIRED))) - .build()); - - Expression.UserDefinedLiteral val = - ExpressionCreator.userDefinedLiteralAny(false, urn, typeName, typeParams, anyValue); - - verifyRoundTrip(val, testExtensions); - } - - @Test - void userDefinedLiteralWithStructRepresentationAndTypeParameters() { - String urn = "extension:test:custom_extensions"; - String typeName = "customType2"; - - java.util.List fields = - java.util.Arrays.asList( - ExpressionCreator.i32(false, 42), ExpressionCreator.string(false, "test")); - - java.util.List typeParams = - java.util.Arrays.asList( - io.substrait.proto.Type.Parameter.newBuilder() - .setDataType( - io.substrait.proto.Type.newBuilder() - .setString( - io.substrait.proto.Type.String.newBuilder() - .setNullability( - io.substrait.proto.Type.Nullability.NULLABILITY_NULLABLE))) - .build()); - + ExpressionCreator.i32(false, 42), ExpressionCreator.i32(false, 100)); Expression.UserDefinedLiteral val = - ExpressionCreator.userDefinedLiteralStruct(false, urn, typeName, typeParams, fields); + ExpressionCreator.userDefinedLiteralStruct( + false, DefaultExtensionCatalog.EXTENSION_TYPES, "point", fields); - verifyRoundTrip(val, testExtensions); + verifyRoundTrip(val); } } From ad94adc55bc501a15dd50bf8388843424f302f8e Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Wed, 19 Nov 2025 22:21:44 -0500 Subject: [PATCH 05/11] revert: drop getExtensionCollection() impl from ExtensionCollector --- .../extension/ExtensionCollector.java | 71 ------- .../ExtensionCollectorGetCollectionTest.java | 173 ------------------ 2 files changed, 244 deletions(-) delete mode 100644 core/src/test/java/io/substrait/extension/ExtensionCollectorGetCollectionTest.java diff --git a/core/src/main/java/io/substrait/extension/ExtensionCollector.java b/core/src/main/java/io/substrait/extension/ExtensionCollector.java index 1ac3b2f1b..7ad07a6b1 100644 --- a/core/src/main/java/io/substrait/extension/ExtensionCollector.java +++ b/core/src/main/java/io/substrait/extension/ExtensionCollector.java @@ -63,77 +63,6 @@ public int getTypeReference(SimpleExtension.TypeAnchor typeAnchor) { return counter; } - /** - * Returns an ExtensionCollection containing only the types and functions that have been tracked - * by this collector. This provides a minimal collection with exactly what was used during - * serialization. - * - *

This collection contains: - * - *

    - *
  • Only the types that were referenced via {@link #getTypeReference} - *
  • Only the functions that were referenced via {@link #getFunctionReference} - *
  • URI/URN mappings for only the used extension URNs - *
- * - *

Types from the catalog are resolved, while custom UserDefined types (not in the catalog) are - * created via {@link SimpleExtension.Type#of(String, String)}. - * - * @return an ExtensionCollection with only the used types, functions, and URI/URN mappings - */ - public SimpleExtension.ExtensionCollection getExtensionCollection() { - java.util.List types = new ArrayList<>(); - java.util.List scalarFunctions = new ArrayList<>(); - java.util.List aggregateFunctions = new ArrayList<>(); - java.util.List windowFunctions = new ArrayList<>(); - - java.util.Set usedUrns = new java.util.HashSet<>(); - - for (Map.Entry entry : typeMap.forwardEntrySet()) { - SimpleExtension.TypeAnchor anchor = entry.getValue(); - usedUrns.add(anchor.urn()); - if (extensionCollection.hasType(anchor)) { - types.add(extensionCollection.getType(anchor)); - } else { - types.add(SimpleExtension.Type.of(anchor.urn(), anchor.key())); - } - } - - for (Map.Entry entry : funcMap.forwardEntrySet()) { - SimpleExtension.FunctionAnchor anchor = entry.getValue(); - usedUrns.add(anchor.urn()); - - if (extensionCollection.hasScalarFunction(anchor)) { - scalarFunctions.add(extensionCollection.getScalarFunction(anchor)); - } else if (extensionCollection.hasAggregateFunction(anchor)) { - aggregateFunctions.add(extensionCollection.getAggregateFunction(anchor)); - } else if (extensionCollection.hasWindowFunction(anchor)) { - windowFunctions.add(extensionCollection.getWindowFunction(anchor)); - } else { - throw new IllegalArgumentException( - String.format( - "Function %s::%s was tracked but not found in catalog as scalar, aggregate, or window function", - anchor.urn(), anchor.key())); - } - } - - BidiMap uriUrnMap = new BidiMap<>(); - for (String urn : usedUrns) { - String uri = extensionCollection.getUriFromUrn(urn); - if (uri != null) { - uriUrnMap.put(uri, urn); - } - } - - return SimpleExtension.ExtensionCollection.builder() - .addAllTypes(types) - .addAllScalarFunctions(scalarFunctions) - .addAllAggregateFunctions(aggregateFunctions) - .addAllWindowFunctions(windowFunctions) - .uriUrnMap(uriUrnMap) - .build(); - } - public void addExtensionsToPlan(Plan.Builder builder) { SimpleExtensions simpleExtensions = getExtensions(); diff --git a/core/src/test/java/io/substrait/extension/ExtensionCollectorGetCollectionTest.java b/core/src/test/java/io/substrait/extension/ExtensionCollectorGetCollectionTest.java deleted file mode 100644 index eca391891..000000000 --- a/core/src/test/java/io/substrait/extension/ExtensionCollectorGetCollectionTest.java +++ /dev/null @@ -1,173 +0,0 @@ -package io.substrait.extension; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertTrue; - -import org.junit.jupiter.api.Test; - -public class ExtensionCollectorGetCollectionTest { - - @Test - public void getExtensionCollection_containsOnlyTrackedTypes() { - BidiMap uriUrnMap = new BidiMap<>(); - uriUrnMap.put("https://example.com/catalog", "extension:example:catalog"); - - SimpleExtension.Type catalogType1 = - SimpleExtension.Type.of("extension:example:catalog", "type1"); - SimpleExtension.Type catalogType2 = - SimpleExtension.Type.of("extension:example:catalog", "type2"); - SimpleExtension.Type catalogType3 = - SimpleExtension.Type.of("extension:example:catalog", "type3"); - - SimpleExtension.ExtensionCollection catalog = - SimpleExtension.ExtensionCollection.builder() - .addTypes(catalogType1, catalogType2, catalogType3) - .uriUrnMap(uriUrnMap) - .build(); - - ExtensionCollector collector = new ExtensionCollector(catalog); - - collector.getTypeReference(catalogType1.getAnchor()); - collector.getTypeReference(catalogType2.getAnchor()); - - SimpleExtension.ExtensionCollection result = collector.getExtensionCollection(); - - assertEquals(2, result.types().size()); - assertEquals("type1", result.types().get(0).name()); - assertEquals("type2", result.types().get(1).name()); - } - - @Test - public void getExtensionCollection_containsOnlyTrackedFunctions() { - BidiMap uriUrnMap = new BidiMap<>(); - uriUrnMap.put("https://example.com/catalog", "extension:example:catalog"); - - SimpleExtension.ScalarFunctionVariant func1 = - ImmutableSimpleExtension.ScalarFunctionVariant.builder() - .urn("extension:example:catalog") - .name("func1") - .returnType(io.substrait.function.TypeExpressionCreator.REQUIRED.BOOLEAN) - .build(); - - SimpleExtension.ScalarFunctionVariant func2 = - ImmutableSimpleExtension.ScalarFunctionVariant.builder() - .urn("extension:example:catalog") - .name("func2") - .returnType(io.substrait.function.TypeExpressionCreator.REQUIRED.BOOLEAN) - .build(); - - SimpleExtension.ScalarFunctionVariant func3 = - ImmutableSimpleExtension.ScalarFunctionVariant.builder() - .urn("extension:example:catalog") - .name("func3") - .returnType(io.substrait.function.TypeExpressionCreator.REQUIRED.BOOLEAN) - .build(); - - SimpleExtension.ExtensionCollection catalog = - SimpleExtension.ExtensionCollection.builder() - .addScalarFunctions(func1, func2, func3) - .uriUrnMap(uriUrnMap) - .build(); - - ExtensionCollector collector = new ExtensionCollector(catalog); - - collector.getFunctionReference(func1); - collector.getFunctionReference(func3); - - SimpleExtension.ExtensionCollection result = collector.getExtensionCollection(); - - assertEquals(2, result.scalarFunctions().size()); - assertEquals("func1", result.scalarFunctions().get(0).name()); - assertEquals("func3", result.scalarFunctions().get(1).name()); - } - - @Test - public void getExtensionCollection_includesCustomTypes() { - SimpleExtension.ExtensionCollection emptyCatalog = - SimpleExtension.ExtensionCollection.builder().build(); - - ExtensionCollector collector = new ExtensionCollector(emptyCatalog); - - SimpleExtension.TypeAnchor customType = - SimpleExtension.TypeAnchor.of("extension:test:custom", "MyCustomType"); - - collector.getTypeReference(customType); - - SimpleExtension.ExtensionCollection result = collector.getExtensionCollection(); - - assertEquals(1, result.types().size()); - assertEquals("MyCustomType", result.types().get(0).name()); - assertEquals("extension:test:custom", result.types().get(0).urn()); - } - - @Test - public void getExtensionCollection_includesOnlyUsedUriUrnMappings() { - BidiMap uriUrnMap = new BidiMap<>(); - uriUrnMap.put("https://example.com/urn1", "extension:example:urn1"); - uriUrnMap.put("https://example.com/urn2", "extension:example:urn2"); - uriUrnMap.put("https://example.com/urn3", "extension:example:urn3"); - - SimpleExtension.Type type1 = SimpleExtension.Type.of("extension:example:urn1", "type1"); - SimpleExtension.Type type2 = SimpleExtension.Type.of("extension:example:urn2", "type2"); - SimpleExtension.Type type3 = SimpleExtension.Type.of("extension:example:urn3", "type3"); - - SimpleExtension.ExtensionCollection catalog = - SimpleExtension.ExtensionCollection.builder() - .addTypes(type1, type2, type3) - .uriUrnMap(uriUrnMap) - .build(); - - ExtensionCollector collector = new ExtensionCollector(catalog); - - collector.getTypeReference(type1.getAnchor()); - collector.getTypeReference(type3.getAnchor()); - - SimpleExtension.ExtensionCollection result = collector.getExtensionCollection(); - - assertEquals(2, result.uriUrnMap().forwardEntrySet().size()); - assertNotNull(result.getUriFromUrn("extension:example:urn1")); - assertNotNull(result.getUriFromUrn("extension:example:urn3")); - assertEquals("https://example.com/urn1", result.getUriFromUrn("extension:example:urn1")); - assertEquals("https://example.com/urn3", result.getUriFromUrn("extension:example:urn3")); - } - - @Test - public void getExtensionCollection_emptyWhenNothingTracked() { - SimpleExtension.ExtensionCollection catalog = - SimpleExtension.ExtensionCollection.builder().build(); - - ExtensionCollector collector = new ExtensionCollector(catalog); - - SimpleExtension.ExtensionCollection result = collector.getExtensionCollection(); - - assertTrue(result.types().isEmpty()); - assertTrue(result.scalarFunctions().isEmpty()); - assertTrue(result.aggregateFunctions().isEmpty()); - assertTrue(result.windowFunctions().isEmpty()); - } - - @Test - public void getExtensionCollection_throwsWhenFunctionNotInCatalog() { - SimpleExtension.ExtensionCollection emptyCatalog = - SimpleExtension.ExtensionCollection.builder().build(); - - ExtensionCollector collector = new ExtensionCollector(emptyCatalog); - - SimpleExtension.ScalarFunctionVariant func = - ImmutableSimpleExtension.ScalarFunctionVariant.builder() - .urn("extension:missing:catalog") - .name("missing_func") - .returnType(io.substrait.function.TypeExpressionCreator.REQUIRED.BOOLEAN) - .build(); - - collector.getFunctionReference(func); - - IllegalArgumentException exception = - org.junit.jupiter.api.Assertions.assertThrows( - IllegalArgumentException.class, () -> collector.getExtensionCollection()); - - assertTrue(exception.getMessage().contains("extension:missing:catalog::missing_func")); - assertTrue(exception.getMessage().contains("not found in catalog")); - } -} From 5eeff70590c0e0f972714399ace224ab3f51e11e Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Wed, 19 Nov 2025 22:27:41 -0500 Subject: [PATCH 06/11] tweak: simplify expression creator --- .../expression/ExpressionCreator.java | 40 +------------------ .../type/proto/LiteralRoundtripTest.java | 12 +++++- .../substrait/isthmus/CustomFunctionTest.java | 3 +- .../isthmus/UserDefinedTypeLiteralTest.java | 27 ++++++++----- 4 files changed, 32 insertions(+), 50 deletions(-) diff --git a/core/src/main/java/io/substrait/expression/ExpressionCreator.java b/core/src/main/java/io/substrait/expression/ExpressionCreator.java index 74de45732..2f924bef8 100644 --- a/core/src/main/java/io/substrait/expression/ExpressionCreator.java +++ b/core/src/main/java/io/substrait/expression/ExpressionCreator.java @@ -292,25 +292,7 @@ public static Expression.StructLiteral struct( * @param nullable whether the literal is nullable * @param urn the URN of the user-defined type * @param name the name of the user-defined type - * @param value the value, encoded as google.protobuf.Any - */ - public static Expression.UserDefinedAny userDefinedLiteralAny( - boolean nullable, String urn, String name, Any value) { - return Expression.UserDefinedAny.builder() - .nullable(nullable) - .urn(urn) - .name(name) - .value(value) - .build(); - } - - /** - * Create a UserDefinedAny with google.protobuf.Any representation and type parameters. - * - * @param nullable whether the literal is nullable - * @param urn the URN of the user-defined type - * @param name the name of the user-defined type - * @param typeParameters the type parameters for the user-defined type + * @param typeParameters the type parameters for the user-defined type (can be empty list) * @param value the value, encoded as google.protobuf.Any */ public static Expression.UserDefinedAny userDefinedLiteralAny( @@ -334,25 +316,7 @@ public static Expression.UserDefinedAny userDefinedLiteralAny( * @param nullable whether the literal is nullable * @param urn the URN of the user-defined type * @param name the name of the user-defined type - * @param fields the fields, as a list of Literal values - */ - public static Expression.UserDefinedStruct userDefinedLiteralStruct( - boolean nullable, String urn, String name, java.util.List fields) { - return Expression.UserDefinedStruct.builder() - .nullable(nullable) - .urn(urn) - .name(name) - .addAllFields(fields) - .build(); - } - - /** - * Create a UserDefinedStruct with Struct representation and type parameters. - * - * @param nullable whether the literal is nullable - * @param urn the URN of the user-defined type - * @param name the name of the user-defined type - * @param typeParameters the type parameters for the user-defined type + * @param typeParameters the type parameters for the user-defined type (can be empty list) * @param fields the fields, as a list of Literal values */ public static Expression.UserDefinedStruct userDefinedLiteralStruct( diff --git a/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java index be6e8dcb7..8e4b13cbe 100644 --- a/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java @@ -31,7 +31,11 @@ void userDefinedLiteralWithAnyRepresentation() { Expression.UserDefinedLiteral val = ExpressionCreator.userDefinedLiteralAny( - false, DefaultExtensionCatalog.EXTENSION_TYPES, "point", anyValue); + false, + DefaultExtensionCatalog.EXTENSION_TYPES, + "point", + java.util.Collections.emptyList(), + anyValue); verifyRoundTrip(val); } @@ -43,7 +47,11 @@ void userDefinedLiteralWithStructRepresentation() { ExpressionCreator.i32(false, 42), ExpressionCreator.i32(false, 100)); Expression.UserDefinedLiteral val = ExpressionCreator.userDefinedLiteralStruct( - false, DefaultExtensionCatalog.EXTENSION_TYPES, "point", fields); + false, + DefaultExtensionCatalog.EXTENSION_TYPES, + "point", + java.util.Collections.emptyList(), + fields); verifyRoundTrip(val); } diff --git a/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java b/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java index 1e80cf264..61a0603d0 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java @@ -584,7 +584,8 @@ void customTypesLiteralInFunctionsRoundtrip() { Expression.Literal.Builder bldr = Expression.Literal.newBuilder(); Any anyValue = Any.pack(bldr.setI32(10).build()); UserDefinedLiteral val = - ExpressionCreator.userDefinedLiteralAny(false, URN, "a_type", anyValue); + ExpressionCreator.userDefinedLiteralAny( + false, URN, "a_type", java.util.Collections.emptyList(), anyValue); Rel originalRel = b.project( diff --git a/isthmus/src/test/java/io/substrait/isthmus/UserDefinedTypeLiteralTest.java b/isthmus/src/test/java/io/substrait/isthmus/UserDefinedTypeLiteralTest.java index 59c513178..4b3fbfe36 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/UserDefinedTypeLiteralTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/UserDefinedTypeLiteralTest.java @@ -159,12 +159,14 @@ void multipleDifferentUserDefinedAnyTypesProtoRoundtrip() { Expression.Literal.Builder bldr1 = Expression.Literal.newBuilder(); Any anyValue1 = Any.pack(bldr1.setI32(100).build()); UserDefinedLiteral aTypeLit = - ExpressionCreator.userDefinedLiteralAny(false, URN, "a_type", anyValue1); + ExpressionCreator.userDefinedLiteralAny( + false, URN, "a_type", java.util.Collections.emptyList(), anyValue1); Expression.Literal.Builder bldr2 = Expression.Literal.newBuilder(); Any anyValue2 = Any.pack(bldr2.setString("b_value").build()); UserDefinedLiteral bTypeLit = - ExpressionCreator.userDefinedLiteralAny(false, URN, "b_type", anyValue2); + ExpressionCreator.userDefinedLiteralAny( + false, URN, "b_type", java.util.Collections.emptyList(), anyValue2); Rel originalRel = b.project( @@ -181,7 +183,8 @@ void singleUserDefinedAnyCalciteRoundtrip() { Expression.Literal.Builder bldr1 = Expression.Literal.newBuilder(); Any anyValue1 = Any.pack(bldr1.setI32(100).build()); UserDefinedLiteral aTypeLit = - ExpressionCreator.userDefinedLiteralAny(false, URN, "a_type", anyValue1); + ExpressionCreator.userDefinedLiteralAny( + false, URN, "a_type", java.util.Collections.emptyList(), anyValue1); Rel originalRel = b.project( @@ -219,12 +222,14 @@ void multipleDifferentUserDefinedAnyTypesCalciteRoundtrip() { Expression.Literal.Builder bldr1 = Expression.Literal.newBuilder(); Any anyValue1 = Any.pack(bldr1.setI32(100).build()); UserDefinedLiteral aTypeLit = - ExpressionCreator.userDefinedLiteralAny(false, URN, "a_type", anyValue1); + ExpressionCreator.userDefinedLiteralAny( + false, URN, "a_type", java.util.Collections.emptyList(), anyValue1); Expression.Literal.Builder bldr2 = Expression.Literal.newBuilder(); Any anyValue2 = Any.pack(bldr2.setString("b_value").build()); UserDefinedLiteral bTypeLit = - ExpressionCreator.userDefinedLiteralAny(false, URN, "b_type", anyValue2); + ExpressionCreator.userDefinedLiteralAny( + false, URN, "b_type", java.util.Collections.emptyList(), anyValue2); Rel originalRel = b.project( @@ -321,7 +326,8 @@ void intermixedUserDefinedAnyAndStructProtoRoundtrip() { Expression.Literal.Builder bldr1 = Expression.Literal.newBuilder(); Any anyValue1 = Any.pack(bldr1.setI64(999L).build()); UserDefinedLiteral anyLit1 = - ExpressionCreator.userDefinedLiteralAny(false, URN, "a_type", anyValue1); + ExpressionCreator.userDefinedLiteralAny( + false, URN, "a_type", java.util.Collections.emptyList(), anyValue1); io.substrait.expression.Expression.UserDefinedStruct structLit1 = io.substrait.expression.Expression.UserDefinedStruct.builder() @@ -335,7 +341,8 @@ void intermixedUserDefinedAnyAndStructProtoRoundtrip() { Expression.Literal.Builder bldr2 = Expression.Literal.newBuilder(); Any anyValue2 = Any.pack(bldr2.setString("mixed").build()); UserDefinedLiteral anyLit2 = - ExpressionCreator.userDefinedLiteralAny(false, URN, "a_type", anyValue2); + ExpressionCreator.userDefinedLiteralAny( + false, URN, "a_type", java.util.Collections.emptyList(), anyValue2); io.substrait.expression.Expression.UserDefinedStruct structLit2 = io.substrait.expression.Expression.UserDefinedStruct.builder() @@ -360,7 +367,8 @@ void multipleDifferentUDTTypesWithAnyAndStructProtoRoundtrip() { Expression.Literal.Builder aTypeBldr = Expression.Literal.newBuilder(); Any aTypeAny = Any.pack(aTypeBldr.setI32(42).build()); UserDefinedLiteral aTypeAny1 = - ExpressionCreator.userDefinedLiteralAny(false, URN, "a_type", aTypeAny); + ExpressionCreator.userDefinedLiteralAny( + false, URN, "a_type", java.util.Collections.emptyList(), aTypeAny); io.substrait.expression.Expression.UserDefinedStruct aTypeStruct = io.substrait.expression.Expression.UserDefinedStruct.builder() @@ -373,7 +381,8 @@ void multipleDifferentUDTTypesWithAnyAndStructProtoRoundtrip() { Expression.Literal.Builder bTypeBldr = Expression.Literal.newBuilder(); Any bTypeAny = Any.pack(bTypeBldr.setString("b_val").build()); UserDefinedLiteral bTypeAny1 = - ExpressionCreator.userDefinedLiteralAny(false, URN, "b_type", bTypeAny); + ExpressionCreator.userDefinedLiteralAny( + false, URN, "b_type", java.util.Collections.emptyList(), bTypeAny); io.substrait.expression.Expression.UserDefinedStruct bTypeStruct = io.substrait.expression.Expression.UserDefinedStruct.builder() From fe01302cbbf0604e4946ad481b65b796129b8a72 Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Wed, 19 Nov 2025 22:31:34 -0500 Subject: [PATCH 07/11] revert: simple extension --- .../substrait/extension/SimpleExtension.java | 28 ------------------- 1 file changed, 28 deletions(-) diff --git a/core/src/main/java/io/substrait/extension/SimpleExtension.java b/core/src/main/java/io/substrait/extension/SimpleExtension.java index 02608cbf3..39d7c45e0 100644 --- a/core/src/main/java/io/substrait/extension/SimpleExtension.java +++ b/core/src/main/java/io/substrait/extension/SimpleExtension.java @@ -551,18 +551,6 @@ public abstract static class Type { public TypeAnchor getAnchor() { return anchorSupplier.get(); } - - /** - * Creates a minimal Type instance for custom UserDefined types that aren't loaded from YAML. - * This is useful for programmatically constructed types during protobuf deserialization. - * - * @param urn the extension URN (e.g., "extension:test:custom") - * @param name the type name (e.g., "MyCustomType") - * @return a Type instance with the specified urn and name - */ - public static Type of(String urn, String name) { - return ImmutableSimpleExtension.Type.builder().urn(urn).name(name).build(); - } } @JsonDeserialize(as = ImmutableSimpleExtension.ExtensionSignatures.class) @@ -678,10 +666,6 @@ public Type getType(TypeAnchor anchor) { anchor.key(), anchor.urn())); } - public boolean hasType(TypeAnchor anchor) { - return typeLookup.get().containsKey(anchor); - } - public ScalarFunctionVariant getScalarFunction(FunctionAnchor anchor) { ScalarFunctionVariant variant = scalarFunctionsLookup.get().get(anchor); if (variant != null) { @@ -734,18 +718,6 @@ public WindowFunctionVariant getWindowFunction(FunctionAnchor anchor) { anchor.key(), anchor.urn())); } - public boolean hasScalarFunction(FunctionAnchor anchor) { - return scalarFunctionsLookup.get().containsKey(anchor); - } - - public boolean hasAggregateFunction(FunctionAnchor anchor) { - return aggregateFunctionsLookup.get().containsKey(anchor); - } - - public boolean hasWindowFunction(FunctionAnchor anchor) { - return windowFunctionsLookup.get().containsKey(anchor); - } - /** * Gets the URI for a given URN. This is for internal framework use during URI/URN migration. * From ccbd5a3e02cede5634f42fd628f6aab4cdbeabba Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Wed, 19 Nov 2025 22:46:40 -0500 Subject: [PATCH 08/11] tweak: simplify TypeConverter.java --- .../io/substrait/isthmus/TypeConverter.java | 3 +- .../type/SubstraitUserDefinedType.java | 35 +++++++++++-------- 2 files changed, 21 insertions(+), 17 deletions(-) diff --git a/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java b/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java index 5cdf8088f..c332dfd19 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java @@ -363,8 +363,7 @@ public RelDataType visit(Type.UserDefined expr) throws RuntimeException { if (type != null) { return type; } - return io.substrait.isthmus.type.SubstraitUserDefinedType.SubstraitUserDefinedAnyType.from( - expr); + return io.substrait.isthmus.type.SubstraitUserDefinedType.from(expr); } private boolean n(NullableType type) { diff --git a/isthmus/src/main/java/io/substrait/isthmus/type/SubstraitUserDefinedType.java b/isthmus/src/main/java/io/substrait/isthmus/type/SubstraitUserDefinedType.java index 48fda7c71..e9b9c1a8f 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/type/SubstraitUserDefinedType.java +++ b/isthmus/src/main/java/io/substrait/isthmus/type/SubstraitUserDefinedType.java @@ -9,13 +9,11 @@ import org.apache.calcite.sql.type.SqlTypeName; /** - * Base class for custom Calcite {@link RelDataType} implementations representing Substrait - * user-defined types. + * Custom Calcite {@link RelDataType} for Substrait user-defined types. * - *

These custom types preserve all UDT metadata (URN, name, type parameters) during Calcite - * roundtrips, unlike the previous approach which flattened everything to binary with REINTERPRET. - * - *

Two concrete implementations exist: + *

This type preserves all UDT metadata (URN, name, type parameters) during Calcite roundtrips. + * It is used when converting types without literal context. For literals, specialized subclasses + * provide representation-specific handling: * *

    *
  • {@link SubstraitUserDefinedAnyType} - For opaque binary UDT literals (wraps protobuf Any) @@ -27,14 +25,14 @@ * @see io.substrait.expression.Expression.UserDefinedAny * @see io.substrait.expression.Expression.UserDefinedStruct */ -public abstract class SubstraitUserDefinedType extends RelDataTypeImpl { +public class SubstraitUserDefinedType extends RelDataTypeImpl { private final String urn; private final String name; private final List typeParameters; private final boolean nullable; - protected SubstraitUserDefinedType( + public SubstraitUserDefinedType( String urn, String name, List typeParameters, @@ -79,6 +77,12 @@ public Type.UserDefined toSubstraitType() { .build(); } + /** Creates a SubstraitUserDefinedType from a Substrait Type.UserDefined. */ + public static SubstraitUserDefinedType from(io.substrait.type.Type.UserDefined type) { + return new SubstraitUserDefinedType( + type.urn(), type.name(), type.typeParameters(), type.nullable()); + } + @Override protected void generateTypeString(StringBuilder sb, boolean withDetail) { sb.append(name); @@ -99,13 +103,13 @@ protected void generateTypeString(StringBuilder sb, boolean withDetail) { *

    Note: The actual value (protobuf Any) is not stored in the type itself - it's stored in the * literal. This type only carries the metadata (URN, name, type parameters). * - *

    Both {@link io.substrait.expression.Expression.UserDefinedAny UserDefinedAny} and {@link - * io.substrait.expression.Expression.UserDefinedStruct UserDefinedStruct} literals use this type - * when passing through Calcite, as they both need to be serialized to binary with REINTERPRET. + *

    {@link io.substrait.expression.Expression.UserDefinedAny UserDefinedAny} literals use this + * type when passing through Calcite, as they need to be serialized to binary with REINTERPRET. + * {@link io.substrait.expression.Expression.UserDefinedStruct UserDefinedStruct} literals use + * {@link SubstraitUserDefinedStructType} instead to preserve field structure. * * @see SubstraitUserDefinedStructType * @see io.substrait.expression.Expression.UserDefinedAny - * @see io.substrait.expression.Expression.UserDefinedStruct */ public static class SubstraitUserDefinedAnyType extends SubstraitUserDefinedType { @@ -132,9 +136,10 @@ public static SubstraitUserDefinedAnyType from(io.substrait.type.Type.UserDefine * SubstraitUserDefinedAnyType}, the fields are accessible and can be represented as a Calcite * STRUCT/ROW type with additional UDT metadata (URN, name, type parameters). * - *

    Note: Currently, UserDefinedStruct literals are serialized to binary when passing through - * Calcite (using {@link SubstraitUserDefinedAnyType}), so this structured type is primarily for - * future use when Calcite can better handle structured user-defined types. + *

    {@link io.substrait.expression.Expression.UserDefinedStruct UserDefinedStruct} literals use + * this type when passing through Calcite, preserving field structure and enabling field access. + * The fields are converted to Calcite literals and wrapped in a ROW type with synthetic field + * names (f0, f1, f2, etc.). * * @see SubstraitUserDefinedAnyType * @see io.substrait.expression.Expression.UserDefinedStruct From 99fe14c877fb56d7a3fe21f002c8a44cdddc4b3e Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Wed, 19 Nov 2025 23:42:44 -0500 Subject: [PATCH 09/11] tweak: use default types and simplify code a bit --- .../type/SubstraitUserDefinedType.java | 8 +- .../isthmus/UserDefinedTypeLiteralTest.java | 443 ++++++++++++------ .../isthmus/utils/UserTypeFactory.java | 6 +- 3 files changed, 314 insertions(+), 143 deletions(-) diff --git a/isthmus/src/main/java/io/substrait/isthmus/type/SubstraitUserDefinedType.java b/isthmus/src/main/java/io/substrait/isthmus/type/SubstraitUserDefinedType.java index e9b9c1a8f..2ce33eb5d 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/type/SubstraitUserDefinedType.java +++ b/isthmus/src/main/java/io/substrait/isthmus/type/SubstraitUserDefinedType.java @@ -85,7 +85,9 @@ public static SubstraitUserDefinedType from(io.substrait.type.Type.UserDefined t @Override protected void generateTypeString(StringBuilder sb, boolean withDetail) { - sb.append(name); + // Include URN in type string to ensure types with same name but different URNs + // are not considered equal by Calcite's type system + sb.append(urn).append("::").append(name); if (!typeParameters.isEmpty()) { sb.append("<"); sb.append(String.join(", ", java.util.Collections.nCopies(typeParameters.size(), "_"))); @@ -206,7 +208,9 @@ public SqlTypeName getSqlTypeName() { @Override protected void generateTypeString(StringBuilder sb, boolean withDetail) { - sb.append(getName()); + // Include URN in type string to ensure types with same name but different URNs + // are not considered equal by Calcite's type system + sb.append(getUrn()).append("::").append(getName()); if (!getTypeParameters().isEmpty()) { sb.append("<"); sb.append( diff --git a/isthmus/src/test/java/io/substrait/isthmus/UserDefinedTypeLiteralTest.java b/isthmus/src/test/java/io/substrait/isthmus/UserDefinedTypeLiteralTest.java index 4b3fbfe36..cd1608a94 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/UserDefinedTypeLiteralTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/UserDefinedTypeLiteralTest.java @@ -6,6 +6,7 @@ import io.substrait.dsl.SubstraitBuilder; import io.substrait.expression.Expression.UserDefinedLiteral; import io.substrait.expression.ExpressionCreator; +import io.substrait.extension.DefaultExtensionCatalog; import io.substrait.extension.ExtensionCollector; import io.substrait.extension.SimpleExtension; import io.substrait.isthmus.expression.AggregateFunctionConverter; @@ -18,8 +19,6 @@ import io.substrait.relation.RelProtoConverter; import io.substrait.type.Type; import io.substrait.type.TypeCreator; -import java.io.IOException; -import java.io.UncheckedIOException; import java.util.List; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; @@ -31,34 +30,18 @@ * Tests for User-Defined Type literals, including both UserDefinedAny (protobuf Any-based) and * UserDefinedStruct (struct-based) encoding strategies. * - *

    These tests verify proto serialization/deserialization of UDT literals (core's - * responsibility), using custom extensions defined in isthmus test resources. + *

    These tests verify proto serialization/deserialization and Calcite roundtrips of UDT literals, + * using standard types from extension_types.yaml (point and line). */ public class UserDefinedTypeLiteralTest extends PlanTestBase { - // Define custom types in a "functions_custom.yaml" extension - static final String URN = "extension:substrait:functions_custom"; - static final String FUNCTIONS_CUSTOM; + final SubstraitBuilder b = new SubstraitBuilder(DefaultExtensionCatalog.DEFAULT_COLLECTION); - static { - try { - FUNCTIONS_CUSTOM = asString("extensions/functions_custom.yaml"); - } catch (IOException e) { - throw new UncheckedIOException(e); - } - } - - // Load custom extension into an ExtensionCollection - static final SimpleExtension.ExtensionCollection testExtensions = - SimpleExtension.load("custom.yaml", FUNCTIONS_CUSTOM); - - final SubstraitBuilder b = new SubstraitBuilder(testExtensions); - - // Create user-defined types - static final String aTypeName = "a_type"; - static final String bTypeName = "b_type"; - static final UserTypeFactory aTypeFactory = new UserTypeFactory(URN, aTypeName); - static final UserTypeFactory bTypeFactory = new UserTypeFactory(URN, bTypeName); + // Create user-defined types using standard types from extension_types.yaml + static final UserTypeFactory pointTypeFactory = + new UserTypeFactory(DefaultExtensionCatalog.EXTENSION_TYPES, "point"); + static final UserTypeFactory lineTypeFactory = + new UserTypeFactory(DefaultExtensionCatalog.EXTENSION_TYPES, "line"); // Mapper for user-defined types static final UserTypeMapper userTypeMapper = @@ -66,11 +49,13 @@ public class UserDefinedTypeLiteralTest extends PlanTestBase { @Nullable @Override public Type toSubstrait(RelDataType relDataType) { - if (aTypeFactory.isTypeFromFactory(relDataType)) { - return TypeCreator.of(relDataType.isNullable()).userDefined(URN, aTypeName); + if (pointTypeFactory.isTypeFromFactory(relDataType)) { + return TypeCreator.of(relDataType.isNullable()) + .userDefined(DefaultExtensionCatalog.EXTENSION_TYPES, "point"); } - if (bTypeFactory.isTypeFromFactory(relDataType)) { - return TypeCreator.of(relDataType.isNullable()).userDefined(URN, bTypeName); + if (lineTypeFactory.isTypeFromFactory(relDataType)) { + return TypeCreator.of(relDataType.isNullable()) + .userDefined(DefaultExtensionCatalog.EXTENSION_TYPES, "line"); } return null; } @@ -78,12 +63,12 @@ public Type toSubstrait(RelDataType relDataType) { @Nullable @Override public RelDataType toCalcite(Type.UserDefined type) { - if (type.urn().equals(URN)) { - if (type.name().equals(aTypeName)) { - return aTypeFactory.createCalcite(type.nullable()); + if (type.urn().equals(DefaultExtensionCatalog.EXTENSION_TYPES)) { + if (type.name().equals("point")) { + return pointTypeFactory.createCalcite(type.nullable()); } - if (type.name().equals(bTypeName)) { - return bTypeFactory.createCalcite(type.nullable()); + if (type.name().equals("line")) { + return lineTypeFactory.createCalcite(type.nullable()); } } return null; @@ -92,18 +77,26 @@ public RelDataType toCalcite(Type.UserDefined type) { TypeConverter typeConverter = new TypeConverter(userTypeMapper); - // Create Function Converters that can handle the custom types + // Create Function Converters that can handle the user-defined types ScalarFunctionConverter scalarFunctionConverter = new ScalarFunctionConverter( - testExtensions.scalarFunctions(), List.of(), typeFactory, typeConverter); + DefaultExtensionCatalog.DEFAULT_COLLECTION.scalarFunctions(), + List.of(), + typeFactory, + typeConverter); AggregateFunctionConverter aggregateFunctionConverter = new AggregateFunctionConverter( - testExtensions.aggregateFunctions(), List.of(), typeFactory, typeConverter); + DefaultExtensionCatalog.DEFAULT_COLLECTION.aggregateFunctions(), + List.of(), + typeFactory, + typeConverter); WindowFunctionConverter windowFunctionConverter = - new WindowFunctionConverter(testExtensions.windowFunctions(), typeFactory); + new WindowFunctionConverter( + DefaultExtensionCatalog.DEFAULT_COLLECTION.windowFunctions(), typeFactory); final SubstraitToCalcite substraitToCalcite = - new CustomSubstraitToCalcite(testExtensions, typeFactory, typeConverter); + new CustomSubstraitToCalcite( + DefaultExtensionCatalog.DEFAULT_COLLECTION, typeFactory, typeConverter); // Create a SubstraitRelVisitor that uses the custom Function Converters final SubstraitRelVisitor calciteToSubstrait = @@ -139,13 +132,14 @@ protected SubstraitRelNodeConverter createSubstraitRelNodeConverter(RelBuilder r /** * Verifies proto roundtrip for a relation. This test class needs this method locally since it's - * testing proto serialization (core's responsibility) but must reside in isthmus to access custom - * test extensions. + * testing proto serialization (core's responsibility) but must reside in isthmus to access + * Calcite integration components. */ private void verifyProtoRoundTrip(Rel rel) { ExtensionCollector functionCollector = new ExtensionCollector(); RelProtoConverter relProtoConverter = new RelProtoConverter(functionCollector); - ProtoRelConverter protoRelConverter = new ProtoRelConverter(functionCollector, testExtensions); + ProtoRelConverter protoRelConverter = + new ProtoRelConverter(functionCollector, DefaultExtensionCatalog.DEFAULT_COLLECTION); io.substrait.proto.Rel protoRel = relProtoConverter.toProto(rel); Rel relReturned = protoRelConverter.from(protoRel); @@ -154,25 +148,46 @@ private void verifyProtoRoundTrip(Rel rel) { @Test void multipleDifferentUserDefinedAnyTypesProtoRoundtrip() { - // Test that UserDefinedAny literals with different payload types have different type names - // a_type wraps int, b_type wraps string - proto only + // Test that UserDefinedAny literals with different type names - proto only + // point wraps struct with two i32 fields, line wraps struct with two point fields Expression.Literal.Builder bldr1 = Expression.Literal.newBuilder(); - Any anyValue1 = Any.pack(bldr1.setI32(100).build()); - UserDefinedLiteral aTypeLit = + Expression.Literal.Struct pointStruct = + Expression.Literal.Struct.newBuilder() + .addFields(Expression.Literal.newBuilder().setI32(42)) + .addFields(Expression.Literal.newBuilder().setI32(100)) + .build(); + Any anyValue1 = Any.pack(bldr1.setStruct(pointStruct).build()); + UserDefinedLiteral pointLit = ExpressionCreator.userDefinedLiteralAny( - false, URN, "a_type", java.util.Collections.emptyList(), anyValue1); + false, + DefaultExtensionCatalog.EXTENSION_TYPES, + "point", + java.util.Collections.emptyList(), + anyValue1); Expression.Literal.Builder bldr2 = Expression.Literal.newBuilder(); - Any anyValue2 = Any.pack(bldr2.setString("b_value").build()); - UserDefinedLiteral bTypeLit = + Expression.Literal.Struct lineStruct = + Expression.Literal.Struct.newBuilder() + .addFields(bldr1.build()) // reuse point struct as start + .addFields(bldr1.build()) // reuse point struct as end + .build(); + Any anyValue2 = Any.pack(bldr2.setStruct(lineStruct).build()); + UserDefinedLiteral lineLit = ExpressionCreator.userDefinedLiteralAny( - false, URN, "b_type", java.util.Collections.emptyList(), anyValue2); + false, + DefaultExtensionCatalog.EXTENSION_TYPES, + "line", + java.util.Collections.emptyList(), + anyValue2); Rel originalRel = b.project( - input -> List.of(aTypeLit, bTypeLit), + input -> List.of(pointLit, lineLit), b.remap(1, 2), // Select both expressions - b.namedScan(List.of("example"), List.of("a"), List.of(N.userDefined(URN, "a_type")))); + b.namedScan( + List.of("example"), + List.of("a"), + List.of(N.userDefined(DefaultExtensionCatalog.EXTENSION_TYPES, "point")))); verifyProtoRoundTrip(originalRel); } @@ -181,18 +196,34 @@ void multipleDifferentUserDefinedAnyTypesProtoRoundtrip() { void singleUserDefinedAnyCalciteRoundtrip() { // Test that a single UserDefinedAny literal can roundtrip through Calcite Expression.Literal.Builder bldr1 = Expression.Literal.newBuilder(); - Any anyValue1 = Any.pack(bldr1.setI32(100).build()); - UserDefinedLiteral aTypeLit = + Expression.Literal.Struct pointStruct = + Expression.Literal.Struct.newBuilder() + .addFields(Expression.Literal.newBuilder().setI32(42)) + .addFields(Expression.Literal.newBuilder().setI32(100)) + .build(); + Any anyValue1 = Any.pack(bldr1.setStruct(pointStruct).build()); + UserDefinedLiteral pointLit = ExpressionCreator.userDefinedLiteralAny( - false, URN, "a_type", java.util.Collections.emptyList(), anyValue1); + false, + DefaultExtensionCatalog.EXTENSION_TYPES, + "point", + java.util.Collections.emptyList(), + anyValue1); Rel originalRel = b.project( - input -> List.of(aTypeLit), + input -> List.of(pointLit), b.remap(1), - b.namedScan(List.of("example"), List.of("a"), List.of(N.userDefined(URN, "a_type")))); - - assertCalciteRoundtrip(originalRel, substraitToCalcite, calciteToSubstrait, testExtensions); + b.namedScan( + List.of("example"), + List.of("a"), + List.of(N.userDefined(DefaultExtensionCatalog.EXTENSION_TYPES, "point")))); + + assertCalciteRoundtrip( + originalRel, + substraitToCalcite, + calciteToSubstrait, + DefaultExtensionCatalog.DEFAULT_COLLECTION); } @Test @@ -201,90 +232,140 @@ void singleUserDefinedStructCalciteRoundtrip() { io.substrait.expression.Expression.UserDefinedStruct val = io.substrait.expression.Expression.UserDefinedStruct.builder() .nullable(false) - .urn(URN) - .name("a_type") + .urn(DefaultExtensionCatalog.EXTENSION_TYPES) + .name("point") .addFields(ExpressionCreator.i32(false, 42)) - .addFields(ExpressionCreator.string(false, "hello")) + .addFields(ExpressionCreator.i32(false, 100)) .build(); Rel originalRel = b.project( input -> List.of(val), b.remap(1), - b.namedScan(List.of("example"), List.of("a"), List.of(N.userDefined(URN, "a_type")))); - - assertCalciteRoundtrip(originalRel, substraitToCalcite, calciteToSubstrait, testExtensions); + b.namedScan( + List.of("example"), + List.of("a"), + List.of(N.userDefined(DefaultExtensionCatalog.EXTENSION_TYPES, "point")))); + + assertCalciteRoundtrip( + originalRel, + substraitToCalcite, + calciteToSubstrait, + DefaultExtensionCatalog.DEFAULT_COLLECTION); } @Test void multipleDifferentUserDefinedAnyTypesCalciteRoundtrip() { // Test that multiple UserDefinedAny literals with different types can roundtrip through Calcite Expression.Literal.Builder bldr1 = Expression.Literal.newBuilder(); - Any anyValue1 = Any.pack(bldr1.setI32(100).build()); - UserDefinedLiteral aTypeLit = + Expression.Literal.Struct pointStruct = + Expression.Literal.Struct.newBuilder() + .addFields(Expression.Literal.newBuilder().setI32(42)) + .addFields(Expression.Literal.newBuilder().setI32(100)) + .build(); + Any anyValue1 = Any.pack(bldr1.setStruct(pointStruct).build()); + UserDefinedLiteral pointLit = ExpressionCreator.userDefinedLiteralAny( - false, URN, "a_type", java.util.Collections.emptyList(), anyValue1); + false, + DefaultExtensionCatalog.EXTENSION_TYPES, + "point", + java.util.Collections.emptyList(), + anyValue1); Expression.Literal.Builder bldr2 = Expression.Literal.newBuilder(); - Any anyValue2 = Any.pack(bldr2.setString("b_value").build()); - UserDefinedLiteral bTypeLit = + Expression.Literal.Struct lineStruct = + Expression.Literal.Struct.newBuilder() + .addFields(bldr1.build()) + .addFields(bldr1.build()) + .build(); + Any anyValue2 = Any.pack(bldr2.setStruct(lineStruct).build()); + UserDefinedLiteral lineLit = ExpressionCreator.userDefinedLiteralAny( - false, URN, "b_type", java.util.Collections.emptyList(), anyValue2); + false, + DefaultExtensionCatalog.EXTENSION_TYPES, + "line", + java.util.Collections.emptyList(), + anyValue2); Rel originalRel = b.project( - input -> List.of(aTypeLit, bTypeLit), + input -> List.of(pointLit, lineLit), b.remap(1, 2), // Select both expressions - b.namedScan(List.of("example"), List.of("a"), List.of(N.userDefined(URN, "a_type")))); - - assertCalciteRoundtrip(originalRel, substraitToCalcite, calciteToSubstrait, testExtensions); + b.namedScan( + List.of("example"), + List.of("a"), + List.of(N.userDefined(DefaultExtensionCatalog.EXTENSION_TYPES, "point")))); + + assertCalciteRoundtrip( + originalRel, + substraitToCalcite, + calciteToSubstrait, + DefaultExtensionCatalog.DEFAULT_COLLECTION); } @Test void userDefinedStructWithPrimitivesProtoRoundtrip() { - // Test UserDefinedStruct with various primitive field types - proto roundtrip only + // Test UserDefinedStruct with primitive field types - proto roundtrip only io.substrait.expression.Expression.UserDefinedStruct val = io.substrait.expression.Expression.UserDefinedStruct.builder() .nullable(false) - .urn(URN) - .name("a_type") + .urn(DefaultExtensionCatalog.EXTENSION_TYPES) + .name("point") .addFields(ExpressionCreator.i32(false, 42)) - .addFields(ExpressionCreator.string(false, "hello")) - .addFields(ExpressionCreator.bool(false, true)) - .addFields(ExpressionCreator.fp64(false, 2.718)) + .addFields(ExpressionCreator.i32(false, 100)) .build(); Rel originalRel = b.project( input -> List.of(val), b.remap(1), - b.namedScan(List.of("example"), List.of("a"), List.of(N.userDefined(URN, "a_type")))); + b.namedScan( + List.of("example"), + List.of("a"), + List.of(N.userDefined(DefaultExtensionCatalog.EXTENSION_TYPES, "point")))); verifyProtoRoundTrip(originalRel); } @Test void userDefinedStructWithNestedStructProtoRoundtrip() { - // Test UserDefinedStruct with nested struct fields - proto roundtrip only - io.substrait.expression.Expression.StructLiteral innerStruct = - ExpressionCreator.struct( - false, ExpressionCreator.i32(false, 10), ExpressionCreator.string(false, "nested")); + // Test UserDefinedStruct with nested UDT fields - proto roundtrip only + // line contains nested point UDT fields + io.substrait.expression.Expression.UserDefinedStruct startPoint = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(DefaultExtensionCatalog.EXTENSION_TYPES) + .name("point") + .addFields(ExpressionCreator.i32(false, 10)) + .addFields(ExpressionCreator.i32(false, 20)) + .build(); - io.substrait.expression.Expression.UserDefinedStruct val = + io.substrait.expression.Expression.UserDefinedStruct endPoint = io.substrait.expression.Expression.UserDefinedStruct.builder() .nullable(false) - .urn(URN) - .name("a_type") - .addFields(ExpressionCreator.i32(false, 100)) - .addFields(innerStruct) - .addFields(ExpressionCreator.bool(false, false)) + .urn(DefaultExtensionCatalog.EXTENSION_TYPES) + .name("point") + .addFields(ExpressionCreator.i32(false, 30)) + .addFields(ExpressionCreator.i32(false, 40)) + .build(); + + io.substrait.expression.Expression.UserDefinedStruct line = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(DefaultExtensionCatalog.EXTENSION_TYPES) + .name("line") + .addFields(startPoint) + .addFields(endPoint) .build(); Rel originalRel = b.project( - input -> List.of(val), + input -> List.of(line), b.remap(1), - b.namedScan(List.of("example"), List.of("a"), List.of(N.userDefined(URN, "a_type")))); + b.namedScan( + List.of("example"), + List.of("a"), + List.of(N.userDefined(DefaultExtensionCatalog.EXTENSION_TYPES, "line")))); verifyProtoRoundTrip(originalRel); } @@ -292,30 +373,52 @@ void userDefinedStructWithNestedStructProtoRoundtrip() { @Test void multipleUserDefinedStructDifferentStructuresProtoRoundtrip() { // Test multiple UserDefinedStruct types with different struct schemas - // a_type: {content: string} - // b_type: {content_int: i32, content_fp: fp64} - io.substrait.expression.Expression.UserDefinedStruct aTypeStruct = + // point: {latitude: i32, longitude: i32} + // line: {start: point, end: point} + io.substrait.expression.Expression.UserDefinedStruct pointStruct = io.substrait.expression.Expression.UserDefinedStruct.builder() .nullable(false) - .urn(URN) - .name("a_type") - .addFields(ExpressionCreator.string(false, "hello")) + .urn(DefaultExtensionCatalog.EXTENSION_TYPES) + .name("point") + .addFields(ExpressionCreator.i32(false, 42)) + .addFields(ExpressionCreator.i32(false, 100)) .build(); - io.substrait.expression.Expression.UserDefinedStruct bTypeStruct = + io.substrait.expression.Expression.UserDefinedStruct startPoint = io.substrait.expression.Expression.UserDefinedStruct.builder() .nullable(false) - .urn(URN) - .name("b_type") - .addFields(ExpressionCreator.i32(false, 42)) - .addFields(ExpressionCreator.fp64(false, 3.14159)) + .urn(DefaultExtensionCatalog.EXTENSION_TYPES) + .name("point") + .addFields(ExpressionCreator.i32(false, 10)) + .addFields(ExpressionCreator.i32(false, 20)) + .build(); + + io.substrait.expression.Expression.UserDefinedStruct endPoint = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(DefaultExtensionCatalog.EXTENSION_TYPES) + .name("point") + .addFields(ExpressionCreator.i32(false, 30)) + .addFields(ExpressionCreator.i32(false, 40)) + .build(); + + io.substrait.expression.Expression.UserDefinedStruct lineStruct = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(DefaultExtensionCatalog.EXTENSION_TYPES) + .name("line") + .addFields(startPoint) + .addFields(endPoint) .build(); Rel originalRel = b.project( - input -> List.of(aTypeStruct, bTypeStruct), + input -> List.of(pointStruct, lineStruct), b.remap(2), - b.namedScan(List.of("example"), List.of("a"), List.of(N.userDefined(URN, "a_type")))); + b.namedScan( + List.of("example"), + List.of("a"), + List.of(N.userDefined(DefaultExtensionCatalog.EXTENSION_TYPES, "point")))); verifyProtoRoundTrip(originalRel); } @@ -324,80 +427,142 @@ void multipleUserDefinedStructDifferentStructuresProtoRoundtrip() { void intermixedUserDefinedAnyAndStructProtoRoundtrip() { // Test intermixing UserDefinedAny and UserDefinedStruct in the same query Expression.Literal.Builder bldr1 = Expression.Literal.newBuilder(); - Any anyValue1 = Any.pack(bldr1.setI64(999L).build()); + Expression.Literal.Struct pointStruct1 = + Expression.Literal.Struct.newBuilder() + .addFields(Expression.Literal.newBuilder().setI32(10)) + .addFields(Expression.Literal.newBuilder().setI32(20)) + .build(); + Any anyValue1 = Any.pack(bldr1.setStruct(pointStruct1).build()); UserDefinedLiteral anyLit1 = ExpressionCreator.userDefinedLiteralAny( - false, URN, "a_type", java.util.Collections.emptyList(), anyValue1); + false, + DefaultExtensionCatalog.EXTENSION_TYPES, + "point", + java.util.Collections.emptyList(), + anyValue1); io.substrait.expression.Expression.UserDefinedStruct structLit1 = io.substrait.expression.Expression.UserDefinedStruct.builder() .nullable(false) - .urn(URN) - .name("a_type") + .urn(DefaultExtensionCatalog.EXTENSION_TYPES) + .name("point") .addFields(ExpressionCreator.i32(false, 123)) - .addFields(ExpressionCreator.bool(false, false)) + .addFields(ExpressionCreator.i32(false, 456)) .build(); Expression.Literal.Builder bldr2 = Expression.Literal.newBuilder(); - Any anyValue2 = Any.pack(bldr2.setString("mixed").build()); + Expression.Literal.Struct pointStruct2 = + Expression.Literal.Struct.newBuilder() + .addFields(Expression.Literal.newBuilder().setI32(30)) + .addFields(Expression.Literal.newBuilder().setI32(40)) + .build(); + Any anyValue2 = Any.pack(bldr2.setStruct(pointStruct2).build()); UserDefinedLiteral anyLit2 = ExpressionCreator.userDefinedLiteralAny( - false, URN, "a_type", java.util.Collections.emptyList(), anyValue2); + false, + DefaultExtensionCatalog.EXTENSION_TYPES, + "point", + java.util.Collections.emptyList(), + anyValue2); io.substrait.expression.Expression.UserDefinedStruct structLit2 = io.substrait.expression.Expression.UserDefinedStruct.builder() .nullable(false) - .urn(URN) - .name("a_type") - .addFields(ExpressionCreator.fp64(false, 1.414)) + .urn(DefaultExtensionCatalog.EXTENSION_TYPES) + .name("point") + .addFields(ExpressionCreator.i32(false, 789)) + .addFields(ExpressionCreator.i32(false, 101)) .build(); Rel originalRel = b.project( input -> List.of(anyLit1, structLit1, anyLit2, structLit2), b.remap(4), - b.namedScan(List.of("example"), List.of("a"), List.of(N.userDefined(URN, "a_type")))); + b.namedScan( + List.of("example"), + List.of("a"), + List.of(N.userDefined(DefaultExtensionCatalog.EXTENSION_TYPES, "point")))); verifyProtoRoundTrip(originalRel); } @Test void multipleDifferentUDTTypesWithAnyAndStructProtoRoundtrip() { - // Test multiple different UDT type names (a_type, b_type) with both Any and Struct - Expression.Literal.Builder aTypeBldr = Expression.Literal.newBuilder(); - Any aTypeAny = Any.pack(aTypeBldr.setI32(42).build()); - UserDefinedLiteral aTypeAny1 = + // Test multiple different UDT type names (point, line) with both Any and Struct + Expression.Literal.Builder pointBldr = Expression.Literal.newBuilder(); + Expression.Literal.Struct pointStruct = + Expression.Literal.Struct.newBuilder() + .addFields(Expression.Literal.newBuilder().setI32(42)) + .addFields(Expression.Literal.newBuilder().setI32(100)) + .build(); + Any pointAny = Any.pack(pointBldr.setStruct(pointStruct).build()); + UserDefinedLiteral pointAny1 = ExpressionCreator.userDefinedLiteralAny( - false, URN, "a_type", java.util.Collections.emptyList(), aTypeAny); + false, + DefaultExtensionCatalog.EXTENSION_TYPES, + "point", + java.util.Collections.emptyList(), + pointAny); - io.substrait.expression.Expression.UserDefinedStruct aTypeStruct = + io.substrait.expression.Expression.UserDefinedStruct pointStructLit = io.substrait.expression.Expression.UserDefinedStruct.builder() .nullable(false) - .urn(URN) - .name("a_type") - .addFields(ExpressionCreator.i32(false, 100)) + .urn(DefaultExtensionCatalog.EXTENSION_TYPES) + .name("point") + .addFields(ExpressionCreator.i32(false, 10)) + .addFields(ExpressionCreator.i32(false, 20)) .build(); - Expression.Literal.Builder bTypeBldr = Expression.Literal.newBuilder(); - Any bTypeAny = Any.pack(bTypeBldr.setString("b_val").build()); - UserDefinedLiteral bTypeAny1 = + Expression.Literal.Builder lineBldr = Expression.Literal.newBuilder(); + Expression.Literal.Struct lineStruct = + Expression.Literal.Struct.newBuilder() + .addFields(pointBldr.build()) + .addFields(pointBldr.build()) + .build(); + Any lineAny = Any.pack(lineBldr.setStruct(lineStruct).build()); + UserDefinedLiteral lineAny1 = ExpressionCreator.userDefinedLiteralAny( - false, URN, "b_type", java.util.Collections.emptyList(), bTypeAny); + false, + DefaultExtensionCatalog.EXTENSION_TYPES, + "line", + java.util.Collections.emptyList(), + lineAny); + + io.substrait.expression.Expression.UserDefinedStruct startPoint = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(DefaultExtensionCatalog.EXTENSION_TYPES) + .name("point") + .addFields(ExpressionCreator.i32(false, 50)) + .addFields(ExpressionCreator.i32(false, 60)) + .build(); + + io.substrait.expression.Expression.UserDefinedStruct endPoint = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(DefaultExtensionCatalog.EXTENSION_TYPES) + .name("point") + .addFields(ExpressionCreator.i32(false, 70)) + .addFields(ExpressionCreator.i32(false, 80)) + .build(); - io.substrait.expression.Expression.UserDefinedStruct bTypeStruct = + io.substrait.expression.Expression.UserDefinedStruct lineStructLit = io.substrait.expression.Expression.UserDefinedStruct.builder() .nullable(false) - .urn(URN) - .name("b_type") - .addFields(ExpressionCreator.string(false, "struct_b")) - .addFields(ExpressionCreator.bool(false, true)) + .urn(DefaultExtensionCatalog.EXTENSION_TYPES) + .name("line") + .addFields(startPoint) + .addFields(endPoint) .build(); Rel originalRel = b.project( - input -> List.of(aTypeAny1, aTypeStruct, bTypeAny1, bTypeStruct), + input -> List.of(pointAny1, pointStructLit, lineAny1, lineStructLit), b.remap(4), - b.namedScan(List.of("example"), List.of("a"), List.of(N.userDefined(URN, "a_type")))); + b.namedScan( + List.of("example"), + List.of("a"), + List.of(N.userDefined(DefaultExtensionCatalog.EXTENSION_TYPES, "point")))); verifyProtoRoundTrip(originalRel); } diff --git a/isthmus/src/test/java/io/substrait/isthmus/utils/UserTypeFactory.java b/isthmus/src/test/java/io/substrait/isthmus/utils/UserTypeFactory.java index f1004261a..0b10a8134 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/utils/UserTypeFactory.java +++ b/isthmus/src/test/java/io/substrait/isthmus/utils/UserTypeFactory.java @@ -50,7 +50,7 @@ public boolean isTypeFromFactory(RelDataType type) { return false; } - private static class InnerType extends RelDataTypeImpl { + private class InnerType extends RelDataTypeImpl { private final boolean nullable; private final String name; @@ -72,7 +72,9 @@ public SqlTypeName getSqlTypeName() { @Override protected void generateTypeString(StringBuilder sb, boolean withDetail) { - sb.append(name); + // Include URN in type string to ensure types with same name but different URNs + // are not considered equal by Calcite's type system + sb.append(UserTypeFactory.this.urn).append("::").append(name); } } } From 8e0548805a2466ae96a71a3a53e0cf8aba21dd28 Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Thu, 20 Nov 2025 11:34:44 -0500 Subject: [PATCH 10/11] tweak: more adjustments --- .../expression/ExpressionRexConverter.java | 42 ++++-- .../isthmus/expression/LiteralConverter.java | 76 +++++++++- .../isthmus/UserDefinedTypeLiteralTest.java | 141 +++++++++++++++++- .../isthmus/utils/UserTypeFactory.java | 49 ++++-- .../isthmus/utils/UserTypeFactoryTest.java | 35 +++++ 5 files changed, 314 insertions(+), 29 deletions(-) create mode 100644 isthmus/src/test/java/io/substrait/isthmus/utils/UserTypeFactoryTest.java diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java index bf6aeceba..f8488e8af 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java @@ -121,11 +121,14 @@ public RexNode visit(Expression.UserDefinedAny expr, Context context) throws Run @Override public RexNode visit(Expression.UserDefinedStruct expr, Context context) throws RuntimeException { - // Convert field types to Calcite types for the struct representation java.util.List fieldTypes = - expr.fields().stream() - .map(field -> typeConverter.toCalcite(typeFactory, field.getType())) - .collect(java.util.stream.Collectors.toList()); + new java.util.ArrayList<>(expr.fields().size()); + java.util.List fieldLiterals = new java.util.ArrayList<>(expr.fields().size()); + + for (Expression.Literal field : expr.fields()) { + fieldTypes.add(toStructFieldType(field)); + fieldLiterals.add(toStructFieldLiteral(field, context)); + } // Generate dummy field names (f0, f1, f2, etc.) to satisfy Calcite's ROW type requirements. // Substrait UserDefinedStruct doesn't have field names - just ordered field values. @@ -144,14 +147,35 @@ public RexNode visit(Expression.UserDefinedStruct expr, Context context) throws fieldTypes, fieldNames); - java.util.List fieldLiterals = - expr.fields().stream() - .map(field -> (RexLiteral) field.accept(this, context)) - .collect(java.util.stream.Collectors.toList()); - return rexBuilder.makeLiteral(fieldLiterals, customType, false); } + private org.apache.calcite.rel.type.RelDataType toStructFieldType(Expression.Literal field) { + if (field instanceof Expression.UserDefinedAny) { + io.substrait.type.Type.UserDefined userDefinedType = + (io.substrait.type.Type.UserDefined) field.getType(); + return io.substrait.isthmus.type.SubstraitUserDefinedType.SubstraitUserDefinedAnyType.from( + userDefinedType); + } + return typeConverter.toCalcite(typeFactory, field.getType()); + } + + private RexLiteral toStructFieldLiteral(Expression.Literal field, Context context) { + if (field instanceof Expression.UserDefinedAny) { + Expression.UserDefinedAny userDefinedAny = (Expression.UserDefinedAny) field; + org.apache.calcite.avatica.util.ByteString bytes = + new org.apache.calcite.avatica.util.ByteString(userDefinedAny.value().toByteArray()); + return rexBuilder.makeBinaryLiteral(bytes); + } + + RexNode rexField = field.accept(this, context); + if (!(rexField instanceof RexLiteral)) { + throw new IllegalArgumentException( + "Expected literal when converting UserDefinedStruct field but found " + rexField); + } + return (RexLiteral) rexField; + } + @Override public RexNode visit(Expression.BoolLiteral expr, Context context) throws RuntimeException { return rexBuilder.makeLiteral(expr.value()); diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/LiteralConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/LiteralConverter.java index 910b9682f..fd17b8860 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/LiteralConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/LiteralConverter.java @@ -1,6 +1,8 @@ package io.substrait.isthmus.expression; +import com.google.protobuf.Any; import com.google.protobuf.ByteString; +import com.google.protobuf.InvalidProtocolBufferException; import io.substrait.expression.Expression; import io.substrait.expression.ExpressionCreator; import io.substrait.isthmus.TypeConverter; @@ -14,13 +16,16 @@ import java.time.format.DateTimeFormatter; import java.time.format.DateTimeFormatterBuilder; import java.time.temporal.ChronoField; +import java.util.ArrayList; import java.util.List; import java.util.Objects; import java.util.Optional; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.rex.RexLiteral; +import org.apache.calcite.rex.RexNode; import org.apache.calcite.util.DateString; import org.apache.calcite.util.NlsString; import org.apache.calcite.util.TimeString; @@ -193,7 +198,14 @@ public Expression.Literal convert(RexLiteral literal) { case ROW: { - // Check if this is a SubstraitUserDefinedStructType + @SuppressWarnings("unchecked") + List fieldNodes = (List) literal.getValue(); + List relFields = literal.getType().getFieldList(); + ArrayList convertedFields = new ArrayList<>(fieldNodes.size()); + for (int i = 0; i < fieldNodes.size(); i++) { + convertedFields.add(convertStructField(fieldNodes.get(i), relFields.get(i).getType())); + } + if (literal.getType() instanceof io.substrait.isthmus.type.SubstraitUserDefinedType.SubstraitUserDefinedStructType) { @@ -202,19 +214,15 @@ public Expression.Literal convert(RexLiteral literal) { (io.substrait.isthmus.type.SubstraitUserDefinedType .SubstraitUserDefinedStructType) literal.getType(); - List literals = (List) literal.getValue(); return ExpressionCreator.userDefinedLiteralStruct( udtType.isNullable(), udtType.getUrn(), udtType.getName(), udtType.getTypeParameters(), - literals.stream().map(this::convert).collect(Collectors.toList())); + convertedFields); } - // Regular struct - List literals = (List) literal.getValue(); - return ExpressionCreator.struct( - n, literals.stream().map(this::convert).collect(Collectors.toList())); + return ExpressionCreator.struct(n, convertedFields); } case ARRAY: @@ -252,4 +260,58 @@ public static byte[] padRightIfNeeded(byte[] value, int length) { System.arraycopy(value, 0, newArray, 0, value.length); return newArray; } + + private Expression.Literal convertStructField(RexNode fieldNode, RelDataType expectedType) { + if (expectedType + instanceof io.substrait.isthmus.type.SubstraitUserDefinedType.SubstraitUserDefinedAnyType) { + return convertUserDefinedAnyStructField( + fieldNode, + (io.substrait.isthmus.type.SubstraitUserDefinedType.SubstraitUserDefinedAnyType) + expectedType); + } + + if (!(fieldNode instanceof RexLiteral)) { + throw new UnsupportedOperationException( + "Expected literal struct field but found " + fieldNode); + } + return convert((RexLiteral) fieldNode); + } + + private Expression.Literal convertUserDefinedAnyStructField( + RexNode fieldNode, + io.substrait.isthmus.type.SubstraitUserDefinedType.SubstraitUserDefinedAnyType expectedType) { + if (!(fieldNode instanceof RexLiteral)) { + throw new UnsupportedOperationException( + "Expected literal for UserDefinedAny struct field but found " + fieldNode); + } + + RexLiteral literal = (RexLiteral) fieldNode; + if (literal.isNull()) { + return ExpressionCreator.typedNull( + Type.UserDefined.builder() + .urn(expectedType.getUrn()) + .name(expectedType.getName()) + .typeParameters(expectedType.getTypeParameters()) + .nullable(true) + .build()); + } + + org.apache.calcite.avatica.util.ByteString bytes = + literal.getValueAs(org.apache.calcite.avatica.util.ByteString.class); + if (bytes == null) { + throw new IllegalArgumentException( + "Expected binary literal for UserDefinedAny struct field but value was null"); + } + try { + Any anyValue = Any.parseFrom(bytes.getBytes()); + return ExpressionCreator.userDefinedLiteralAny( + expectedType.isNullable(), + expectedType.getUrn(), + expectedType.getName(), + expectedType.getTypeParameters(), + anyValue); + } catch (InvalidProtocolBufferException e) { + throw new IllegalArgumentException("Failed to parse UserDefinedAny literal", e); + } + } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/UserDefinedTypeLiteralTest.java b/isthmus/src/test/java/io/substrait/isthmus/UserDefinedTypeLiteralTest.java index cd1608a94..5d2fd0054 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/UserDefinedTypeLiteralTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/UserDefinedTypeLiteralTest.java @@ -1,6 +1,7 @@ package io.substrait.isthmus; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; import com.google.protobuf.Any; import io.substrait.dsl.SubstraitBuilder; @@ -365,11 +366,63 @@ void userDefinedStructWithNestedStructProtoRoundtrip() { b.namedScan( List.of("example"), List.of("a"), - List.of(N.userDefined(DefaultExtensionCatalog.EXTENSION_TYPES, "line")))); + List.of(N.userDefined(DefaultExtensionCatalog.EXTENSION_TYPES, "point")))); verifyProtoRoundTrip(originalRel); } + @Test + void userDefinedStructWithNestedAnyCalciteRoundtrip() { + // Mix struct-encoded and Any-encoded fields inside a single UserDefinedStruct literal + Expression.Literal.Builder pointBuilder = Expression.Literal.newBuilder(); + Expression.Literal.Struct pointStructProto = + Expression.Literal.Struct.newBuilder() + .addFields(Expression.Literal.newBuilder().setI32(5)) + .addFields(Expression.Literal.newBuilder().setI32(10)) + .build(); + Any pointAny = Any.pack(pointBuilder.setStruct(pointStructProto).build()); + UserDefinedLiteral pointAnyLiteral = + ExpressionCreator.userDefinedLiteralAny( + false, + DefaultExtensionCatalog.EXTENSION_TYPES, + "point", + java.util.Collections.emptyList(), + pointAny); + + io.substrait.expression.Expression.UserDefinedStruct pointStructLiteral = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(DefaultExtensionCatalog.EXTENSION_TYPES) + .name("point") + .addFields(ExpressionCreator.i32(false, 20)) + .addFields(ExpressionCreator.i32(false, 30)) + .build(); + + io.substrait.expression.Expression.UserDefinedStruct line = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(DefaultExtensionCatalog.EXTENSION_TYPES) + .name("line") + .addFields(pointAnyLiteral) + .addFields(pointStructLiteral) + .build(); + + Rel originalRel = + b.project( + input -> List.of(line), + b.remap(1), + b.namedScan( + List.of("example"), + List.of("a"), + List.of(N.userDefined(DefaultExtensionCatalog.EXTENSION_TYPES, "point")))); + + assertCalciteRoundtrip( + originalRel, + substraitToCalcite, + calciteToSubstrait, + DefaultExtensionCatalog.DEFAULT_COLLECTION); + } + @Test void multipleUserDefinedStructDifferentStructuresProtoRoundtrip() { // Test multiple UserDefinedStruct types with different struct schemas @@ -423,6 +476,92 @@ void multipleUserDefinedStructDifferentStructuresProtoRoundtrip() { verifyProtoRoundTrip(originalRel); } + @Test + void sameUdTypeDifferentEncodingsCalciteRoundtrip() { + // Validate that "line" UDT survives Calcite roundtrip in both Any and Struct encodings + Expression.Literal.Builder lineBuilder = Expression.Literal.newBuilder(); + Expression.Literal.Builder pointBuilder = Expression.Literal.newBuilder(); + Expression.Literal.Struct pointStructProto = + Expression.Literal.Struct.newBuilder() + .addFields(pointBuilder.clear().setI32(5).build()) + .addFields(pointBuilder.clear().setI32(15).build()) + .build(); + Expression.Literal.Struct lineStructProto = + Expression.Literal.Struct.newBuilder() + .addFields(Expression.Literal.newBuilder().setStruct(pointStructProto).build()) + .addFields(Expression.Literal.newBuilder().setStruct(pointStructProto).build()) + .build(); + Any lineAnyValue = Any.pack(lineBuilder.setStruct(lineStructProto).build()); + UserDefinedLiteral lineAny = + ExpressionCreator.userDefinedLiteralAny( + false, + DefaultExtensionCatalog.EXTENSION_TYPES, + "line", + java.util.Collections.emptyList(), + lineAnyValue); + + io.substrait.expression.Expression.UserDefinedStruct startPoint = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(DefaultExtensionCatalog.EXTENSION_TYPES) + .name("point") + .addFields(ExpressionCreator.i32(false, 1)) + .addFields(ExpressionCreator.i32(false, 2)) + .build(); + + io.substrait.expression.Expression.UserDefinedStruct endPoint = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(DefaultExtensionCatalog.EXTENSION_TYPES) + .name("point") + .addFields(ExpressionCreator.i32(false, 3)) + .addFields(ExpressionCreator.i32(false, 4)) + .build(); + + io.substrait.expression.Expression.UserDefinedStruct lineStruct = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(DefaultExtensionCatalog.EXTENSION_TYPES) + .name("line") + .addFields(startPoint) + .addFields(endPoint) + .build(); + + Rel relWithAny = + b.project( + input -> List.of(lineAny), + b.remap(1), + b.namedScan( + List.of("example_any"), + List.of("a"), + List.of(N.userDefined(DefaultExtensionCatalog.EXTENSION_TYPES, "line")))); + Rel relWithStruct = + b.project( + input -> List.of(lineStruct), + b.remap(1), + b.namedScan( + List.of("example_struct"), + List.of("a"), + List.of(N.userDefined(DefaultExtensionCatalog.EXTENSION_TYPES, "line")))); + + Rel roundtrippedAny = calciteToSubstrait.apply(substraitToCalcite.convert(relWithAny)); + assertInstanceOf(io.substrait.relation.Project.class, roundtrippedAny); + io.substrait.relation.Project anyProject = (io.substrait.relation.Project) roundtrippedAny; + assertEquals(1, anyProject.getExpressions().size()); + assertInstanceOf( + io.substrait.expression.Expression.UserDefinedAny.class, + anyProject.getExpressions().get(0)); + + Rel roundtrippedStruct = calciteToSubstrait.apply(substraitToCalcite.convert(relWithStruct)); + assertInstanceOf(io.substrait.relation.Project.class, roundtrippedStruct); + io.substrait.relation.Project structProject = + (io.substrait.relation.Project) roundtrippedStruct; + assertEquals(1, structProject.getExpressions().size()); + assertInstanceOf( + io.substrait.expression.Expression.UserDefinedStruct.class, + structProject.getExpressions().get(0)); + } + @Test void intermixedUserDefinedAnyAndStructProtoRoundtrip() { // Test intermixing UserDefinedAny and UserDefinedStruct in the same query diff --git a/isthmus/src/test/java/io/substrait/isthmus/utils/UserTypeFactory.java b/isthmus/src/test/java/io/substrait/isthmus/utils/UserTypeFactory.java index 0b10a8134..d230f466d 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/utils/UserTypeFactory.java +++ b/isthmus/src/test/java/io/substrait/isthmus/utils/UserTypeFactory.java @@ -1,5 +1,6 @@ package io.substrait.isthmus.utils; +import io.substrait.isthmus.type.SubstraitUserDefinedType; import io.substrait.type.Type; import io.substrait.type.TypeCreator; import org.apache.calcite.rel.type.RelDataType; @@ -35,21 +36,47 @@ public Type createSubstrait(boolean nullable) { return TypeCreator.of(nullable).userDefined(urn, name); } + /** + * Test-specific variant of the core implementation that treats Calcite-copied types as + * equivalent, even when they are not the same Java instance. Calcite often clones UDTs during + * planning, so reference equality alone would fail in tests. + */ public boolean isTypeFromFactory(RelDataType type) { - // Use value-based comparison instead of reference equality to handle - // cases where the same type is created by different factory instances - if (type == N || type == R) { - return true; - } - // Check if this is a type with the same name and SqlTypeName.OTHER - if (type != null - && type.getSqlTypeName() == SqlTypeName.OTHER - && type.toString().equals(this.name)) { - return true; + return matchesSubstraitType(type) || type == N || type == R || matchesCalciteAlias(type); + } + + /** + * Detects Substrait-backed Calcite types by interrogating their metadata. + * + *

    If Calcite preserves the original {@link SubstraitUserDefinedType}, the urn/name are both + * available directly. Otherwise, Calcite may create an anonymous {@link RelDataTypeImpl} copy, + * exposing only its alias string. In that case we fall back to comparing the formatted alias (see + * {@link InnerType#generateTypeString}). + */ + private boolean matchesSubstraitType(RelDataType type) { + if (type instanceof SubstraitUserDefinedType) { + SubstraitUserDefinedType udt = (SubstraitUserDefinedType) type; + return this.urn.equals(udt.getUrn()) && this.name.equals(udt.getName()); } return false; } + /** + * Calcite may copy a user-defined type into an anonymous {@link RelDataTypeImpl} where the only + * identifier left is its alias string. This helper captures the "find by alias" fallback so it’s + * clear we’re matching against the formatted urn::name when the rich metadata is not + * available. + */ + private boolean matchesCalciteAlias(RelDataType type) { + return type != null + && (type.getSqlTypeName() == SqlTypeName.OTHER || type.getSqlTypeName() == SqlTypeName.ROW) + && type.toString().equals(calciteDisplayName()); + } + + private String calciteDisplayName() { + return String.format("%s::%s", this.urn, this.name); + } + private class InnerType extends RelDataTypeImpl { private final boolean nullable; private final String name; @@ -72,8 +99,6 @@ public SqlTypeName getSqlTypeName() { @Override protected void generateTypeString(StringBuilder sb, boolean withDetail) { - // Include URN in type string to ensure types with same name but different URNs - // are not considered equal by Calcite's type system sb.append(UserTypeFactory.this.urn).append("::").append(name); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/utils/UserTypeFactoryTest.java b/isthmus/src/test/java/io/substrait/isthmus/utils/UserTypeFactoryTest.java new file mode 100644 index 000000000..df1c2dcd0 --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/utils/UserTypeFactoryTest.java @@ -0,0 +1,35 @@ +package io.substrait.isthmus.utils; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import io.substrait.isthmus.type.SubstraitUserDefinedType; +import io.substrait.type.Type; +import org.apache.calcite.rel.type.RelDataType; +import org.junit.jupiter.api.Test; + +class UserTypeFactoryTest { + + private static final String URN = "extension:io.substrait:test"; + private static final String NAME = "custom_type"; + + @Test + void detectsSubstraitUserDefinedType() { + UserTypeFactory factory = new UserTypeFactory(URN, NAME); + RelDataType substraitType = + SubstraitUserDefinedType.from( + Type.UserDefined.builder().nullable(true).urn(URN).name(NAME).build()); + + assertTrue(factory.isTypeFromFactory(substraitType)); + } + + @Test + void rejectsDifferentUrnOrName() { + UserTypeFactory factory = new UserTypeFactory(URN, NAME); + RelDataType differentType = + SubstraitUserDefinedType.from( + Type.UserDefined.builder().nullable(true).urn(URN).name("other").build()); + + assertFalse(factory.isTypeFromFactory(differentType)); + } +} From 3c3b63f8bffd68d391f170385055ddf2bc5dfaff Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Thu, 20 Nov 2025 14:41:32 -0500 Subject: [PATCH 11/11] tweak: more adjustments --- .../expression/ExpressionRexConverter.java | 55 ++++++++++++++++--- .../type/SubstraitUserDefinedType.java | 39 ++++++++----- .../isthmus/UserDefinedTypeLiteralTest.java | 44 +++++++++++++++ .../type/SubstraitUserDefinedTypeTest.java | 42 ++++++++++++++ 4 files changed, 157 insertions(+), 23 deletions(-) create mode 100644 isthmus/src/test/java/io/substrait/isthmus/type/SubstraitUserDefinedTypeTest.java diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java index f8488e8af..9d7ea630b 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java @@ -121,18 +121,19 @@ public RexNode visit(Expression.UserDefinedAny expr, Context context) throws Run @Override public RexNode visit(Expression.UserDefinedStruct expr, Context context) throws RuntimeException { + return toUserDefinedStructLiteral(expr, context); + } + + private RexLiteral toUserDefinedStructLiteral(Expression.UserDefinedStruct expr, Context context) { java.util.List fieldTypes = new java.util.ArrayList<>(expr.fields().size()); java.util.List fieldLiterals = new java.util.ArrayList<>(expr.fields().size()); for (Expression.Literal field : expr.fields()) { fieldTypes.add(toStructFieldType(field)); - fieldLiterals.add(toStructFieldLiteral(field, context)); + fieldLiterals.add(literalToRexLiteral(field, context)); } - // Generate dummy field names (f0, f1, f2, etc.) to satisfy Calcite's ROW type requirements. - // Substrait UserDefinedStruct doesn't have field names - just ordered field values. - // These synthetic names are discarded during conversion back to Substrait. java.util.List fieldNames = java.util.stream.IntStream.range(0, expr.fields().size()) .mapToObj(i -> "f" + i) @@ -147,7 +148,7 @@ public RexNode visit(Expression.UserDefinedStruct expr, Context context) throws fieldTypes, fieldNames); - return rexBuilder.makeLiteral(fieldLiterals, customType, false); + return (RexLiteral) rexBuilder.makeLiteral(fieldLiterals, customType, false); } private org.apache.calcite.rel.type.RelDataType toStructFieldType(Expression.Literal field) { @@ -161,14 +162,52 @@ private org.apache.calcite.rel.type.RelDataType toStructFieldType(Expression.Lit } private RexLiteral toStructFieldLiteral(Expression.Literal field, Context context) { - if (field instanceof Expression.UserDefinedAny) { - Expression.UserDefinedAny userDefinedAny = (Expression.UserDefinedAny) field; + return literalToRexLiteral(field, context); + } + + private RexLiteral literalToRexLiteral(Expression.Literal literal, Context context) { + if (literal instanceof Expression.UserDefinedAny) { + Expression.UserDefinedAny userDefinedAny = (Expression.UserDefinedAny) literal; org.apache.calcite.avatica.util.ByteString bytes = new org.apache.calcite.avatica.util.ByteString(userDefinedAny.value().toByteArray()); return rexBuilder.makeBinaryLiteral(bytes); } - RexNode rexField = field.accept(this, context); + if (literal instanceof Expression.UserDefinedStruct) { + return toUserDefinedStructLiteral((Expression.UserDefinedStruct) literal, context); + } + + if (literal instanceof Expression.StructLiteral) { + java.util.List fieldValues = + new java.util.ArrayList<>(((Expression.StructLiteral) literal).fields().size()); + for (Expression.Literal child : ((Expression.StructLiteral) literal).fields()) { + fieldValues.add(literalToRexLiteral(child, context)); + } + return (RexLiteral) + rexBuilder.makeLiteral( + fieldValues, typeConverter.toCalcite(typeFactory, literal.getType()), false); + } + + if (literal instanceof Expression.ListLiteral) { + java.util.List elements = + new java.util.ArrayList<>(((Expression.ListLiteral) literal).values().size()); + for (Expression.Literal child : ((Expression.ListLiteral) literal).values()) { + elements.add(literalToRexLiteral(child, context)); + } + return (RexLiteral) + rexBuilder.makeLiteral( + elements, typeConverter.toCalcite(typeFactory, literal.getType()), false); + } + + if (literal instanceof Expression.EmptyListLiteral) { + return (RexLiteral) + rexBuilder.makeLiteral( + java.util.Collections.emptyList(), + typeConverter.toCalcite(typeFactory, literal.getType()), + false); + } + + RexNode rexField = literal.accept(this, context); if (!(rexField instanceof RexLiteral)) { throw new IllegalArgumentException( "Expected literal when converting UserDefinedStruct field but found " + rexField); diff --git a/isthmus/src/main/java/io/substrait/isthmus/type/SubstraitUserDefinedType.java b/isthmus/src/main/java/io/substrait/isthmus/type/SubstraitUserDefinedType.java index 2ce33eb5d..aa155a476 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/type/SubstraitUserDefinedType.java +++ b/isthmus/src/main/java/io/substrait/isthmus/type/SubstraitUserDefinedType.java @@ -1,7 +1,9 @@ package io.substrait.isthmus.type; +import com.google.protobuf.TextFormat; import io.substrait.type.Type; import java.util.List; +import java.util.stream.Collectors; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.rel.type.RelDataTypeFieldImpl; @@ -85,14 +87,29 @@ public static SubstraitUserDefinedType from(io.substrait.type.Type.UserDefined t @Override protected void generateTypeString(StringBuilder sb, boolean withDetail) { - // Include URN in type string to ensure types with same name but different URNs - // are not considered equal by Calcite's type system + appendDigest(sb); + } + + protected void appendDigest(StringBuilder sb) { sb.append(urn).append("::").append(name); - if (!typeParameters.isEmpty()) { - sb.append("<"); - sb.append(String.join(", ", java.util.Collections.nCopies(typeParameters.size(), "_"))); - sb.append(">"); + appendTypeParameters(sb, typeParameters); + } + + private static void appendTypeParameters( + StringBuilder sb, java.util.List parameters) { + if (parameters.isEmpty()) { + return; } + sb.append("<"); + sb.append( + parameters.stream() + .map(SubstraitUserDefinedType::formatParameter) + .collect(Collectors.joining(","))); + sb.append(">"); + } + + private static String formatParameter(io.substrait.proto.Type.Parameter parameter) { + return TextFormat.shortDebugString(parameter); } /** @@ -208,15 +225,7 @@ public SqlTypeName getSqlTypeName() { @Override protected void generateTypeString(StringBuilder sb, boolean withDetail) { - // Include URN in type string to ensure types with same name but different URNs - // are not considered equal by Calcite's type system - sb.append(getUrn()).append("::").append(getName()); - if (!getTypeParameters().isEmpty()) { - sb.append("<"); - sb.append( - String.join(", ", java.util.Collections.nCopies(getTypeParameters().size(), "_"))); - sb.append(">"); - } + appendDigest(sb); if (withDetail && fieldNames != null) { sb.append("("); sb.append( diff --git a/isthmus/src/test/java/io/substrait/isthmus/UserDefinedTypeLiteralTest.java b/isthmus/src/test/java/io/substrait/isthmus/UserDefinedTypeLiteralTest.java index 5d2fd0054..adf96ec7a 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/UserDefinedTypeLiteralTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/UserDefinedTypeLiteralTest.java @@ -256,6 +256,50 @@ void singleUserDefinedStructCalciteRoundtrip() { } @Test + void nestedUserDefinedStructCalciteRoundtrip() { + io.substrait.expression.Expression.UserDefinedStruct startPoint = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(DefaultExtensionCatalog.EXTENSION_TYPES) + .name("point") + .addFields(ExpressionCreator.i32(false, 5)) + .addFields(ExpressionCreator.i32(false, 15)) + .build(); + + io.substrait.expression.Expression.UserDefinedStruct endPoint = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(DefaultExtensionCatalog.EXTENSION_TYPES) + .name("point") + .addFields(ExpressionCreator.i32(false, 25)) + .addFields(ExpressionCreator.i32(false, 35)) + .build(); + + io.substrait.expression.Expression.UserDefinedStruct lineStructLit = + io.substrait.expression.Expression.UserDefinedStruct.builder() + .nullable(false) + .urn(DefaultExtensionCatalog.EXTENSION_TYPES) + .name("line") + .addFields(startPoint) + .addFields(endPoint) + .build(); + + Rel originalRel = + b.project( + input -> List.of(lineStructLit), + b.remap(1), + b.namedScan( + List.of("example"), + List.of("a"), + List.of(N.userDefined(DefaultExtensionCatalog.EXTENSION_TYPES, "line")))); + + assertCalciteRoundtrip( + originalRel, + substraitToCalcite, + calciteToSubstrait, + DefaultExtensionCatalog.DEFAULT_COLLECTION); + } + void multipleDifferentUserDefinedAnyTypesCalciteRoundtrip() { // Test that multiple UserDefinedAny literals with different types can roundtrip through Calcite Expression.Literal.Builder bldr1 = Expression.Literal.newBuilder(); diff --git a/isthmus/src/test/java/io/substrait/isthmus/type/SubstraitUserDefinedTypeTest.java b/isthmus/src/test/java/io/substrait/isthmus/type/SubstraitUserDefinedTypeTest.java new file mode 100644 index 000000000..105d30091 --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/type/SubstraitUserDefinedTypeTest.java @@ -0,0 +1,42 @@ +package io.substrait.isthmus.type; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; + +import io.substrait.proto.Type; +import java.util.List; +import org.junit.jupiter.api.Test; + +class SubstraitUserDefinedTypeTest { + private static final String URN = "extension:io.substrait:test"; + private static final String NAME = "custom"; + + @Test + void differentTypeParametersProduceDifferentDigests() { + Type.Parameter integerParam = Type.Parameter.newBuilder().setInteger(1).build(); + Type.Parameter enumParam = Type.Parameter.newBuilder().setEnum("value").build(); + + SubstraitUserDefinedType typeWithInteger = + new SubstraitUserDefinedType.SubstraitUserDefinedAnyType( + URN, NAME, List.of(integerParam), false); + SubstraitUserDefinedType typeWithEnum = + new SubstraitUserDefinedType.SubstraitUserDefinedAnyType( + URN, NAME, List.of(enumParam), false); + + assertNotEquals(typeWithInteger, typeWithEnum); + assertNotEquals(typeWithInteger.toString(), typeWithEnum.toString()); + } + + @Test + void sameParametersRemainEqual() { + Type.Parameter integerParam = Type.Parameter.newBuilder().setInteger(7).build(); + SubstraitUserDefinedType left = + new SubstraitUserDefinedType.SubstraitUserDefinedAnyType( + URN, NAME, List.of(integerParam), true); + SubstraitUserDefinedType right = + new SubstraitUserDefinedType.SubstraitUserDefinedAnyType( + URN, NAME, List.of(integerParam), true); + + assertEquals(left, right); + } +}