diff --git a/core/src/test/java/io/substrait/TestBase.java b/core/src/test/java/io/substrait/TestBase.java index 3defbf78f..8307785a1 100644 --- a/core/src/test/java/io/substrait/TestBase.java +++ b/core/src/test/java/io/substrait/TestBase.java @@ -2,6 +2,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals; +import com.google.common.base.Charsets; +import com.google.common.io.Resources; import io.substrait.dsl.SubstraitBuilder; import io.substrait.extension.DefaultExtensionCatalog; import io.substrait.extension.ExtensionCollector; @@ -10,24 +12,38 @@ import io.substrait.relation.Rel; import io.substrait.relation.RelProtoConverter; import io.substrait.type.TypeCreator; +import java.io.IOException; public abstract class TestBase { - protected static final SimpleExtension.ExtensionCollection defaultExtensionCollection = - DefaultExtensionCatalog.DEFAULT_COLLECTION; + protected static final TypeCreator R = TypeCreator.REQUIRED; + protected static final TypeCreator N = TypeCreator.NULLABLE; - protected TypeCreator R = TypeCreator.REQUIRED; - protected TypeCreator N = TypeCreator.NULLABLE; + protected final SimpleExtension.ExtensionCollection extensions; - protected SubstraitBuilder b = new SubstraitBuilder(defaultExtensionCollection); protected ExtensionCollector functionCollector = new ExtensionCollector(); protected RelProtoConverter relProtoConverter = new RelProtoConverter(functionCollector); - protected ProtoRelConverter protoRelConverter = - new ProtoRelConverter(functionCollector, defaultExtensionCollection); + + protected SubstraitBuilder sb; + protected ProtoRelConverter protoRelConverter; + + protected TestBase() { + this(DefaultExtensionCatalog.DEFAULT_COLLECTION); + } + + protected TestBase(SimpleExtension.ExtensionCollection extensions) { + this.extensions = extensions; + this.sb = new SubstraitBuilder(extensions); + this.protoRelConverter = new ProtoRelConverter(functionCollector, extensions); + } protected void verifyRoundTrip(Rel rel) { io.substrait.proto.Rel protoRel = relProtoConverter.toProto(rel); Rel relReturned = protoRelConverter.from(protoRel); assertEquals(rel, relReturned); } + + public static String asString(String resource) throws IOException { + return Resources.toString(Resources.getResource(resource), Charsets.UTF_8); + } } diff --git a/core/src/test/java/io/substrait/expression/VariadicParameterConsistencyTest.java b/core/src/test/java/io/substrait/expression/VariadicParameterConsistencyTest.java index 46fc5041e..a25a9009c 100644 --- a/core/src/test/java/io/substrait/expression/VariadicParameterConsistencyTest.java +++ b/core/src/test/java/io/substrait/expression/VariadicParameterConsistencyTest.java @@ -3,19 +3,17 @@ import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertThrows; +import io.substrait.TestBase; import io.substrait.extension.ImmutableSimpleExtension; import io.substrait.extension.SimpleExtension; import io.substrait.function.ParameterizedType; -import io.substrait.type.TypeCreator; import java.util.Collections; import java.util.List; import java.util.Optional; import org.junit.jupiter.api.Test; /** Tests for variadic parameter consistency validation in Expression. */ -class VariadicParameterConsistencyTest { - - private static final TypeCreator R = TypeCreator.of(false); +class VariadicParameterConsistencyTest extends TestBase { /** * Helper method to create a ScalarFunctionInvocation and test if it validates correctly. The diff --git a/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionRoundTripTest.java b/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionRoundTripTest.java index b84ef8bd2..9b6ae95a8 100644 --- a/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionRoundTripTest.java +++ b/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionRoundTripTest.java @@ -88,7 +88,7 @@ private static ImmutableExpressionReference getFieldReferenceExpression() { private static ImmutableExpressionReference getScalarFunctionExpression() { Expression.ScalarFunctionInvocation scalarFunctionInvocation = - new SubstraitBuilder(defaultExtensionCollection) + new SubstraitBuilder(DefaultExtensionCatalog.DEFAULT_COLLECTION) .scalarFn( DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC_DECIMAL, "add:dec_dec", @@ -111,8 +111,9 @@ private static ImmutableAggregateFunctionReference getAggregateFunctionReference .function( AggregateFunctionInvocation.builder() .arguments(Collections.emptyList()) - .declaration(defaultExtensionCollection.aggregateFunctions().get(0)) - .outputType(TypeCreator.of(false).I64) + .declaration( + DefaultExtensionCatalog.DEFAULT_COLLECTION.aggregateFunctions().get(0)) + .outputType(R.I64) .aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT) .invocation(Expression.AggregationInvocation.ALL) .build()) diff --git a/core/src/test/java/io/substrait/extension/AdvancedExtensionRelProtoConversionTest.java b/core/src/test/java/io/substrait/extension/AdvancedExtensionRelProtoConversionTest.java index c217d078e..dec329ae0 100644 --- a/core/src/test/java/io/substrait/extension/AdvancedExtensionRelProtoConversionTest.java +++ b/core/src/test/java/io/substrait/extension/AdvancedExtensionRelProtoConversionTest.java @@ -2,7 +2,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; -import io.substrait.dsl.SubstraitBuilder; +import io.substrait.TestBase; import io.substrait.expression.Expression; import io.substrait.expression.Expression.SortDirection; import io.substrait.expression.ExpressionCreator; @@ -48,9 +48,7 @@ import io.substrait.utils.StringHolderHandlingProtoExtensionConverter; import org.junit.jupiter.api.Test; -class AdvancedExtensionRelProtoConversionTest { - final SubstraitBuilder builder = new SubstraitBuilder(DefaultExtensionCatalog.DEFAULT_COLLECTION); - +class AdvancedExtensionRelProtoConversionTest extends TestBase { final StringHolder enhanced = new StringHolder("ENHANCED"); final StringHolder optimized = new StringHolder("OPTIMIZED"); final AdvancedExtension extension = @@ -168,7 +166,7 @@ void testAggregateRelConversionRoundtrip() throws Exception { .initialSchema( NamedStruct.builder().struct(TypeCreator.REQUIRED.struct()).build()) .build()) - .addMeasures(builder.countStar()) + .addMeasures(sb.countStar()) .extension(extension) .build(); @@ -191,7 +189,7 @@ void testSortRelConversionRoundtrip() throws Exception { .addSortFields( SortField.builder() .direction(SortDirection.ASC_NULLS_FIRST) - .expr(builder.fieldReference(scan, 0)) + .expr(sb.fieldReference(scan, 0)) .build()) .extension(extension) .build(); @@ -214,8 +212,7 @@ void testJoinRelConversionRoundtrip() throws Exception { .left(scan) .right(scan) .joinType(JoinType.INNER) - .condition( - builder.equal(builder.fieldReference(scan, 0), builder.fieldReference(scan, 0))) + .condition(sb.equal(sb.fieldReference(scan, 0), sb.fieldReference(scan, 0))) .extension(extension) .build(); @@ -235,7 +232,7 @@ void testProjectRelConversionRoundtrip() throws Exception { final Project rel = Project.builder() .input(scan) - .addExpressions(builder.fieldReference(scan, 0)) + .addExpressions(sb.fieldReference(scan, 0)) .extension(extension) .build(); @@ -384,12 +381,12 @@ void testNamedUpdateRelConversionRoundtrip() throws Exception { .addNames("CUSTOMER") .tableSchema(schema) .condition( - builder.equal( + sb.equal( FieldReference.builder() .addSegments(FieldReference.StructField.of(0)) .type(TypeCreator.REQUIRED.BOOLEAN) .build(), - builder.bool(true))) + sb.bool(true))) .addTransformations( TransformExpression.builder() .columnTarget(0) diff --git a/core/src/test/java/io/substrait/extension/TypeExtensionTest.java b/core/src/test/java/io/substrait/extension/TypeExtensionTest.java index abb4008d5..4b3b94a80 100644 --- a/core/src/test/java/io/substrait/extension/TypeExtensionTest.java +++ b/core/src/test/java/io/substrait/extension/TypeExtensionTest.java @@ -3,13 +3,13 @@ import static io.substrait.type.TypeCreator.REQUIRED; import static org.junit.jupiter.api.Assertions.assertEquals; -import io.substrait.dsl.SubstraitBuilder; +import io.substrait.TestBase; import io.substrait.plan.Plan; import io.substrait.plan.PlanProtoConverter; import io.substrait.plan.ProtoPlanConverter; import io.substrait.type.Type; -import io.substrait.type.TypeCreator; -import java.io.InputStream; +import java.io.IOException; +import java.io.UncheckedIOException; import java.util.List; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -24,24 +24,29 @@ *
  • Roundtrip between POJO and Proto * */ -class TypeExtensionTest { - - static final TypeCreator R = TypeCreator.of(false); +class TypeExtensionTest extends TestBase { static final String URN = "extension:test:custom_extensions"; - final SimpleExtension.ExtensionCollection extensionCollection; + static final SimpleExtension.ExtensionCollection CUSTOM_EXTENSION; - { - String path = "/extensions/custom_extensions.yaml"; - InputStream inputStream = this.getClass().getResourceAsStream(path); - extensionCollection = SimpleExtension.load(path, inputStream); + static { + try { + String customExtensionStr = asString("extensions/custom_extensions.yaml"); + CUSTOM_EXTENSION = SimpleExtension.load(URN, customExtensionStr); + } catch (IOException e) { + throw new UncheckedIOException(e); + } } - final SubstraitBuilder b = new SubstraitBuilder(extensionCollection); - Type customType1 = b.userDefinedType(URN, "customType1"); - Type customType2 = b.userDefinedType(URN, "customType2"); + Type customType1 = sb.userDefinedType(URN, "customType1"); + Type customType2 = sb.userDefinedType(URN, "customType2"); final PlanProtoConverter planProtoConverter = new PlanProtoConverter(); - final ProtoPlanConverter protoPlanConverter = new ProtoPlanConverter(extensionCollection); + final ProtoPlanConverter protoPlanConverter; + + TypeExtensionTest() { + super(CUSTOM_EXTENSION); + this.protoPlanConverter = new ProtoPlanConverter(extensions); + } @Test void roundtripCustomType() { @@ -49,26 +54,26 @@ void roundtripCustomType() { List tableName = Stream.of("example").collect(Collectors.toList()); List columnNames = Stream.of("custom_type_column", "i64_column").collect(Collectors.toList()); - List types = Stream.of(customType1, R.I64).collect(Collectors.toList()); + List types = Stream.of(customType1, R.I64).collect(Collectors.toList()); // SELECT custom_type_column, scalar1(custom_type_column), scalar2(i64_column) // FROM example Plan plan = - b.plan( - b.root( - b.project( + sb.plan( + sb.root( + sb.project( input -> Stream.of( - b.fieldReference(input, 0), - b.scalarFn( + sb.fieldReference(input, 0), + sb.scalarFn( URN, "scalar1:u!customType1", R.I64, - b.fieldReference(input, 0)), - b.scalarFn( - URN, "scalar2:i64", customType2, b.fieldReference(input, 1))) + sb.fieldReference(input, 0)), + sb.scalarFn( + URN, "scalar2:i64", customType2, sb.fieldReference(input, 1))) .collect(Collectors.toList()), - b.namedScan(tableName, columnNames, types)))); + sb.namedScan(tableName, columnNames, types)))); io.substrait.proto.Plan protoPlan = planProtoConverter.toProto(plan); Plan planReturned = protoPlanConverter.from(protoPlan); @@ -80,19 +85,21 @@ void roundtripNumberedAnyTypes() { List tableName = Stream.of("example").collect(Collectors.toList()); List columnNames = Stream.of("array_i64_type_column", "array_i64_column").collect(Collectors.toList()); - List types = - Stream.of(REQUIRED.list(R.I64)).collect(Collectors.toList()); + List types = Stream.of(REQUIRED.list(R.I64)).collect(Collectors.toList()); Plan plan = - b.plan( - b.root( - b.project( + sb.plan( + sb.root( + sb.project( input -> Stream.of( - b.scalarFn( - URN, "array_index:list_i64", R.I64, b.fieldReference(input, 0))) + sb.scalarFn( + URN, + "array_index:list_i64", + R.I64, + sb.fieldReference(input, 0))) .collect(Collectors.toList()), - b.namedScan(tableName, columnNames, types)))); + sb.namedScan(tableName, columnNames, types)))); io.substrait.proto.Plan protoPlan = planProtoConverter.toProto(plan); Plan planReturned = protoPlanConverter.from(protoPlan); assertEquals(plan, planReturned); diff --git a/core/src/test/java/io/substrait/relation/ProtoRelConverterTest.java b/core/src/test/java/io/substrait/relation/ProtoRelConverterTest.java index 3876fd2ad..d84f06234 100644 --- a/core/src/test/java/io/substrait/relation/ProtoRelConverterTest.java +++ b/core/src/test/java/io/substrait/relation/ProtoRelConverterTest.java @@ -25,7 +25,7 @@ class ProtoRelConverterTest extends TestBase { final NamedScan commonTable = - b.namedScan(Collections.emptyList(), Collections.emptyList(), Collections.emptyList()); + sb.namedScan(Collections.emptyList(), Collections.emptyList(), Collections.emptyList()); /** * Verify default behaviour of {@link ProtoRelConverter} in the presence of {@link @@ -121,9 +121,7 @@ functionCollector, new StringHolderHandlingExtensionProtoConverter()) final Rel relFromProto = new ProtoRelConverter( - functionCollector, - defaultExtensionCollection, - new StringHolderHandlingProtoExtensionConverter()) + functionCollector, extensions, new StringHolderHandlingProtoExtensionConverter()) .from(protoRel); assertEquals(rel, relFromProto); @@ -140,9 +138,7 @@ functionCollector, new StringHolderHandlingExtensionProtoConverter()) final Rel relFromProto = new ProtoRelConverter( - functionCollector, - defaultExtensionCollection, - new StringHolderHandlingProtoExtensionConverter()) + functionCollector, extensions, new StringHolderHandlingProtoExtensionConverter()) .from(protoRel); assertEquals(rel, relFromProto); diff --git a/core/src/test/java/io/substrait/type/proto/AggregateRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/AggregateRoundtripTest.java index 01276fd4a..6d32e9301 100644 --- a/core/src/test/java/io/substrait/type/proto/AggregateRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/AggregateRoundtripTest.java @@ -13,7 +13,6 @@ import io.substrait.relation.RelProtoConverter; import io.substrait.relation.VirtualTableScan; import io.substrait.type.NamedStruct; -import io.substrait.type.TypeCreator; import java.math.BigDecimal; import java.util.Arrays; import java.util.Collections; @@ -33,8 +32,6 @@ private void assertAggregateRoundtrip(Expression.AggregationInvocation invocatio ExtensionCollector functionCollector = new ExtensionCollector(); RelProtoConverter to = new RelProtoConverter(functionCollector); - io.substrait.extension.SimpleExtension.ExtensionCollection extensions = - defaultExtensionCollection; ProtoRelConverter from = new ProtoRelConverter(functionCollector, extensions); io.substrait.relation.ImmutableMeasure measure = @@ -43,7 +40,7 @@ private void assertAggregateRoundtrip(Expression.AggregationInvocation invocatio AggregateFunctionInvocation.builder() .arguments(Collections.emptyList()) .declaration(extensions.aggregateFunctions().get(0)) - .outputType(TypeCreator.of(false).I64) + .outputType(R.I64) .aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT) .invocation(invocation) .options( @@ -56,7 +53,7 @@ private void assertAggregateRoundtrip(Expression.AggregationInvocation invocatio Arrays.asList( Expression.SortField.builder() // SORT BY decimal - .expr(b.fieldReference(input, 0)) + .expr(sb.fieldReference(input, 0)) .direction(Expression.SortDirection.ASC_NULLS_LAST) .build())) .build()) diff --git a/core/src/test/java/io/substrait/type/proto/ConsistentPartitionWindowRelRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/ConsistentPartitionWindowRelRoundtripTest.java index afd8a494d..bf8639380 100644 --- a/core/src/test/java/io/substrait/type/proto/ConsistentPartitionWindowRelRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/ConsistentPartitionWindowRelRoundtripTest.java @@ -18,11 +18,11 @@ class ConsistentPartitionWindowRelRoundtripTest extends TestBase { @Test void consistentPartitionWindowRoundtripSingle() { SimpleExtension.WindowFunctionVariant windowFunctionDeclaration = - defaultExtensionCollection.getWindowFunction( + extensions.getWindowFunction( SimpleExtension.FunctionAnchor.of( DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, "lead:any")); Rel input = - b.namedScan( + sb.namedScan( Arrays.asList("test"), Arrays.asList("a", "b", "c"), Arrays.asList(R.I64, R.I16, R.I32)); @@ -34,7 +34,7 @@ void consistentPartitionWindowRoundtripSingle() { ConsistentPartitionWindow.WindowRelFunctionInvocation.builder() .declaration(windowFunctionDeclaration) // lead(a) - .arguments(Arrays.asList(b.fieldReference(input, 0))) + .arguments(Arrays.asList(sb.fieldReference(input, 0))) .options( Arrays.asList( FunctionOption.builder() @@ -49,12 +49,12 @@ void consistentPartitionWindowRoundtripSingle() { .boundsType(Expression.WindowBoundsType.RANGE) .build())) // PARTITION BY b - .partitionExpressions(Arrays.asList(b.fieldReference(input, 1))) + .partitionExpressions(Arrays.asList(sb.fieldReference(input, 1))) .sorts( Arrays.asList( Expression.SortField.builder() // SORT BY c - .expr(b.fieldReference(input, 2)) + .expr(sb.fieldReference(input, 2)) .direction(Expression.SortDirection.ASC_NULLS_FIRST) .build())) .build(); @@ -71,15 +71,15 @@ void consistentPartitionWindowRoundtripSingle() { @Test void consistentPartitionWindowRoundtripMulti() { SimpleExtension.WindowFunctionVariant windowFunctionLeadDeclaration = - defaultExtensionCollection.getWindowFunction( + extensions.getWindowFunction( SimpleExtension.FunctionAnchor.of( DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, "lead:any")); SimpleExtension.WindowFunctionVariant windowFunctionLagDeclaration = - defaultExtensionCollection.getWindowFunction( + extensions.getWindowFunction( SimpleExtension.FunctionAnchor.of( DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, "lead:any")); Rel input = - b.namedScan( + sb.namedScan( Arrays.asList("test"), Arrays.asList("a", "b", "c"), Arrays.asList(R.I64, R.I16, R.I32)); @@ -91,7 +91,7 @@ void consistentPartitionWindowRoundtripMulti() { ConsistentPartitionWindow.WindowRelFunctionInvocation.builder() .declaration(windowFunctionLeadDeclaration) // lead(a) - .arguments(Arrays.asList(b.fieldReference(input, 0))) + .arguments(Arrays.asList(sb.fieldReference(input, 0))) .options( Arrays.asList( FunctionOption.builder() @@ -108,7 +108,7 @@ void consistentPartitionWindowRoundtripMulti() { ConsistentPartitionWindow.WindowRelFunctionInvocation.builder() .declaration(windowFunctionLagDeclaration) // lag(a) - .arguments(Arrays.asList(b.fieldReference(input, 0))) + .arguments(Arrays.asList(sb.fieldReference(input, 0))) .options( Arrays.asList( FunctionOption.builder() @@ -123,12 +123,12 @@ void consistentPartitionWindowRoundtripMulti() { .boundsType(Expression.WindowBoundsType.RANGE) .build())) // PARTITION BY b - .partitionExpressions(Arrays.asList(b.fieldReference(input, 1))) + .partitionExpressions(Arrays.asList(sb.fieldReference(input, 1))) .sorts( Arrays.asList( Expression.SortField.builder() // SORT BY c - .expr(b.fieldReference(input, 2)) + .expr(sb.fieldReference(input, 2)) .direction(Expression.SortDirection.ASC_NULLS_FIRST) .build())) .build(); diff --git a/core/src/test/java/io/substrait/type/proto/DdlRelRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/DdlRelRoundtripTest.java index 82a816db0..5eb0c3100 100644 --- a/core/src/test/java/io/substrait/type/proto/DdlRelRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/DdlRelRoundtripTest.java @@ -44,7 +44,7 @@ void create() { @Test void alter() { ProtoRelConverter protoRelConverter = - new StringHolderHandlingProtoRelConverter(functionCollector, defaultExtensionCollection); + new StringHolderHandlingProtoRelConverter(functionCollector, extensions); StringHolder detail = new StringHolder("DETAIL"); diff --git a/core/src/test/java/io/substrait/type/proto/ExchangeRelRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/ExchangeRelRoundtripTest.java index 95b06b8fc..36e757289 100644 --- a/core/src/test/java/io/substrait/type/proto/ExchangeRelRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/ExchangeRelRoundtripTest.java @@ -17,7 +17,7 @@ class ExchangeRelRoundtripTest extends TestBase { final Rel baseTable = - b.namedScan( + sb.namedScan( Collections.singletonList("exchange_test_table"), Arrays.asList("id", "amount", "name", "status"), Arrays.asList(R.I64, R.FP64, R.STRING, R.BOOLEAN)); @@ -42,7 +42,7 @@ void scatterExchange() { Rel exchange = ScatterExchange.builder() .input(baseTable) - .addFields(b.fieldReference(baseTable, 0)) + .addFields(sb.fieldReference(baseTable, 0)) .partitionCount(1) .build(); @@ -55,7 +55,7 @@ void singleBucketExchange() { SingleBucketExchange.builder() .input(baseTable) .partitionCount(1) - .expression(b.fieldReference(baseTable, 0)) + .expression(sb.fieldReference(baseTable, 0)) .build(); verifyRoundTrip(exchange); @@ -66,7 +66,7 @@ void multiBucketExchange() { Rel exchange = MultiBucketExchange.builder() .input(baseTable) - .expression(b.fieldReference(baseTable, 0)) + .expression(sb.fieldReference(baseTable, 0)) .constrainedToCount(true) .partitionCount(1) .build(); diff --git a/core/src/test/java/io/substrait/type/proto/ExpandRelRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/ExpandRelRoundtripTest.java index 127ac6c00..ef5e60a41 100644 --- a/core/src/test/java/io/substrait/type/proto/ExpandRelRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/ExpandRelRoundtripTest.java @@ -13,20 +13,20 @@ class ExpandRelRoundtripTest extends TestBase { final Rel input = - b.namedScan( + sb.namedScan( Stream.of("a_table").collect(Collectors.toList()), Stream.of("column1", "column2").collect(Collectors.toList()), Stream.of(R.I64, R.I64).collect(Collectors.toList())); private Expand.ExpandField getConsistentField(int index) { - return Expand.ConsistentField.builder().expression(b.fieldReference(input, index)).build(); + return Expand.ConsistentField.builder().expression(sb.fieldReference(input, index)).build(); } private Expand.ExpandField getSwitchingField(List indexes) { return Expand.SwitchingField.builder() .addAllDuplicates( indexes.stream() - .map(index -> b.fieldReference(input, index)) + .map(index -> sb.fieldReference(input, index)) .collect(Collectors.toList())) .build(); } @@ -35,7 +35,7 @@ private Expand.ExpandField getSwitchingField(List indexes) { void expandConsistent() { Rel rel = Expand.builder() - .from(b.expand(__ -> Collections.emptyList(), input)) + .from(sb.expand(__ -> Collections.emptyList(), input)) .hint( Hint.builder() .alias("alias1") @@ -52,7 +52,7 @@ void expandConsistent() { void expandSwitching() { Rel rel = Expand.builder() - .from(b.expand(__ -> Collections.emptyList(), input)) + .from(sb.expand(__ -> Collections.emptyList(), input)) .hint(Hint.builder().addAllOutputNames(Arrays.asList("name1", "name2")).build()) .fields( Stream.of( diff --git a/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java index e19f2ed7d..49895b01d 100644 --- a/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java @@ -47,10 +47,10 @@ class ExtensionRoundtripTest extends TestBase { final ProtoRelConverter protoRelConverter = - new StringHolderHandlingProtoRelConverter(functionCollector, defaultExtensionCollection); + new StringHolderHandlingProtoRelConverter(functionCollector, extensions); final Rel commonTable = - b.namedScan(Collections.emptyList(), Collections.emptyList(), Collections.emptyList()); + sb.namedScan(Collections.emptyList(), Collections.emptyList(), Collections.emptyList()); final AdvancedExtension commonExtension = AdvancedExtension.builder() @@ -105,7 +105,7 @@ void namedScan() { Rel rel = NamedScan.builder() .from( - b.namedScan( + sb.namedScan( Collections.emptyList(), Collections.emptyList(), Collections.emptyList())) .commonExtension(commonExtension) .extension(relExtension) @@ -123,7 +123,7 @@ void extensionTable() { void filter() { Rel rel = Filter.builder() - .from(b.filter(__ -> b.bool(true), commonTable)) + .from(sb.filter(__ -> sb.bool(true), commonTable)) .commonExtension(commonExtension) .extension(relExtension) .build(); @@ -134,7 +134,7 @@ void filter() { void fetch() { Rel rel = Fetch.builder() - .from(b.fetch(1, 2, commonTable)) + .from(sb.fetch(1, 2, commonTable)) .commonExtension(commonExtension) .extension(relExtension) .build(); @@ -145,7 +145,7 @@ void fetch() { void aggregate() { Rel rel = Aggregate.builder() - .from(b.aggregate(b::grouping, __ -> Collections.emptyList(), commonTable)) + .from(sb.aggregate(sb::grouping, __ -> Collections.emptyList(), commonTable)) .commonExtension(commonExtension) .extension(relExtension) .build(); @@ -156,7 +156,7 @@ void aggregate() { void sort() { Rel rel = Sort.builder() - .from(b.sort(__ -> Collections.emptyList(), commonTable)) + .from(sb.sort(__ -> Collections.emptyList(), commonTable)) .commonExtension(commonExtension) .extension(relExtension) .build(); @@ -167,7 +167,7 @@ void sort() { void join() { Rel rel = Join.builder() - .from(b.innerJoin(__ -> b.bool(true), commonTable, commonTable)) + .from(sb.innerJoin(__ -> sb.bool(true), commonTable, commonTable)) .commonExtension(commonExtension) .extension(relExtension) .build(); @@ -182,7 +182,7 @@ void hashJoin() { Rel relWithoutKeys = HashJoin.builder() .from( - b.hashJoin( + sb.hashJoin( leftEmptyKeys, rightEmptyKeys, HashJoin.JoinType.INNER, @@ -202,7 +202,7 @@ void mergeJoin() { Rel relWithoutKeys = MergeJoin.builder() .from( - b.mergeJoin( + sb.mergeJoin( leftEmptyKeys, rightEmptyKeys, MergeJoin.JoinType.INNER, @@ -219,8 +219,8 @@ void nestedLoopJoin() { Rel rel = NestedLoopJoin.builder() .from( - b.nestedLoopJoin( - __ -> b.bool(true), NestedLoopJoin.JoinType.INNER, commonTable, commonTable)) + sb.nestedLoopJoin( + __ -> sb.bool(true), NestedLoopJoin.JoinType.INNER, commonTable, commonTable)) .commonExtension(commonExtension) .extension(relExtension) .build(); @@ -231,7 +231,7 @@ void nestedLoopJoin() { void project() { Rel rel = Project.builder() - .from(b.project(__ -> Collections.emptyList(), commonTable)) + .from(sb.project(__ -> Collections.emptyList(), commonTable)) .commonExtension(commonExtension) .extension(relExtension) .build(); @@ -242,7 +242,7 @@ void project() { void expand() { Rel rel = Expand.builder() - .from(b.expand(__ -> Collections.emptyList(), commonTable)) + .from(sb.expand(__ -> Collections.emptyList(), commonTable)) .commonExtension(commonExtension) .build(); verifyRoundTrip(rel); @@ -252,7 +252,7 @@ void expand() { void set() { Rel rel = Set.builder() - .from(b.set(Set.SetOp.UNION_ALL, commonTable)) + .from(sb.set(Set.SetOp.UNION_ALL, commonTable)) .commonExtension(commonExtension) .extension(relExtension) .build(); @@ -284,7 +284,7 @@ void extensionLeafRel() { void cross() { Rel rel = Cross.builder() - .from(b.cross(commonTable, commonTable)) + .from(sb.cross(commonTable, commonTable)) .commonExtension(commonExtension) .extension(relExtension) .build(); @@ -297,13 +297,13 @@ class ExtensionThroughExpression { // Check that custom extensions in these relations can be handled. Rel baseTable = - b.namedScan( + sb.namedScan( Stream.of("test_table").collect(Collectors.toList()), Stream.of("test_column").collect(Collectors.toList()), Stream.of(TypeCreator.REQUIRED.I64).collect(Collectors.toList())); Rel relWithEnhancement = Project.builder() - .from(b.project(input -> Collections.emptyList(), baseTable)) + .from(sb.project(input -> Collections.emptyList(), baseTable)) .commonExtension(commonExtension) .extension(relExtension) .build(); @@ -311,7 +311,7 @@ class ExtensionThroughExpression { @Test void scalarSubquery() { Project rel = - b.project( + sb.project( input -> Stream.of( Expression.ScalarSubquery.builder() @@ -327,7 +327,7 @@ void scalarSubquery() { @Test void inPredicate() { Project rel = - b.project( + sb.project( input -> Stream.of( Expression.InPredicate.builder() @@ -342,7 +342,7 @@ void inPredicate() { @Test void setPredicate() { Project rel = - b.project( + sb.project( input -> Stream.of( Expression.SetPredicate.builder() diff --git a/core/src/test/java/io/substrait/type/proto/FieldReferenceRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/FieldReferenceRoundtripTest.java index a04f0bad4..4e00797c2 100644 --- a/core/src/test/java/io/substrait/type/proto/FieldReferenceRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/FieldReferenceRoundtripTest.java @@ -17,7 +17,7 @@ class FieldReferenceRoundtripTest extends TestBase { final Rel baseTable = - b.namedScan( + sb.namedScan( Collections.singletonList("test_table"), Arrays.asList("id", "amount", "name", "nested_struct"), Arrays.asList( @@ -30,7 +30,7 @@ class FieldReferenceRoundtripTest extends TestBase { void simpleStructFieldReference() { // Test simple root struct field reference via projection Rel projection = - Project.builder().input(baseTable).addExpressions(b.fieldReference(baseTable, 0)).build(); + Project.builder().input(baseTable).addExpressions(sb.fieldReference(baseTable, 0)).build(); verifyRoundTrip(projection); } @@ -42,9 +42,9 @@ void multipleFieldReferences() { Project.builder() .input(baseTable) .addExpressions( - b.fieldReference(baseTable, 0), - b.fieldReference(baseTable, 1), - b.fieldReference(baseTable, 2)) + sb.fieldReference(baseTable, 0), + sb.fieldReference(baseTable, 1), + sb.fieldReference(baseTable, 2)) .build(); verifyRoundTrip(projection); @@ -53,7 +53,8 @@ void multipleFieldReferences() { @Test void fieldReferenceInFilter() { // Test field reference in filter condition - Expression condition = b.equal(b.fieldReference(baseTable, 0), b.fieldReference(baseTable, 0)); + Expression condition = + sb.equal(sb.fieldReference(baseTable, 0), sb.fieldReference(baseTable, 0)); Rel filter = Filter.builder().input(baseTable).condition(condition).build(); @@ -63,7 +64,7 @@ void fieldReferenceInFilter() { @Test void fieldReferenceInComplexExpression() { // Test field reference as part of arithmetic expression - Expression add = b.add(b.fieldReference(baseTable, 0), b.fieldReference(baseTable, 0)); + Expression add = sb.add(sb.fieldReference(baseTable, 0), sb.fieldReference(baseTable, 0)); Rel projection = Project.builder().input(baseTable).addExpressions(add).build(); @@ -76,13 +77,13 @@ void fieldReferenceInNestedProjection() { Rel firstProjection = Project.builder() .input(baseTable) - .addExpressions(b.fieldReference(baseTable, 0), b.fieldReference(baseTable, 2)) + .addExpressions(sb.fieldReference(baseTable, 0), sb.fieldReference(baseTable, 2)) .build(); Rel secondProjection = Project.builder() .input(firstProjection) - .addExpressions(b.fieldReference(firstProjection, 1)) + .addExpressions(sb.fieldReference(firstProjection, 1)) .build(); verifyRoundTrip(secondProjection); @@ -95,10 +96,10 @@ void fieldReferenceAllFields() { Project.builder() .input(baseTable) .addExpressions( - b.fieldReference(baseTable, 0), - b.fieldReference(baseTable, 1), - b.fieldReference(baseTable, 2), - b.fieldReference(baseTable, 3)) + sb.fieldReference(baseTable, 0), + sb.fieldReference(baseTable, 1), + sb.fieldReference(baseTable, 2), + sb.fieldReference(baseTable, 3)) .build(); verifyRoundTrip(projection); @@ -108,9 +109,9 @@ void fieldReferenceAllFields() { void fieldReferenceWithBooleanLogic() { // Test field references in boolean expressions Expression condition = - b.and( - b.equal(b.fieldReference(baseTable, 0), b.fieldReference(baseTable, 0)), - b.equal(b.fieldReference(baseTable, 2), b.str("test"))); + sb.and( + sb.equal(sb.fieldReference(baseTable, 0), sb.fieldReference(baseTable, 0)), + sb.equal(sb.fieldReference(baseTable, 2), sb.str("test"))); Rel filter = Filter.builder().input(baseTable).condition(condition).build(); @@ -120,8 +121,8 @@ void fieldReferenceWithBooleanLogic() { @Test void fieldReferenceInMultipleArithmetic() { // Test multiple field references in arithmetic - Expression add = b.add(b.fieldReference(baseTable, 1), b.fieldReference(baseTable, 1)); - Expression multiply = b.multiply(add, b.fieldReference(baseTable, 1)); + Expression add = sb.add(sb.fieldReference(baseTable, 1), sb.fieldReference(baseTable, 1)); + Expression multiply = sb.multiply(add, sb.fieldReference(baseTable, 1)); Rel projection = Project.builder().input(baseTable).addExpressions(multiply).build(); @@ -135,9 +136,9 @@ void fieldReferenceReordering() { Project.builder() .input(baseTable) .addExpressions( - b.fieldReference(baseTable, 3), - b.fieldReference(baseTable, 0), - b.fieldReference(baseTable, 2)) + sb.fieldReference(baseTable, 3), + sb.fieldReference(baseTable, 0), + sb.fieldReference(baseTable, 2)) .build(); verifyRoundTrip(projection); @@ -150,9 +151,9 @@ void sameFieldReferencedMultipleTimes() { Project.builder() .input(baseTable) .addExpressions( - b.fieldReference(baseTable, 0), - b.fieldReference(baseTable, 0), - b.fieldReference(baseTable, 0)) + sb.fieldReference(baseTable, 0), + sb.fieldReference(baseTable, 0), + sb.fieldReference(baseTable, 0)) .build(); verifyRoundTrip(projection); diff --git a/core/src/test/java/io/substrait/type/proto/FilterRelRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/FilterRelRoundtripTest.java index 4ee77133f..95fce4c1c 100644 --- a/core/src/test/java/io/substrait/type/proto/FilterRelRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/FilterRelRoundtripTest.java @@ -11,7 +11,7 @@ class FilterRelRoundtripTest extends TestBase { final Rel baseTable = - b.namedScan( + sb.namedScan( Collections.singletonList("test_table"), Arrays.asList("id", "amount", "name", "status"), Arrays.asList(R.I64, R.FP64, R.STRING, R.BOOLEAN)); @@ -19,7 +19,7 @@ class FilterRelRoundtripTest extends TestBase { @Test void simpleEqualityFilter() { // Filter: WHERE id = 100 - Expression condition = b.equal(b.fieldReference(baseTable, 0), b.i32(100)); + Expression condition = sb.equal(sb.fieldReference(baseTable, 0), sb.i32(100)); Rel filter = Filter.builder().input(baseTable).condition(condition).build(); @@ -29,7 +29,7 @@ void simpleEqualityFilter() { @Test void stringComparisonFilter() { // Filter: WHERE name = 'John' - Expression condition = b.equal(b.fieldReference(baseTable, 2), b.str("John")); + Expression condition = sb.equal(sb.fieldReference(baseTable, 2), sb.str("John")); Rel filter = Filter.builder().input(baseTable).condition(condition).build(); @@ -40,9 +40,9 @@ void stringComparisonFilter() { void andConditionFilter() { // Filter: WHERE id = 10 AND amount = 100.0 Expression condition = - b.and( - b.equal(b.fieldReference(baseTable, 0), b.i32(10)), - b.equal(b.fieldReference(baseTable, 1), b.fp64(100.0))); + sb.and( + sb.equal(sb.fieldReference(baseTable, 0), sb.i32(10)), + sb.equal(sb.fieldReference(baseTable, 1), sb.fp64(100.0))); Rel filter = Filter.builder().input(baseTable).condition(condition).build(); @@ -53,9 +53,9 @@ void andConditionFilter() { void orConditionFilter() { // Filter: WHERE id = 5 OR id = 95 Expression condition = - b.or( - b.equal(b.fieldReference(baseTable, 0), b.i32(5)), - b.equal(b.fieldReference(baseTable, 0), b.i32(95))); + sb.or( + sb.equal(sb.fieldReference(baseTable, 0), sb.i32(5)), + sb.equal(sb.fieldReference(baseTable, 0), sb.i32(95))); Rel filter = Filter.builder().input(baseTable).condition(condition).build(); @@ -66,12 +66,12 @@ void orConditionFilter() { void complexBooleanFilter() { // Filter: WHERE (id = 10 AND amount = 100) OR status = true Expression andCondition = - b.and( - b.equal(b.fieldReference(baseTable, 0), b.i32(10)), - b.equal(b.fieldReference(baseTable, 1), b.fp64(100.0))); + sb.and( + sb.equal(sb.fieldReference(baseTable, 0), sb.i32(10)), + sb.equal(sb.fieldReference(baseTable, 1), sb.fp64(100.0))); Expression condition = - b.or(andCondition, b.equal(b.fieldReference(baseTable, 3), b.bool(true))); + sb.or(andCondition, sb.equal(sb.fieldReference(baseTable, 3), sb.bool(true))); Rel filter = Filter.builder().input(baseTable).condition(condition).build(); @@ -81,7 +81,8 @@ void complexBooleanFilter() { @Test void multipleFieldComparison() { // Filter: WHERE id = amount (comparing two fields) - Expression condition = b.equal(b.fieldReference(baseTable, 0), b.fieldReference(baseTable, 1)); + Expression condition = + sb.equal(sb.fieldReference(baseTable, 0), sb.fieldReference(baseTable, 1)); Rel filter = Filter.builder().input(baseTable).condition(condition).build(); @@ -91,10 +92,10 @@ void multipleFieldComparison() { @Test void nestedFilters() { // Apply filter on top of another filter - Expression firstCondition = b.equal(b.fieldReference(baseTable, 0), b.i32(10)); + Expression firstCondition = sb.equal(sb.fieldReference(baseTable, 0), sb.i32(10)); Rel firstFilter = Filter.builder().input(baseTable).condition(firstCondition).build(); - Expression secondCondition = b.equal(b.fieldReference(firstFilter, 1), b.fp64(100.0)); + Expression secondCondition = sb.equal(sb.fieldReference(firstFilter, 1), sb.fp64(100.0)); Rel secondFilter = Filter.builder().input(firstFilter).condition(secondCondition).build(); verifyRoundTrip(secondFilter); @@ -103,8 +104,8 @@ void nestedFilters() { @Test void filterWithArithmeticExpression() { // Filter: WHERE amount * 2 = 100 - Expression multiply = b.multiply(b.fieldReference(baseTable, 1), b.fp64(2.0)); - Expression condition = b.equal(multiply, b.fp64(100.0)); + Expression multiply = sb.multiply(sb.fieldReference(baseTable, 1), sb.fp64(2.0)); + Expression condition = sb.equal(multiply, sb.fp64(100.0)); Rel filter = Filter.builder().input(baseTable).condition(condition).build(); @@ -114,7 +115,7 @@ void filterWithArithmeticExpression() { @Test void filterWithBooleanField() { // Filter: WHERE status (direct boolean field) - Expression condition = b.fieldReference(baseTable, 3); + Expression condition = sb.fieldReference(baseTable, 3); Rel filter = Filter.builder().input(baseTable).condition(condition).build(); @@ -124,8 +125,8 @@ void filterWithBooleanField() { @Test void filterWithAddition() { // Filter: WHERE id + id = id (field with itself) - Expression add = b.add(b.fieldReference(baseTable, 0), b.fieldReference(baseTable, 0)); - Expression condition = b.equal(add, b.fieldReference(baseTable, 0)); + Expression add = sb.add(sb.fieldReference(baseTable, 0), sb.fieldReference(baseTable, 0)); + Expression condition = sb.equal(add, sb.fieldReference(baseTable, 0)); Rel filter = Filter.builder().input(baseTable).condition(condition).build(); diff --git a/core/src/test/java/io/substrait/type/proto/GenericRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/GenericRoundtripTest.java index 268729c68..3e3f42646 100644 --- a/core/src/test/java/io/substrait/type/proto/GenericRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/GenericRoundtripTest.java @@ -54,10 +54,7 @@ void roundtripTest(Method m, List paramInst, UnsupportedTypeGenerationEx ExpressionProtoConverter to = new ExpressionProtoConverter(null, null); ProtoExpressionConverter from = new ProtoExpressionConverter( - null, - null, - EMPTY_TYPE, - new ProtoRelConverter(new ExtensionCollector(), defaultExtensionCollection)); + null, null, EMPTY_TYPE, new ProtoRelConverter(new ExtensionCollector(), extensions)); assertEquals(val, from.from(val.accept(to, EmptyVisitationContext.INSTANCE))); } diff --git a/core/src/test/java/io/substrait/type/proto/IfThenRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/IfThenRoundtripTest.java index 9b42136fb..d8c12b641 100644 --- a/core/src/test/java/io/substrait/type/proto/IfThenRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/IfThenRoundtripTest.java @@ -19,9 +19,9 @@ class IfThenRoundtripTest extends TestBase { @Test void ifThenNotNullable() { final Expression.IfThen ifRel = - b.ifThen( + sb.ifThen( Arrays.asList( - b.ifClause(ExpressionCreator.bool(false, false), ExpressionCreator.i64(false, 1))), + sb.ifClause(ExpressionCreator.bool(false, false), ExpressionCreator.i64(false, 1))), ExpressionCreator.i64(false, 2)); assertFalse(ifRel.getType().nullable()); @@ -34,9 +34,9 @@ void ifThenNotNullable() { @Test void ifThenNullable() { final Expression.IfThen ifRel = - b.ifThen( + sb.ifThen( Arrays.asList( - b.ifClause(ExpressionCreator.bool(true, false), ExpressionCreator.i64(true, 1))), + sb.ifClause(ExpressionCreator.bool(true, false), ExpressionCreator.i64(true, 1))), ExpressionCreator.i64(false, 2)); assertTrue(ifRel.getType().nullable()); diff --git a/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java index 45e8749ba..fd230fb6d 100644 --- a/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java @@ -12,13 +12,13 @@ class JoinRoundtripTest extends TestBase { final Rel leftTable = - b.namedScan( + sb.namedScan( Arrays.asList("T1"), Arrays.asList("a", "b", "c"), Arrays.asList(R.I64, R.FP64, R.STRING)); final Rel rightTable = - b.namedScan( + sb.namedScan( Arrays.asList("T2"), Arrays.asList("d", "e", "f"), Arrays.asList(R.FP64, R.STRING, R.I64)); @@ -29,7 +29,7 @@ void hashJoin() { List rightKeys = Arrays.asList(2, 0); Rel relWithoutKeys = HashJoin.builder() - .from(b.hashJoin(leftKeys, rightKeys, HashJoin.JoinType.INNER, leftTable, rightTable)) + .from(sb.hashJoin(leftKeys, rightKeys, HashJoin.JoinType.INNER, leftTable, rightTable)) .build(); verifyRoundTrip(relWithoutKeys); } @@ -40,7 +40,8 @@ void mergeJoin() { List rightKeys = Arrays.asList(2, 0); Rel relWithoutKeys = MergeJoin.builder() - .from(b.mergeJoin(leftKeys, rightKeys, MergeJoin.JoinType.INNER, leftTable, rightTable)) + .from( + sb.mergeJoin(leftKeys, rightKeys, MergeJoin.JoinType.INNER, leftTable, rightTable)) .build(); verifyRoundTrip(relWithoutKeys); } @@ -51,8 +52,9 @@ void nestedLoopJoin() { Rel rel = NestedLoopJoin.builder() .from( - b.nestedLoopJoin( - __ -> b.equal(b.fieldReference(inputRels, 0), b.fieldReference(inputRels, 5)), + sb.nestedLoopJoin( + __ -> + sb.equal(sb.fieldReference(inputRels, 0), sb.fieldReference(inputRels, 5)), NestedLoopJoin.JoinType.INNER, leftTable, rightTable)) diff --git a/core/src/test/java/io/substrait/type/proto/LocalFilesRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/LocalFilesRoundtripTest.java index 7143ef707..ff23b3fea 100644 --- a/core/src/test/java/io/substrait/type/proto/LocalFilesRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/LocalFilesRoundtripTest.java @@ -33,7 +33,7 @@ private void assertLocalFilesRoundtrip(FileOrFiles file) { .build()) .addItems(file); - defaultExtensionCollection.scalarFunctions().stream() + extensions.scalarFunctions().stream() .filter(s -> s.name().equalsIgnoreCase("equal")) .findFirst() .map( diff --git a/core/src/test/java/io/substrait/type/proto/NestedListExpressionTest.java b/core/src/test/java/io/substrait/type/proto/NestedListExpressionTest.java index e2c4b18b8..ec3bc885c 100644 --- a/core/src/test/java/io/substrait/type/proto/NestedListExpressionTest.java +++ b/core/src/test/java/io/substrait/type/proto/NestedListExpressionTest.java @@ -11,25 +11,25 @@ class NestedListExpressionTest extends TestBase { io.substrait.expression.Expression literalExpression = Expression.BoolLiteral.builder().value(true).build(); - Expression.ScalarFunctionInvocation nonLiteralExpression = b.add(b.i32(7), b.i32(42)); + Expression.ScalarFunctionInvocation nonLiteralExpression = sb.add(sb.i32(7), sb.i32(42)); @Test void rejectNestedListWithElementsOfDifferentTypes() { ImmutableExpression.NestedList.Builder builder = - Expression.NestedList.builder().addValues(literalExpression).addValues(b.i32(12)); + Expression.NestedList.builder().addValues(literalExpression).addValues(sb.i32(12)); assertThrows(AssertionError.class, builder::build); } @Test void acceptNestedListWithElementsOfSameType() { ImmutableExpression.NestedList.Builder builder = - Expression.NestedList.builder().addValues(nonLiteralExpression).addValues(b.i32(12)); + Expression.NestedList.builder().addValues(nonLiteralExpression).addValues(sb.i32(12)); assertDoesNotThrow(builder::build); io.substrait.relation.Project project = io.substrait.relation.Project.builder() .addExpressions(builder.build()) - .input(b.emptyScan()) + .input(sb.emptyScan()) .build(); verifyRoundTrip(project); } @@ -51,7 +51,7 @@ void literalNestedListTest() { io.substrait.relation.Project project = io.substrait.relation.Project.builder() .addExpressions(literalNestedList) - .input(b.emptyScan()) + .input(sb.emptyScan()) .build(); verifyRoundTrip(project); @@ -69,7 +69,7 @@ void literalNullableNestedListTest() { io.substrait.relation.Project project = io.substrait.relation.Project.builder() .addExpressions(literalNestedList) - .input(b.emptyScan()) + .input(sb.emptyScan()) .build(); verifyRoundTrip(project); @@ -86,7 +86,7 @@ void nonLiteralNestedListTest() { io.substrait.relation.Project project = io.substrait.relation.Project.builder() .addExpressions(nonLiteralNestedList) - .input(b.emptyScan()) + .input(sb.emptyScan()) .build(); verifyRoundTrip(project); diff --git a/core/src/test/java/io/substrait/type/proto/ProjectRelRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/ProjectRelRoundtripTest.java index 66a0fd417..9c557913f 100644 --- a/core/src/test/java/io/substrait/type/proto/ProjectRelRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/ProjectRelRoundtripTest.java @@ -11,7 +11,7 @@ class ProjectRelRoundtripTest extends TestBase { final Rel baseTable = - b.namedScan( + sb.namedScan( Collections.singletonList("test_table"), Arrays.asList("col_a", "col_b", "col_c", "col_d"), Arrays.asList(R.I64, R.FP64, R.STRING, R.I32)); @@ -20,7 +20,7 @@ class ProjectRelRoundtripTest extends TestBase { void simpleProjection() { // Project single field Rel projection = - Project.builder().input(baseTable).addExpressions(b.fieldReference(baseTable, 0)).build(); + Project.builder().input(baseTable).addExpressions(sb.fieldReference(baseTable, 0)).build(); verifyRoundTrip(projection); } @@ -32,9 +32,9 @@ void multipleFieldProjection() { Project.builder() .input(baseTable) .addExpressions( - b.fieldReference(baseTable, 0), - b.fieldReference(baseTable, 2), - b.fieldReference(baseTable, 1)) + sb.fieldReference(baseTable, 0), + sb.fieldReference(baseTable, 2), + sb.fieldReference(baseTable, 1)) .build(); verifyRoundTrip(projection); @@ -43,7 +43,7 @@ void multipleFieldProjection() { @Test void projectionWithComputedExpression() { // Project with computed expression: col_a + 3 (both I64) - Expression addExpr = b.add(b.fieldReference(baseTable, 0), b.fieldReference(baseTable, 0)); + Expression addExpr = sb.add(sb.fieldReference(baseTable, 0), sb.fieldReference(baseTable, 0)); Rel projection = Project.builder().input(baseTable).addExpressions(addExpr).build(); @@ -53,15 +53,15 @@ void projectionWithComputedExpression() { @Test void projectionWithMultipleComputedExpressions() { // Project with multiple computed expressions - Expression add = b.add(b.fieldReference(baseTable, 0), b.fieldReference(baseTable, 0)); + Expression add = sb.add(sb.fieldReference(baseTable, 0), sb.fieldReference(baseTable, 0)); Expression multiply = - b.multiply(b.fieldReference(baseTable, 1), b.fieldReference(baseTable, 1)); + sb.multiply(sb.fieldReference(baseTable, 1), sb.fieldReference(baseTable, 1)); Rel projection = Project.builder() .input(baseTable) .addExpressions( - b.fieldReference(baseTable, 2), // original field + sb.fieldReference(baseTable, 2), // original field add, // computed col_a + 100 multiply) // computed col_b * 2.0 .build(); @@ -75,7 +75,7 @@ void projectionWithLiterals() { Rel projection = Project.builder() .input(baseTable) - .addExpressions(b.fieldReference(baseTable, 0), b.i32(100), b.str("constant_string")) + .addExpressions(sb.fieldReference(baseTable, 0), sb.i32(100), sb.str("constant_string")) .build(); verifyRoundTrip(projection); @@ -88,10 +88,10 @@ void projectionWithAllFields() { Project.builder() .input(baseTable) .addExpressions( - b.fieldReference(baseTable, 0), - b.fieldReference(baseTable, 1), - b.fieldReference(baseTable, 2), - b.fieldReference(baseTable, 3)) + sb.fieldReference(baseTable, 0), + sb.fieldReference(baseTable, 1), + sb.fieldReference(baseTable, 2), + sb.fieldReference(baseTable, 3)) .build(); verifyRoundTrip(projection); @@ -103,13 +103,13 @@ void nestedProjection() { Rel firstProjection = Project.builder() .input(baseTable) - .addExpressions(b.fieldReference(baseTable, 0), b.fieldReference(baseTable, 2)) + .addExpressions(sb.fieldReference(baseTable, 0), sb.fieldReference(baseTable, 2)) .build(); Rel secondProjection = Project.builder() .input(firstProjection) - .addExpressions(b.fieldReference(firstProjection, 1)) + .addExpressions(sb.fieldReference(firstProjection, 1)) .build(); verifyRoundTrip(secondProjection); @@ -118,12 +118,13 @@ void nestedProjection() { @Test void projectionWithComparison() { // Project with comparison expression: col_a = col_d - Expression comparison = b.equal(b.fieldReference(baseTable, 0), b.fieldReference(baseTable, 3)); + Expression comparison = + sb.equal(sb.fieldReference(baseTable, 0), sb.fieldReference(baseTable, 3)); Rel projection = Project.builder() .input(baseTable) - .addExpressions(b.fieldReference(baseTable, 0), comparison) + .addExpressions(sb.fieldReference(baseTable, 0), comparison) .build(); verifyRoundTrip(projection); @@ -132,7 +133,7 @@ void projectionWithComparison() { @Test void projectionWithCast() { // Project with type cast: CAST(col_d AS BIGINT) - Expression cast = b.cast(b.fieldReference(baseTable, 3), R.I64); + Expression cast = sb.cast(sb.fieldReference(baseTable, 3), R.I64); Rel projection = Project.builder().input(baseTable).addExpressions(cast).build(); diff --git a/core/src/test/java/io/substrait/type/proto/ReadRelRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/ReadRelRoundtripTest.java index 611564fce..6144d784b 100644 --- a/core/src/test/java/io/substrait/type/proto/ReadRelRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/ReadRelRoundtripTest.java @@ -20,13 +20,13 @@ void namedScan() { List columnNames = Stream.of("column1", "column2").collect(Collectors.toList()); List columnTypes = Stream.of(R.I64, R.I64).collect(Collectors.toList()); - NamedScan namedScan = b.namedScan(tableName, columnNames, columnTypes); + NamedScan namedScan = sb.namedScan(tableName, columnNames, columnTypes); namedScan = NamedScan.builder() .from(namedScan) .bestEffortFilter( - b.equal(b.fieldReference(namedScan, 0), b.fieldReference(namedScan, 1))) - .filter(b.equal(b.fieldReference(namedScan, 0), b.fieldReference(namedScan, 1))) + sb.equal(sb.fieldReference(namedScan, 0), sb.fieldReference(namedScan, 1))) + .filter(sb.equal(sb.fieldReference(namedScan, 0), sb.fieldReference(namedScan, 1))) .build(); verifyRoundTrip(namedScan); @@ -34,7 +34,7 @@ void namedScan() { @Test void emptyScan() { - io.substrait.relation.EmptyScan emptyScan = b.emptyScan(); + io.substrait.relation.EmptyScan emptyScan = sb.emptyScan(); verifyRoundTrip(emptyScan); } @@ -56,8 +56,8 @@ void virtualTable() { VirtualTableScan.builder() .from(virtTable) .bestEffortFilter( - b.equal(b.fieldReference(virtTable, 0), b.fieldReference(virtTable, 1))) - .filter(b.equal(b.fieldReference(virtTable, 0), b.fieldReference(virtTable, 1))) + sb.equal(sb.fieldReference(virtTable, 0), sb.fieldReference(virtTable, 1))) + .filter(sb.equal(sb.fieldReference(virtTable, 0), sb.fieldReference(virtTable, 1))) .build(); verifyRoundTrip(virtTable); } diff --git a/core/src/test/java/io/substrait/type/proto/SortRelRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/SortRelRoundtripTest.java index 9ec30822c..039062f24 100644 --- a/core/src/test/java/io/substrait/type/proto/SortRelRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/SortRelRoundtripTest.java @@ -11,7 +11,7 @@ class SortRelRoundtripTest extends TestBase { final Rel baseTable = - b.namedScan( + sb.namedScan( Collections.singletonList("test_table"), Arrays.asList("id", "amount", "name", "category", "timestamp"), Arrays.asList(R.I64, R.FP64, R.STRING, R.STRING, R.TIMESTAMP)); @@ -21,7 +21,7 @@ void simpleSortAscending() { // Sort by id ascending, nulls first Expression.SortField sortField = Expression.SortField.builder() - .expr(b.fieldReference(baseTable, 0)) + .expr(sb.fieldReference(baseTable, 0)) .direction(Expression.SortDirection.ASC_NULLS_FIRST) .build(); @@ -35,7 +35,7 @@ void sortAscendingNullsLast() { // Sort by name ascending, nulls last Expression.SortField sortField = Expression.SortField.builder() - .expr(b.fieldReference(baseTable, 2)) + .expr(sb.fieldReference(baseTable, 2)) .direction(Expression.SortDirection.ASC_NULLS_LAST) .build(); @@ -49,7 +49,7 @@ void sortDescendingNullsFirst() { // Sort by amount descending, nulls first Expression.SortField sortField = Expression.SortField.builder() - .expr(b.fieldReference(baseTable, 1)) + .expr(sb.fieldReference(baseTable, 1)) .direction(Expression.SortDirection.DESC_NULLS_FIRST) .build(); @@ -63,7 +63,7 @@ void sortDescendingNullsLast() { // Sort by timestamp descending, nulls last Expression.SortField sortField = Expression.SortField.builder() - .expr(b.fieldReference(baseTable, 4)) + .expr(sb.fieldReference(baseTable, 4)) .direction(Expression.SortDirection.DESC_NULLS_LAST) .build(); @@ -77,7 +77,7 @@ void sortClustered() { // Sort with clustered direction (no specific order guarantee) Expression.SortField sortField = Expression.SortField.builder() - .expr(b.fieldReference(baseTable, 3)) + .expr(sb.fieldReference(baseTable, 3)) .direction(Expression.SortDirection.CLUSTERED) .build(); @@ -91,13 +91,13 @@ void multipleSortFields() { // Sort by category (asc), then amount (desc) Expression.SortField sortField1 = Expression.SortField.builder() - .expr(b.fieldReference(baseTable, 3)) + .expr(sb.fieldReference(baseTable, 3)) .direction(Expression.SortDirection.ASC_NULLS_FIRST) .build(); Expression.SortField sortField2 = Expression.SortField.builder() - .expr(b.fieldReference(baseTable, 1)) + .expr(sb.fieldReference(baseTable, 1)) .direction(Expression.SortDirection.DESC_NULLS_LAST) .build(); @@ -111,19 +111,19 @@ void sortByThreeFields() { // Sort by category, name, and id Expression.SortField sortField1 = Expression.SortField.builder() - .expr(b.fieldReference(baseTable, 3)) + .expr(sb.fieldReference(baseTable, 3)) .direction(Expression.SortDirection.ASC_NULLS_LAST) .build(); Expression.SortField sortField2 = Expression.SortField.builder() - .expr(b.fieldReference(baseTable, 2)) + .expr(sb.fieldReference(baseTable, 2)) .direction(Expression.SortDirection.ASC_NULLS_LAST) .build(); Expression.SortField sortField3 = Expression.SortField.builder() - .expr(b.fieldReference(baseTable, 0)) + .expr(sb.fieldReference(baseTable, 0)) .direction(Expression.SortDirection.ASC_NULLS_FIRST) .build(); @@ -136,7 +136,7 @@ void sortByThreeFields() { @Test void sortByComputedExpression() { // Sort by computed expression: amount * 2 - Expression computedExpr = b.multiply(b.fieldReference(baseTable, 1), b.fp64(2.0)); + Expression computedExpr = sb.multiply(sb.fieldReference(baseTable, 1), sb.fp64(2.0)); Expression.SortField sortField = Expression.SortField.builder() @@ -154,7 +154,7 @@ void sortByStringField() { // Sort by string field directly Expression.SortField sortField = Expression.SortField.builder() - .expr(b.fieldReference(baseTable, 2)) + .expr(sb.fieldReference(baseTable, 2)) .direction(Expression.SortDirection.ASC_NULLS_LAST) .build(); @@ -168,19 +168,19 @@ void sortWithMixedNullHandling() { // Sort with different null handling for different fields Expression.SortField sortField1 = Expression.SortField.builder() - .expr(b.fieldReference(baseTable, 3)) + .expr(sb.fieldReference(baseTable, 3)) .direction(Expression.SortDirection.ASC_NULLS_FIRST) .build(); Expression.SortField sortField2 = Expression.SortField.builder() - .expr(b.fieldReference(baseTable, 1)) + .expr(sb.fieldReference(baseTable, 1)) .direction(Expression.SortDirection.DESC_NULLS_FIRST) .build(); Expression.SortField sortField3 = Expression.SortField.builder() - .expr(b.fieldReference(baseTable, 2)) + .expr(sb.fieldReference(baseTable, 2)) .direction(Expression.SortDirection.ASC_NULLS_LAST) .build(); @@ -198,23 +198,23 @@ void sortAllDirections() { .input(baseTable) .addSortFields( Expression.SortField.builder() - .expr(b.fieldReference(baseTable, 0)) + .expr(sb.fieldReference(baseTable, 0)) .direction(Expression.SortDirection.ASC_NULLS_FIRST) .build(), Expression.SortField.builder() - .expr(b.fieldReference(baseTable, 1)) + .expr(sb.fieldReference(baseTable, 1)) .direction(Expression.SortDirection.ASC_NULLS_LAST) .build(), Expression.SortField.builder() - .expr(b.fieldReference(baseTable, 2)) + .expr(sb.fieldReference(baseTable, 2)) .direction(Expression.SortDirection.DESC_NULLS_FIRST) .build(), Expression.SortField.builder() - .expr(b.fieldReference(baseTable, 3)) + .expr(sb.fieldReference(baseTable, 3)) .direction(Expression.SortDirection.DESC_NULLS_LAST) .build(), Expression.SortField.builder() - .expr(b.fieldReference(baseTable, 4)) + .expr(sb.fieldReference(baseTable, 4)) .direction(Expression.SortDirection.CLUSTERED) .build()) .build(); @@ -227,7 +227,7 @@ void nestedSort() { // Sort on top of another sort Expression.SortField firstSort = Expression.SortField.builder() - .expr(b.fieldReference(baseTable, 3)) + .expr(sb.fieldReference(baseTable, 3)) .direction(Expression.SortDirection.ASC_NULLS_FIRST) .build(); @@ -235,7 +235,7 @@ void nestedSort() { Expression.SortField secondSort = Expression.SortField.builder() - .expr(b.fieldReference(firstSortRel, 0)) + .expr(sb.fieldReference(firstSortRel, 0)) .direction(Expression.SortDirection.DESC_NULLS_LAST) .build(); diff --git a/core/src/test/java/io/substrait/type/proto/UpdateRelRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/UpdateRelRoundtripTest.java index dd7c1b2fa..1645a5fa0 100644 --- a/core/src/test/java/io/substrait/type/proto/UpdateRelRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/UpdateRelRoundtripTest.java @@ -47,7 +47,7 @@ void update() { } private Expression.ScalarFunctionInvocation fnAdd(int value) { - return defaultExtensionCollection.scalarFunctions().stream() + return extensions.scalarFunctions().stream() .filter(s -> s.name().equalsIgnoreCase("add")) .findFirst() .map( diff --git a/core/src/test/java/io/substrait/type/proto/WriteRelRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/WriteRelRoundtripTest.java index 1ecd1fad4..cda30a7b6 100644 --- a/core/src/test/java/io/substrait/type/proto/WriteRelRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/WriteRelRoundtripTest.java @@ -37,7 +37,7 @@ void insert() { virtTable = VirtualTableScan.builder() .from(virtTable) - .filter(b.equal(b.fieldReference(virtTable, 0), b.fieldReference(virtTable, 1))) + .filter(sb.equal(sb.fieldReference(virtTable, 0), sb.fieldReference(virtTable, 1))) .build(); NamedWrite command = @@ -56,7 +56,7 @@ void insert() { @Test void append() { ProtoRelConverter protoRelConverter = - new StringHolderHandlingProtoRelConverter(functionCollector, defaultExtensionCollection); + new StringHolderHandlingProtoRelConverter(functionCollector, extensions); StringHolder detail = new StringHolder("DETAIL"); diff --git a/isthmus/src/test/java/io/substrait/isthmus/AggregationFunctionsTest.java b/isthmus/src/test/java/io/substrait/isthmus/AggregationFunctionsTest.java index 4f690e17f..e3082f55b 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/AggregationFunctionsTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/AggregationFunctionsTest.java @@ -1,12 +1,10 @@ package io.substrait.isthmus; import com.google.common.collect.Streams; -import io.substrait.dsl.SubstraitBuilder; import io.substrait.relation.Aggregate; import io.substrait.relation.NamedScan; import io.substrait.relation.Rel; import io.substrait.type.Type; -import io.substrait.type.TypeCreator; import java.util.List; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -16,11 +14,6 @@ class AggregationFunctionsTest extends PlanTestBase { - SubstraitBuilder b = new SubstraitBuilder(extensions); - - static final TypeCreator R = TypeCreator.of(false); - static final TypeCreator N = TypeCreator.of(true); - // Create a table with that has a column of every numeric type, both NOT NULL and NULL private List numericTypesR = List.of(R.I8, R.I16, R.I32, R.I64, R.FP32, R.FP64); private List numericTypesN = List.of(N.I8, N.I16, N.I32, N.I64, N.FP32, N.FP64); @@ -37,21 +30,21 @@ class AggregationFunctionsTest extends PlanTestBase { private List columnNames = Streams.mapWithIndex(tableTypes.stream(), (t, index) -> String.valueOf(index)) .collect(Collectors.toList()); - private NamedScan numericTypesTable = b.namedScan(List.of("example"), columnNames, tableTypes); + private NamedScan numericTypesTable = sb.namedScan(List.of("example"), columnNames, tableTypes); // Create the given function call on the given field of the input private Aggregate.Measure functionPicker(Rel input, int field, String fname) { switch (fname) { case "min": - return b.min(input, field); + return sb.min(input, field); case "max": - return b.max(input, field); + return sb.max(input, field); case "sum": - return b.sum(input, field); + return sb.sum(input, field); case "sum0": - return b.sum0(input, field); + return sb.sum0(input, field); case "avg": - return b.avg(input, field); + return sb.avg(input, field); default: throw new UnsupportedOperationException( String.format("no function is associated with %s", fname)); @@ -71,8 +64,8 @@ private List functions(Rel input, String fname) { @ValueSource(strings = {"max", "min", "sum", "sum0", "avg"}) void emptyGrouping(String aggFunction) { Aggregate rel = - b.aggregate( - input -> b.grouping(input), input -> functions(input, aggFunction), numericTypesTable); + sb.aggregate( + input -> sb.grouping(input), input -> functions(input, aggFunction), numericTypesTable); assertFullRoundTrip(rel); } @@ -80,8 +73,8 @@ void emptyGrouping(String aggFunction) { @ValueSource(strings = {"max", "min", "sum", "sum0", "avg"}) void withGrouping(String aggFunction) { Aggregate rel = - b.aggregate( - input -> b.grouping(input, 0), + sb.aggregate( + input -> sb.grouping(input, 0), input -> functions(input, aggFunction), numericTypesTable); assertFullRoundTrip(rel); diff --git a/isthmus/src/test/java/io/substrait/isthmus/ComplexAggregateTest.java b/isthmus/src/test/java/io/substrait/isthmus/ComplexAggregateTest.java index 7a0d4ddc5..de53b65b4 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ComplexAggregateTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ComplexAggregateTest.java @@ -2,31 +2,22 @@ import static org.junit.jupiter.api.Assertions.assertEquals; -import io.substrait.dsl.SubstraitBuilder; import io.substrait.expression.AggregateFunctionInvocation; import io.substrait.expression.Expression; import io.substrait.expression.ImmutableAggregateFunctionInvocation; -import io.substrait.extension.DefaultExtensionCatalog; -import io.substrait.extension.SimpleExtension; import io.substrait.relation.Aggregate; import io.substrait.relation.NamedScan; import io.substrait.relation.Rel; import io.substrait.type.Type; -import io.substrait.type.TypeCreator; import java.util.List; import org.apache.calcite.rel.RelNode; import org.junit.jupiter.api.Test; class ComplexAggregateTest extends PlanTestBase { - protected static final SimpleExtension.ExtensionCollection EXTENSION_COLLECTION = - DefaultExtensionCatalog.DEFAULT_COLLECTION; - - final TypeCreator R = TypeCreator.of(false); - SubstraitBuilder b = new SubstraitBuilder(extensions); private List columnTypes = List.of(R.I32, R.I32, R.I32, R.I32); private List columnNames = List.of("a", "b", "c", "d"); - private NamedScan table = b.namedScan(List.of("example"), columnNames, columnTypes); + private NamedScan table = sb.namedScan(List.of("example"), columnNames, columnTypes); private Aggregate.Grouping emptyGrouping = Aggregate.Grouping.builder().build(); @@ -59,26 +50,26 @@ protected void validateAggregateTransformation(Aggregate pojo, Rel expectedTrans assertEquals(expectedTransform, converterPojo); // Substrait POJO -> Calcite - new SubstraitToCalcite(EXTENSION_COLLECTION, typeFactory).convert(pojo); + substraitToCalcite.convert(pojo); } @Test void handleComplexMeasureArgument() { // SELECT sum(c + 7) FROM example Aggregate rel = - b.aggregate( + sb.aggregate( input -> emptyGrouping, - input -> List.of(b.sum(b.add(b.fieldReference(input, 2), b.i32(7)))), + input -> List.of(sb.sum(sb.add(sb.fieldReference(input, 2), sb.i32(7)))), table); Aggregate expectedFinal = - b.aggregate( + sb.aggregate( input -> emptyGrouping, // sum call references input field - input -> List.of(b.sum(input, 4)), - b.project( + input -> List.of(sb.sum(input, 4)), + sb.project( // add call is moved to child project - input -> List.of(b.add(b.fieldReference(input, 2), b.i32(7))), + input -> List.of(sb.add(sb.fieldReference(input, 2), sb.i32(7))), table)); validateAggregateTransformation(rel, expectedFinal); @@ -88,19 +79,19 @@ void handleComplexMeasureArgument() { void handleComplexPreMeasureFilter() { // SELECT sum(a) FILTER (b = 42) FROM example Aggregate rel = - b.aggregate( + sb.aggregate( input -> emptyGrouping, input -> List.of( withPreMeasureFilter( - b.sum(input, 0), b.equal(b.fieldReference(input, 1), b.i32(42)))), + sb.sum(input, 0), sb.equal(sb.fieldReference(input, 1), sb.i32(42)))), table); Aggregate expectedFinal = - b.aggregate( + sb.aggregate( input -> emptyGrouping, - input -> List.of(withPreMeasureFilter(b.sum(input, 0), b.fieldReference(input, 4))), - b.project(input -> List.of(b.equal(b.fieldReference(input, 1), b.i32(42))), table)); + input -> List.of(withPreMeasureFilter(sb.sum(input, 0), sb.fieldReference(input, 4))), + sb.project(input -> List.of(sb.equal(sb.fieldReference(input, 1), sb.i32(42))), table)); validateAggregateTransformation(rel, expectedFinal); } @@ -109,32 +100,32 @@ void handleComplexPreMeasureFilter() { void handleComplexSortingArguments() { // SELECT sum(d ORDER BY -b ASC) FROM example Aggregate rel = - b.aggregate( + sb.aggregate( input -> emptyGrouping, input -> List.of( withSort( - b.sum(input, 3), + sb.sum(input, 3), List.of( - b.sortField( - b.negate(b.fieldReference(input, 1)), + sb.sortField( + sb.negate(sb.fieldReference(input, 1)), Expression.SortDirection.ASC_NULLS_FIRST)))), table); Aggregate expectedFinal = - b.aggregate( + sb.aggregate( input -> emptyGrouping, input -> List.of( withSort( - b.sum(input, 3), + sb.sum(input, 3), List.of( - b.sortField( - b.fieldReference(input, 4), + sb.sortField( + sb.fieldReference(input, 4), Expression.SortDirection.ASC_NULLS_FIRST)))), - b.project( + sb.project( // negate call is moved to child project - input -> List.of(b.negate(b.fieldReference(input, 1))), + input -> List.of(sb.negate(sb.fieldReference(input, 1))), table)); validateAggregateTransformation(rel, expectedFinal); @@ -143,22 +134,23 @@ void handleComplexSortingArguments() { @Test void handleComplexGroupingArgument() { Aggregate rel = - b.aggregate( + sb.aggregate( input -> - b.grouping( - b.fieldReference(input, 2), b.add(b.fieldReference(input, 1), b.i32(42))), + sb.grouping( + sb.fieldReference(input, 2), sb.add(sb.fieldReference(input, 1), sb.i32(42))), input -> List.of(), table); Aggregate expectedFinal = - b.aggregate( + sb.aggregate( // grouping exprs are now field references to input - input -> b.grouping(input, 4, 5), + input -> sb.grouping(input, 4, 5), input -> List.of(), - b.project( + sb.project( input -> List.of( - b.fieldReference(input, 2), b.add(b.fieldReference(input, 1), b.i32(42))), + sb.fieldReference(input, 2), + sb.add(sb.fieldReference(input, 1), sb.i32(42))), table)); validateAggregateTransformation(rel, expectedFinal); @@ -166,20 +158,20 @@ void handleComplexGroupingArgument() { @Test void handleOutOfOrderGroupingArguments() { - Aggregate rel = b.aggregate(input -> b.grouping(input, 1, 0, 2), input -> List.of(), table); + Aggregate rel = sb.aggregate(input -> sb.grouping(input, 1, 0, 2), input -> List.of(), table); Aggregate expectedFinal = - b.aggregate( + sb.aggregate( // grouping exprs are now field references to input - input -> b.grouping(input, 4, 5, 6), + input -> sb.grouping(input, 4, 5, 6), input -> List.of(), - b.project( + sb.project( // ALL grouping exprs are added to the child projects (including field references) input -> List.of( - b.fieldReference(input, 1), - b.fieldReference(input, 0), - b.fieldReference(input, 2)), + sb.fieldReference(input, 1), + sb.fieldReference(input, 0), + sb.fieldReference(input, 2)), table)); validateAggregateTransformation(rel, expectedFinal); @@ -188,11 +180,11 @@ void handleOutOfOrderGroupingArguments() { @Test void outOfOrderGroupingKeysHaveCorrectCalciteType() { Rel rel = - b.aggregate( - input -> b.grouping(input, 2, 0), + sb.aggregate( + input -> sb.grouping(input, 2, 0), input -> List.of(), - b.namedScan(List.of("foo"), List.of("a", "b", "c"), List.of(R.I64, R.I64, R.STRING))); - RelNode relNode = new SubstraitToCalcite(EXTENSION_COLLECTION, typeFactory).convert(rel); + sb.namedScan(List.of("foo"), List.of("a", "b", "c"), List.of(R.I64, R.I64, R.STRING))); + RelNode relNode = substraitToCalcite.convert(rel); assertRowMatch(relNode.getRowType(), R.STRING, R.I64); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/ComplexSortTest.java b/isthmus/src/test/java/io/substrait/isthmus/ComplexSortTest.java index 300786f21..256541980 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ComplexSortTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ComplexSortTest.java @@ -2,12 +2,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals; -import io.substrait.dsl.SubstraitBuilder; import io.substrait.expression.Expression; -import io.substrait.extension.DefaultExtensionCatalog; -import io.substrait.extension.SimpleExtension; import io.substrait.relation.Rel; -import io.substrait.type.TypeCreator; import java.io.PrintWriter; import java.io.StringWriter; import java.util.List; @@ -21,15 +17,6 @@ class ComplexSortTest extends PlanTestBase { - private static final SimpleExtension.ExtensionCollection EXTENSION_COLLECTION = - DefaultExtensionCatalog.DEFAULT_COLLECTION; - - final TypeCreator R = TypeCreator.of(false); - SubstraitBuilder b = new SubstraitBuilder(extensions); - - final SubstraitToCalcite substraitToCalcite = - new SubstraitToCalcite(EXTENSION_COLLECTION, typeFactory); - /** * A {@link RelWriterImpl} that annotates each {@link RelNode} with its {@link RelCollation} trait * information. A {@link RelNode} is only annotated if its {@link RelCollation} is not empty. @@ -58,15 +45,15 @@ void handleInputReferenceSort() { // SELECT a FROM example ORDER BY a Rel rel = - b.project( - input -> b.fieldReferences(input, 0), - b.remap(1), - b.sort( + sb.project( + input -> sb.fieldReferences(input, 0), + sb.remap(1), + sb.sort( input -> List.of( - b.sortField( - b.fieldReference(input, 0), Expression.SortDirection.ASC_NULLS_LAST)), - b.namedScan(List.of("example"), List.of("a"), List.of(R.STRING)))); + sb.sortField( + sb.fieldReference(input, 0), Expression.SortDirection.ASC_NULLS_LAST)), + sb.namedScan(List.of("example"), List.of("a"), List.of(R.STRING)))); String expected = "Collation: [0]\n" @@ -85,16 +72,16 @@ void handleCastExpressionSort() { // SELECT a FROM example ORDER BY a::INT Rel rel = - b.project( - input -> b.fieldReferences(input, 0), - b.remap(1), - b.sort( + sb.project( + input -> sb.fieldReferences(input, 0), + sb.remap(1), + sb.sort( input -> List.of( - b.sortField( - b.cast(b.fieldReference(input, 0), R.I32), + sb.sortField( + sb.cast(sb.fieldReference(input, 0), R.I32), Expression.SortDirection.ASC_NULLS_LAST)), - b.namedScan(List.of("example"), List.of("a"), List.of(R.STRING)))); + sb.namedScan(List.of("example"), List.of("a"), List.of(R.STRING)))); String expected = "LogicalProject(a0=[$0])\n" @@ -115,16 +102,16 @@ void handleCastProjectAndSortWithSortDirection() { // SELECT a::INT FROM example ORDER BY a::INT DESC NULLS LAST Rel rel = - b.project( - input -> List.of(b.cast(b.fieldReference(input, 0), R.I32)), - b.remap(1), - b.sort( + sb.project( + input -> List.of(sb.cast(sb.fieldReference(input, 0), R.I32)), + sb.remap(1), + sb.sort( input -> List.of( - b.sortField( - b.cast(b.fieldReference(input, 0), R.I32), + sb.sortField( + sb.cast(sb.fieldReference(input, 0), R.I32), Expression.SortDirection.DESC_NULLS_LAST)), - b.namedScan(List.of("example"), List.of("a"), List.of(R.STRING)))); + sb.namedScan(List.of("example"), List.of("a"), List.of(R.STRING)))); String expected = "LogicalProject(a0=[CAST($0):INTEGER NOT NULL])\n" @@ -145,16 +132,16 @@ void handleCastSortToOriginalType() { // SELECT a FROM example ORDER BY a::VARCHAR Rel rel = - b.project( - input -> List.of(b.fieldReference(input, 0)), - b.remap(1), - b.sort( + sb.project( + input -> List.of(sb.fieldReference(input, 0)), + sb.remap(1), + sb.sort( input -> List.of( - b.sortField( - b.cast(b.fieldReference(input, 0), R.STRING), + sb.sortField( + sb.cast(sb.fieldReference(input, 0), R.STRING), Expression.SortDirection.DESC_NULLS_LAST)), - b.namedScan(List.of("example"), List.of("a"), List.of(R.STRING)))); + sb.namedScan(List.of("example"), List.of("a"), List.of(R.STRING)))); String expected = "LogicalProject(a0=[$0])\n" @@ -175,19 +162,19 @@ void handleComplex2ExpressionSort() { // SELECT b, a FROM example ORDER BY a::INT DESC, -b + 42 ASC NULLS LAST Rel rel = - b.project( - input -> List.of(b.fieldReference(input, 0), b.fieldReference(input, 1)), - b.remap(2, 3), - b.sort( + sb.project( + input -> List.of(sb.fieldReference(input, 0), sb.fieldReference(input, 1)), + sb.remap(2, 3), + sb.sort( input -> List.of( - b.sortField( - b.cast(b.fieldReference(input, 0), R.I32), + sb.sortField( + sb.cast(sb.fieldReference(input, 0), R.I32), Expression.SortDirection.DESC_NULLS_FIRST), - b.sortField( - b.add(b.negate(b.fieldReference(input, 1)), b.i32(42)), + sb.sortField( + sb.add(sb.negate(sb.fieldReference(input, 1)), sb.i32(42)), Expression.SortDirection.ASC_NULLS_LAST)), - b.namedScan(List.of("example"), List.of("a", "b"), List.of(R.STRING, R.I32)))); + sb.namedScan(List.of("example"), List.of("a", "b"), List.of(R.STRING, R.I32)))); String expected = "LogicalProject(a0=[$0], b0=[$1])\n" diff --git a/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java b/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java index 34e06d0ac..5da4b728f 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java @@ -4,7 +4,6 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import com.google.protobuf.Any; -import io.substrait.dsl.SubstraitBuilder; import io.substrait.expression.Expression.UserDefinedLiteral; import io.substrait.expression.ExpressionCreator; import io.substrait.extension.ExtensionCollector; @@ -58,8 +57,6 @@ class CustomFunctionTest extends PlanTestBase { static final SimpleExtension.ExtensionCollection extensionCollection = SimpleExtension.load("custom.yaml", FUNCTIONS_CUSTOM); - final SubstraitBuilder b = new SubstraitBuilder(extensionCollection); - // Create user-defined types static final String aTypeName = "a_type"; static final String bTypeName = "b_type"; @@ -265,6 +262,10 @@ public RelDataType toCalcite(Type.UserDefined type) { typeConverter, ImmutableFeatureBoard.builder().build()); + CustomFunctionTest() { + super(extensionCollection); + } + // Create a SubstraitToCalcite converter that has access to the custom Function Converters class CustomSubstraitToCalcite extends SubstraitToCalcite { @@ -292,11 +293,12 @@ void customScalarFunctionRoundtrip() { // CREATE TABLE example(a TEXT) // SELECT custom_scalar(a) FROM example Rel rel = - b.project( + sb.project( input -> - List.of(b.scalarFn(URN, "custom_scalar:str", R.STRING, b.fieldReference(input, 0))), - b.remap(1), - b.namedScan(List.of("example"), List.of("a"), List.of(R.STRING))); + List.of( + sb.scalarFn(URN, "custom_scalar:str", R.STRING, sb.fieldReference(input, 0))), + sb.remap(1), + sb.namedScan(List.of("example"), List.of("a"), List.of(R.STRING))); RelNode calciteRel = substraitToCalcite.convert(rel); Rel relReturned = calciteToSubstrait.apply(calciteRel); @@ -306,12 +308,13 @@ void customScalarFunctionRoundtrip() { @Test void customScalarAnyFunctionRoundtrip() { Rel rel = - b.project( + sb.project( input -> List.of( - b.scalarFn(URN, "custom_scalar_any:any", R.STRING, b.fieldReference(input, 0))), - b.remap(1), - b.namedScan(List.of("example"), List.of("a"), List.of(R.I64))); + sb.scalarFn( + URN, "custom_scalar_any:any", R.STRING, sb.fieldReference(input, 0))), + sb.remap(1), + sb.namedScan(List.of("example"), List.of("a"), List.of(R.I64))); RelNode calciteRel = substraitToCalcite.convert(rel); Rel relReturned = calciteToSubstrait.apply(calciteRel); @@ -321,13 +324,13 @@ void customScalarAnyFunctionRoundtrip() { @Test void customScalarAnyToAnyFunctionRoundtrip() { Rel rel = - b.project( + sb.project( input -> List.of( - b.scalarFn( - URN, "custom_scalar_any_to_any:any", R.FP64, b.fieldReference(input, 0))), - b.remap(1), - b.namedScan(List.of("example"), List.of("a"), List.of(R.FP64))); + sb.scalarFn( + URN, "custom_scalar_any_to_any:any", R.FP64, sb.fieldReference(input, 0))), + sb.remap(1), + sb.namedScan(List.of("example"), List.of("a"), List.of(R.FP64))); RelNode calciteRel = substraitToCalcite.convert(rel); Rel relReturned = calciteToSubstrait.apply(calciteRel); @@ -337,17 +340,17 @@ void customScalarAnyToAnyFunctionRoundtrip() { @Test void customScalarAny1Any1ToAny1FunctionRoundtrip() { Rel rel = - b.project( + sb.project( input -> List.of( - b.scalarFn( + sb.scalarFn( URN, "custom_scalar_any1any1_to_any1:any_any", R.FP64, - b.fieldReference(input, 0), - b.fieldReference(input, 1))), - b.remap(2), - b.namedScan(List.of("example"), List.of("a", "b"), List.of(R.FP64, R.FP64))); + sb.fieldReference(input, 0), + sb.fieldReference(input, 1))), + sb.remap(2), + sb.namedScan(List.of("example"), List.of("a", "b"), List.of(R.FP64, R.FP64))); RelNode calciteRel = substraitToCalcite.convert(rel); Rel relReturned = calciteToSubstrait.apply(calciteRel); @@ -357,17 +360,17 @@ void customScalarAny1Any1ToAny1FunctionRoundtrip() { @Test void customScalarAny1Any1ToAny1FunctionMismatch() { Rel rel = - b.project( + sb.project( input -> List.of( - b.scalarFn( + sb.scalarFn( URN, "custom_scalar_any1any1_to_any1:any_any", R.FP64, - b.fieldReference(input, 0), - b.fieldReference(input, 1))), - b.remap(2), - b.namedScan(List.of("example"), List.of("a", "b"), List.of(R.FP64, R.STRING))); + sb.fieldReference(input, 0), + sb.fieldReference(input, 1))), + sb.remap(2), + sb.namedScan(List.of("example"), List.of("a", "b"), List.of(R.FP64, R.STRING))); assertThrows( IllegalArgumentException.class, @@ -381,17 +384,17 @@ void customScalarAny1Any1ToAny1FunctionMismatch() { @Test void customScalarAny1Any2ToAny2FunctionRoundtrip() { Rel rel = - b.project( + sb.project( input -> List.of( - b.scalarFn( + sb.scalarFn( URN, "custom_scalar_any1any2_to_any2:any_any", R.STRING, - b.fieldReference(input, 0), - b.fieldReference(input, 1))), - b.remap(2), - b.namedScan(List.of("example"), List.of("a", "b"), List.of(R.FP64, R.STRING))); + sb.fieldReference(input, 0), + sb.fieldReference(input, 1))), + sb.remap(2), + sb.namedScan(List.of("example"), List.of("a", "b"), List.of(R.FP64, R.STRING))); RelNode calciteRel = substraitToCalcite.convert(rel); Rel relReturned = calciteToSubstrait.apply(calciteRel); @@ -401,16 +404,16 @@ void customScalarAny1Any2ToAny2FunctionRoundtrip() { @Test void customScalarListAnyRoundtrip() { Rel rel = - b.project( + sb.project( input -> List.of( - b.scalarFn( + sb.scalarFn( URN, "custom_scalar_listany_to_listany:list", R.list(R.I64), - b.fieldReference(input, 0))), - b.remap(1), - b.namedScan(List.of("example"), List.of("a"), List.of(R.list(R.I64)))); + sb.fieldReference(input, 0))), + sb.remap(1), + sb.namedScan(List.of("example"), List.of("a"), List.of(R.list(R.I64)))); RelNode calciteRel = substraitToCalcite.convert(rel); Rel relReturned = calciteToSubstrait.apply(calciteRel); @@ -420,17 +423,17 @@ void customScalarListAnyRoundtrip() { @Test void customScalarListAnyAndAnyRoundtrip() { Rel rel = - b.project( + sb.project( input -> List.of( - b.scalarFn( + sb.scalarFn( URN, "custom_scalar_listany_any_to_listany:list_any", R.list(R.STRING), - b.fieldReference(input, 0), - b.fieldReference(input, 1))), - b.remap(2), - b.namedScan( + sb.fieldReference(input, 0), + sb.fieldReference(input, 1))), + sb.remap(2), + sb.namedScan( List.of("example"), List.of("a", "b"), List.of(R.list(R.STRING), R.STRING))); RelNode calciteRel = substraitToCalcite.convert(rel); @@ -441,16 +444,16 @@ void customScalarListAnyAndAnyRoundtrip() { @Test void customScalarListStringRoundtrip() { Rel rel = - b.project( + sb.project( input -> List.of( - b.scalarFn( + sb.scalarFn( URN, "custom_scalar_liststring_to_liststring:list", R.list(R.STRING), - b.fieldReference(input, 0))), - b.remap(1), - b.namedScan(List.of("example"), List.of("a"), List.of(R.list(R.STRING)))); + sb.fieldReference(input, 0))), + sb.remap(1), + sb.namedScan(List.of("example"), List.of("a"), List.of(R.list(R.STRING)))); RelNode calciteRel = substraitToCalcite.convert(rel); Rel relReturned = calciteToSubstrait.apply(calciteRel); @@ -460,17 +463,17 @@ void customScalarListStringRoundtrip() { @Test void customScalarListStringAndAnyRoundtrip() { Rel rel = - b.project( + sb.project( input -> List.of( - b.scalarFn( + sb.scalarFn( URN, "custom_scalar_liststring_any_to_liststring:list_any", R.list(R.STRING), - b.fieldReference(input, 0), - b.fieldReference(input, 1))), - b.remap(2), - b.namedScan( + sb.fieldReference(input, 0), + sb.fieldReference(input, 1))), + sb.remap(2), + sb.namedScan( List.of("example"), List.of("a", "b"), List.of(R.list(R.STRING), R.STRING))); RelNode calciteRel = substraitToCalcite.convert(rel); @@ -481,19 +484,19 @@ void customScalarListStringAndAnyRoundtrip() { @Test void customScalarListStringAndAnyVariadic0Roundtrip() { Rel rel = - b.project( + sb.project( input -> List.of( - b.scalarFn( + sb.scalarFn( URN, "custom_scalar_liststring_anyvariadic0_to_liststring:list_any", R.list(R.STRING), - b.fieldReference(input, 0), - b.fieldReference(input, 1), - b.fieldReference(input, 2), - b.fieldReference(input, 3))), - b.remap(4), - b.namedScan( + sb.fieldReference(input, 0), + sb.fieldReference(input, 1), + sb.fieldReference(input, 2), + sb.fieldReference(input, 3))), + sb.remap(4), + sb.namedScan( List.of("example"), List.of("a", "b", "c", "d"), List.of(R.list(R.STRING), R.STRING, R.STRING, R.STRING))); @@ -506,16 +509,16 @@ void customScalarListStringAndAnyVariadic0Roundtrip() { @Test void customScalarListStringAndAnyVariadic0NoArgsRoundtrip() { Rel rel = - b.project( + sb.project( input -> List.of( - b.scalarFn( + sb.scalarFn( URN, "custom_scalar_liststring_anyvariadic0_to_liststring:list_any", R.list(R.STRING), - b.fieldReference(input, 0))), - b.remap(1), - b.namedScan(List.of("example"), List.of("a"), List.of(R.list(R.STRING)))); + sb.fieldReference(input, 0))), + sb.remap(1), + sb.namedScan(List.of("example"), List.of("a"), List.of(R.list(R.STRING)))); RelNode calciteRel = substraitToCalcite.convert(rel); Rel relReturned = calciteToSubstrait.apply(calciteRel); @@ -525,17 +528,17 @@ void customScalarListStringAndAnyVariadic0NoArgsRoundtrip() { @Test void customScalarListStringAndAnyVariadic1Roundtrip() { Rel rel = - b.project( + sb.project( input -> List.of( - b.scalarFn( + sb.scalarFn( URN, "custom_scalar_liststring_anyvariadic1_to_liststring:list_any", R.list(R.STRING), - b.fieldReference(input, 0), - b.fieldReference(input, 1))), - b.remap(2), - b.namedScan( + sb.fieldReference(input, 0), + sb.fieldReference(input, 1))), + sb.remap(2), + sb.namedScan( List.of("example"), List.of("a", "b"), List.of(R.list(R.STRING), R.STRING))); RelNode calciteRel = substraitToCalcite.convert(rel); @@ -548,14 +551,14 @@ void customAggregateFunctionRoundtrip() { // CREATE TABLE example (a BIGINT) // SELECT custom_aggregate(a) FROM example GROUP BY a Rel rel = - b.aggregate( - input -> b.grouping(input, 0), + sb.aggregate( + input -> sb.grouping(input, 0), input -> List.of( - b.measure( - b.aggregateFn( - URN, "custom_aggregate:i64", R.I64, b.fieldReference(input, 0)))), - b.namedScan(List.of("example"), List.of("a"), List.of(R.I64))); + sb.measure( + sb.aggregateFn( + URN, "custom_aggregate:i64", R.I64, sb.fieldReference(input, 0)))), + sb.namedScan(List.of("example"), List.of("a"), List.of(R.I64))); RelNode calciteRel = substraitToCalcite.convert(rel); Rel relReturned = calciteToSubstrait.apply(calciteRel); @@ -567,16 +570,16 @@ void customTypesInFunctionsRoundtrip() { // CREATE TABLE example(a a_type) // SELECT to_b_type(a) FROM example Rel rel = - b.project( + sb.project( input -> List.of( - b.scalarFn( + sb.scalarFn( URN, "to_b_type:u!a_type", R.userDefined(URN, "b_type"), - b.fieldReference(input, 0))), - b.remap(1), - b.namedScan(List.of("example"), List.of("a"), List.of(N.userDefined(URN, "a_type")))); + sb.fieldReference(input, 0))), + sb.remap(1), + sb.namedScan(List.of("example"), List.of("a"), List.of(N.userDefined(URN, "a_type")))); RelNode calciteRel = substraitToCalcite.convert(rel); Rel relReturned = calciteToSubstrait.apply(calciteRel); @@ -590,11 +593,11 @@ void customTypesLiteralInFunctionsRoundtrip() { UserDefinedLiteral val = ExpressionCreator.userDefinedLiteral(false, URN, "a_type", anyValue); Rel rel1 = - b.project( + sb.project( input -> - List.of(b.scalarFn(URN, "to_b_type:u!a_type", R.userDefined(URN, "b_type"), val)), - b.remap(1), - b.namedScan(List.of("example"), List.of("a"), List.of(N.userDefined(URN, "a_type")))); + List.of(sb.scalarFn(URN, "to_b_type:u!a_type", R.userDefined(URN, "b_type"), val)), + sb.remap(1), + sb.namedScan(List.of("example"), List.of("a"), List.of(N.userDefined(URN, "a_type")))); RelNode calciteRel = substraitToCalcite.convert(rel1); Rel rel2 = calciteToSubstrait.apply(calciteRel); diff --git a/isthmus/src/test/java/io/substrait/isthmus/EmptyArrayLiteralTest.java b/isthmus/src/test/java/io/substrait/isthmus/EmptyArrayLiteralTest.java index 20237e0f4..72fe318e8 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/EmptyArrayLiteralTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/EmptyArrayLiteralTest.java @@ -1,29 +1,24 @@ package io.substrait.isthmus; -import io.substrait.dsl.SubstraitBuilder; import io.substrait.expression.Expression.EmptyListLiteral; import io.substrait.expression.ExpressionCreator; import io.substrait.relation.Project; import io.substrait.relation.Rel; import io.substrait.type.Type; -import io.substrait.type.TypeCreator; import java.util.List; import org.junit.jupiter.api.Test; class EmptyArrayLiteralTest extends PlanTestBase { - private static final TypeCreator N = TypeCreator.of(true); - - private final SubstraitBuilder b = new SubstraitBuilder(extensions); @Test void emptyArrayLiteral() { Type colType = N.I8; EmptyListLiteral emptyListLiteral = ExpressionCreator.emptyList(false, N.I8); Project rel = - b.project( + sb.project( input -> List.of(emptyListLiteral), Rel.Remap.offset(1, 1), - b.namedScan(List.of("t"), List.of("col"), List.of(colType))); + sb.namedScan(List.of("t"), List.of("col"), List.of(colType))); assertFullRoundTrip(rel); } @@ -32,10 +27,10 @@ void nullableEmptyArrayLiteral() { Type colType = N.I8; EmptyListLiteral emptyListLiteral = ExpressionCreator.emptyList(true, N.I8); Project rel = - b.project( + sb.project( input -> List.of(emptyListLiteral), Rel.Remap.offset(1, 1), - b.namedScan(List.of("t"), List.of("col"), List.of(colType))); + sb.namedScan(List.of("t"), List.of("col"), List.of(colType))); assertFullRoundTrip(rel); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/ExpressionConvertabilityTest.java b/isthmus/src/test/java/io/substrait/isthmus/ExpressionConvertabilityTest.java index db04ffc28..00316b5be 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ExpressionConvertabilityTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ExpressionConvertabilityTest.java @@ -6,7 +6,6 @@ import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertThrows; -import io.substrait.dsl.SubstraitBuilder; import io.substrait.expression.Expression; import io.substrait.expression.ExpressionCreator; import io.substrait.expression.proto.ExpressionProtoConverter; @@ -18,7 +17,6 @@ import io.substrait.isthmus.expression.WindowFunctionConverter; import io.substrait.relation.Rel; import io.substrait.type.Type; -import io.substrait.type.TypeCreator; import io.substrait.util.EmptyVisitationContext; import java.io.IOException; import java.util.List; @@ -31,11 +29,6 @@ /** Tests which test that an expression can be converted to and from Calcite expressions. */ class ExpressionConvertabilityTest extends PlanTestBase { - static final TypeCreator R = TypeCreator.of(false); - static final TypeCreator N = TypeCreator.of(true); - - final SubstraitBuilder b = new SubstraitBuilder(extensions); - final ExpressionProtoConverter expressionProtoConverter = new ExpressionProtoConverter(new ExtensionCollector(), null); @@ -51,7 +44,7 @@ class ExpressionConvertabilityTest extends PlanTestBase { // Define a shared table (i.e. a NamedScan) for use in tests. final List commonTableType = List.of(R.I32, R.FP32, N.STRING, N.BOOLEAN); final Rel commonTable = - b.namedScan(List.of("example"), List.of("a", "b", "c", "d"), commonTableType); + sb.namedScan(List.of("example"), List.of("a", "b", "c", "d"), commonTableType); @Test void listLiteral() throws IOException, SqlParseException { @@ -72,7 +65,8 @@ void inPredicate() throws IOException, SqlParseException { @Test void singleOrList() { - Expression singleOrList = b.singleOrList(b.fieldReference(commonTable, 0), b.i32(5), b.i32(10)); + Expression singleOrList = + sb.singleOrList(sb.fieldReference(commonTable, 0), sb.i32(5), sb.i32(10)); RexNode rexNode = singleOrList.accept(converter, Context.newContext()); Expression substraitExpression = rexNode.accept( @@ -82,19 +76,19 @@ void singleOrList() { // cannot roundtrip test singleOrList because Calcite simplifies the representation assertExpressionEquality( - b.or( - b.equal(b.fieldReference(commonTable, 0), b.i32(5)), - b.equal(b.fieldReference(commonTable, 0), b.i32(10))), + sb.or( + sb.equal(sb.fieldReference(commonTable, 0), sb.i32(5)), + sb.equal(sb.fieldReference(commonTable, 0), sb.i32(10))), substraitExpression); } @Test void switchExpression() { Expression switchExpression = - b.switchExpression( - b.fieldReference(commonTable, 0), - List.of(b.switchClause(b.i32(5), b.i32(1)), b.switchClause(b.i32(10), b.i32(2))), - b.i32(3)); + sb.switchExpression( + sb.fieldReference(commonTable, 0), + List.of(sb.switchClause(sb.i32(5), sb.i32(1)), sb.switchClause(sb.i32(10), sb.i32(2))), + sb.i32(3)); RexNode rexNode = switchExpression.accept(converter, Context.newContext()); Expression expression = rexNode.accept( @@ -103,28 +97,30 @@ void switchExpression() { // cannot roundtrip test switchExpression because Calcite simplifies the representation assertExpressionEquality( - b.ifThen( + sb.ifThen( List.of( - b.ifClause(b.equal(b.fieldReference(commonTable, 0), b.i32(5)), b.i32(1)), - b.ifClause(b.equal(b.fieldReference(commonTable, 0), b.i32(10)), b.i32(2))), - b.i32(3)), + sb.ifClause(sb.equal(sb.fieldReference(commonTable, 0), sb.i32(5)), sb.i32(1)), + sb.ifClause(sb.equal(sb.fieldReference(commonTable, 0), sb.i32(10)), sb.i32(2))), + sb.i32(3)), expression); } @Test void castFailureCondition() { Rel rel = - b.project( + sb.project( input -> List.of( ExpressionCreator.cast( R.I64, - b.fieldReference(input, 0), + sb.fieldReference(input, 0), Expression.FailureBehavior.THROW_EXCEPTION), ExpressionCreator.cast( - R.I32, b.fieldReference(input, 0), Expression.FailureBehavior.RETURN_NULL)), - b.remap(1, 2), - b.namedScan(List.of("test"), List.of("col1"), List.of(R.STRING))); + R.I32, + sb.fieldReference(input, 0), + Expression.FailureBehavior.RETURN_NULL)), + sb.remap(1, 2), + sb.namedScan(List.of("test"), List.of("col1"), List.of(R.STRING))); assertFullRoundTrip(rel); } diff --git a/isthmus/src/test/java/io/substrait/isthmus/FetchTest.java b/isthmus/src/test/java/io/substrait/isthmus/FetchTest.java index 8e9824490..e2eb554e8 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/FetchTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/FetchTest.java @@ -1,34 +1,28 @@ package io.substrait.isthmus; -import io.substrait.dsl.SubstraitBuilder; import io.substrait.relation.Rel; -import io.substrait.type.TypeCreator; import java.util.List; import org.junit.jupiter.api.Test; class FetchTest extends PlanTestBase { - static final TypeCreator R = TypeCreator.of(false); - - final SubstraitBuilder b = new SubstraitBuilder(extensions); - - final Rel TABLE = b.namedScan(List.of("test"), List.of("col1"), List.of(R.STRING)); + final Rel TABLE = sb.namedScan(List.of("test"), List.of("col1"), List.of(R.STRING)); @Test void limitOnly() { - Rel rel = b.limit(50, TABLE); + Rel rel = sb.limit(50, TABLE); assertFullRoundTrip(rel); } @Test void offsetOnly() { - Rel rel = b.offset(50, TABLE); + Rel rel = sb.offset(50, TABLE); assertFullRoundTrip(rel); } @Test void offsetAndLimit() { - Rel rel = b.fetch(50, 10, TABLE); + Rel rel = sb.fetch(50, 10, TABLE); assertFullRoundTrip(rel); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/FunctionConversionTest.java b/isthmus/src/test/java/io/substrait/isthmus/FunctionConversionTest.java index 211f1b1ef..e08728400 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/FunctionConversionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/FunctionConversionTest.java @@ -4,7 +4,6 @@ import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertThrows; -import io.substrait.dsl.SubstraitBuilder; import io.substrait.expression.EnumArg; import io.substrait.expression.Expression; import io.substrait.expression.Expression.ScalarFunctionInvocation; @@ -27,8 +26,6 @@ */ class FunctionConversionTest extends PlanTestBase { - final SubstraitBuilder b = new SubstraitBuilder(extensions); - final ExpressionRexConverter expressionRexConverter = new ExpressionRexConverter( typeFactory, @@ -52,7 +49,7 @@ void subtractDateIDay() { // this is being mapped to the wrong Calcite function. // TODO: https://github.com/substrait-io/substrait-java/issues/377 Expression.ScalarFunctionInvocation expr = - b.scalarFn( + sb.scalarFn( DefaultExtensionCatalog.FUNCTIONS_DATETIME, "subtract:date_iday", TypeCreator.REQUIRED.DATE, @@ -71,7 +68,7 @@ void subtractDateIDay() { @Test void extractTimestampTzScalarFunction() { ScalarFunctionInvocation reqTstzFn = - b.scalarFn( + sb.scalarFn( DefaultExtensionCatalog.FUNCTIONS_DATETIME, "extract:req_tstz_str", TypeCreator.REQUIRED.I64, @@ -92,7 +89,7 @@ void extractTimestampTzScalarFunction() { @Test void extractPrecisionTimestampTzScalarFunction() { ScalarFunctionInvocation reqPtstzFn = - b.scalarFn( + sb.scalarFn( DefaultExtensionCatalog.FUNCTIONS_DATETIME, "extract:req_ptstz_str", TypeCreator.REQUIRED.I64, @@ -113,7 +110,7 @@ void extractPrecisionTimestampTzScalarFunction() { @Test void extractTimestampScalarFunction() { ScalarFunctionInvocation reqTsFn = - b.scalarFn( + sb.scalarFn( DefaultExtensionCatalog.FUNCTIONS_DATETIME, "extract:req_ts", TypeCreator.REQUIRED.I64, @@ -131,7 +128,7 @@ void extractTimestampScalarFunction() { @Test void extractPrecisionTimestampScalarFunction() { ScalarFunctionInvocation reqPtsFn = - b.scalarFn( + sb.scalarFn( DefaultExtensionCatalog.FUNCTIONS_DATETIME, "extract:req_pts", TypeCreator.REQUIRED.I64, @@ -149,7 +146,7 @@ void extractPrecisionTimestampScalarFunction() { @Test void extractDateScalarFunction() { ScalarFunctionInvocation reqDateFn = - b.scalarFn( + sb.scalarFn( DefaultExtensionCatalog.FUNCTIONS_DATETIME, "extract:req_date", TypeCreator.REQUIRED.I64, @@ -167,7 +164,7 @@ void extractDateScalarFunction() { @Test void extractTimeScalarFunction() { ScalarFunctionInvocation reqTimeFn = - b.scalarFn( + sb.scalarFn( DefaultExtensionCatalog.FUNCTIONS_DATETIME, "extract:req_time", TypeCreator.REQUIRED.I64, @@ -185,7 +182,7 @@ void extractTimeScalarFunction() { @Test void extractDateWithIndexing() { ScalarFunctionInvocation reqReqDateFn = - b.scalarFn( + sb.scalarFn( DefaultExtensionCatalog.FUNCTIONS_DATETIME, "extract:req_req_date", TypeCreator.REQUIRED.I64, @@ -204,7 +201,7 @@ void extractDateWithIndexing() { @Test void unsupportedExtractTimestampTzWithIndexing() { ScalarFunctionInvocation reqReqTstzFn = - b.scalarFn( + sb.scalarFn( DefaultExtensionCatalog.FUNCTIONS_DATETIME, "extract:req_req_tstz_str", TypeCreator.REQUIRED.I64, @@ -221,7 +218,7 @@ void unsupportedExtractTimestampTzWithIndexing() { @Test void unsupportedExtractPrecisionTimestampTzWithIndexing() { ScalarFunctionInvocation reqReqPtstzFn = - b.scalarFn( + sb.scalarFn( DefaultExtensionCatalog.FUNCTIONS_DATETIME, "extract:req_req_ptstz_str", TypeCreator.REQUIRED.I64, @@ -238,7 +235,7 @@ void unsupportedExtractPrecisionTimestampTzWithIndexing() { @Test void unsupportedExtractTimestampWithIndexing() { ScalarFunctionInvocation reqReqTsFn = - b.scalarFn( + sb.scalarFn( DefaultExtensionCatalog.FUNCTIONS_DATETIME, "extract:req_req_ts", TypeCreator.REQUIRED.I64, @@ -254,7 +251,7 @@ void unsupportedExtractTimestampWithIndexing() { @Test void unsupportedExtractPrecisionTimestampWithIndexing() { ScalarFunctionInvocation reqReqPtsFn = - b.scalarFn( + sb.scalarFn( DefaultExtensionCatalog.FUNCTIONS_DATETIME, "extract:req_req_pts", TypeCreator.REQUIRED.I64, diff --git a/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java b/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java index a99c5ff35..983e9c84e 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java @@ -23,9 +23,6 @@ void preserveNamesFromSql() throws Exception { CalciteCatalogReader catalogReader = SubstraitCreateStatementParser.processCreateStatementsToCatalog(createStatement); - SubstraitToCalcite substraitToCalcite = - new SubstraitToCalcite(EXTENSION_COLLECTION, typeFactory); - String query = "SELECT \"a\", \"B\" FROM foo GROUP BY a, b"; List expectedNames = List.of("a", "B"); @@ -45,7 +42,7 @@ void preserveNamesFromSql() throws Exception { @Test void preserveNamesFromSubstrait() { NamedScan rel = - substraitBuilder.namedScan( + sb.namedScan( List.of("foo"), List.of("i64", "struct", "struct0", "struct1"), List.of(R.I64, R.struct(R.FP64, R.STRING))); diff --git a/isthmus/src/test/java/io/substrait/isthmus/NestedExpressionsTest.java b/isthmus/src/test/java/io/substrait/isthmus/NestedExpressionsTest.java index e0c4b8023..f6b6324af 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/NestedExpressionsTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/NestedExpressionsTest.java @@ -3,11 +3,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import com.google.protobuf.ByteString; -import io.substrait.dsl.SubstraitBuilder; import io.substrait.expression.Expression; import io.substrait.expression.ImmutableExpression; -import io.substrait.extension.DefaultExtensionCatalog; -import io.substrait.extension.SimpleExtension; import io.substrait.relation.Project; import io.substrait.relation.Rel; import io.substrait.type.Type; @@ -19,22 +16,17 @@ class NestedExpressionsTest extends PlanTestBase { - protected static final SimpleExtension.ExtensionCollection defaultExtensionCollection = - DefaultExtensionCatalog.DEFAULT_COLLECTION; - protected SubstraitBuilder b = new SubstraitBuilder(defaultExtensionCollection); - SubstraitToCalcite substraitToCalcite = new SubstraitToCalcite(extensions, typeFactory); - Expression literalExpression = Expression.BoolLiteral.builder().value(true).build(); - Expression.ScalarFunctionInvocation nonLiteralExpression = b.add(b.i32(7), b.i32(42)); - Expression.ScalarFunctionInvocation nonLiteralExpression2 = b.add(b.i32(3), b.i32(4)); + Expression.ScalarFunctionInvocation nonLiteralExpression = sb.add(sb.i32(7), sb.i32(42)); + Expression.ScalarFunctionInvocation nonLiteralExpression2 = sb.add(sb.i32(3), sb.i32(4)); final List tableType = List.of(R.I32, R.FP32, N.STRING, N.BOOLEAN, N.STRING); final Rel commonTable = - b.namedScan(List.of("example"), List.of("a", "b", "c", "d", "e"), tableType); - final Rel emptyTable = b.emptyScan(); + sb.namedScan(List.of("example"), List.of("a", "b", "c", "d", "e"), tableType); + final Rel emptyTable = sb.emptyScan(); - Expression fieldRef1 = b.fieldReference(commonTable, 2); - Expression fieldRef2 = b.fieldReference(commonTable, 4); + Expression fieldRef1 = sb.fieldReference(commonTable, 2); + Expression fieldRef2 = sb.fieldReference(commonTable, 4); @Test void nestedListWithLiteralsTest() { @@ -99,7 +91,7 @@ void nestedListWithFieldReferenceTest() { @Test void nestedListWithStringLiteralsTest() { Expression.NestedList nestedList = - Expression.NestedList.builder().addValues(b.str("xzy")).addValues(b.str("abc")).build(); + Expression.NestedList.builder().addValues(sb.str("xzy")).addValues(sb.str("abc")).build(); Rel project = Project.builder().expressions(List.of(nestedList)).input(emptyTable).build(); diff --git a/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java b/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java index a37916bc4..2397c3e53 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java +++ b/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java @@ -39,14 +39,18 @@ import org.apache.calcite.tools.RelBuilder; public class PlanTestBase { + + protected static final TypeCreator R = TypeCreator.of(false); + protected static final TypeCreator N = TypeCreator.of(true); + protected final SimpleExtension.ExtensionCollection extensions; protected final RelCreator creator = new RelCreator(); protected final RelBuilder builder = creator.createRelBuilder(); protected final RelDataTypeFactory typeFactory = creator.typeFactory(); - protected final SubstraitBuilder substraitBuilder; - protected static final TypeCreator R = TypeCreator.of(false); - protected static final TypeCreator N = TypeCreator.of(true); + + protected final SubstraitBuilder sb; + protected final SubstraitToCalcite substraitToCalcite; protected static final CalciteCatalogReader TPCH_CATALOG; @@ -70,7 +74,8 @@ protected PlanTestBase() { protected PlanTestBase(SimpleExtension.ExtensionCollection extensions) { this.extensions = extensions; - this.substraitBuilder = new SubstraitBuilder(extensions); + this.sb = new SubstraitBuilder(extensions); + this.substraitToCalcite = new SubstraitToCalcite(extensions, typeFactory); } public static String asString(String resource) throws IOException { diff --git a/isthmus/src/test/java/io/substrait/isthmus/ProjectTest.java b/isthmus/src/test/java/io/substrait/isthmus/ProjectTest.java index d6d7868b7..b1f1730ce 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ProjectTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ProjectTest.java @@ -1,18 +1,16 @@ package io.substrait.isthmus; -import io.substrait.dsl.SubstraitBuilder; import io.substrait.relation.Project; import io.substrait.relation.Rel; import org.junit.jupiter.api.Test; class ProjectTest extends PlanTestBase { - final SubstraitBuilder b = new SubstraitBuilder(extensions); - final Rel emptyTable = b.emptyScan(); + final Rel emptyTable = sb.emptyScan(); @Test void avoidProjectRemapOnEmptyInput() { Rel projection = - Project.builder().input(emptyTable).addExpressions(b.add(b.i32(1), b.i32(2))).build(); + Project.builder().input(emptyTable).addExpressions(sb.add(sb.i32(1), sb.i32(2))).build(); assertFullRoundTrip(projection); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/RelExtensionRoundtripTest.java b/isthmus/src/test/java/io/substrait/isthmus/RelExtensionRoundtripTest.java index fc1c4d812..19831745e 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/RelExtensionRoundtripTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/RelExtensionRoundtripTest.java @@ -44,25 +44,23 @@ class RelExtensionRoundtripTest extends PlanTestBase { @Test void extensionLeafRelDetailTest() { - ColumnAppendDetail detail = new ColumnAppendDetail(substraitBuilder.i32(1)); + ColumnAppendDetail detail = new ColumnAppendDetail(sb.i32(1)); ImmutableExtensionLeaf rel = ExtensionLeaf.from(detail).build(); roundtrip(rel); } @Test void extensionSingleRelDetailTest() { - ColumnAppendDetail detail = new ColumnAppendDetail(substraitBuilder.i32(2)); - ImmutableExtensionSingle rel = - ExtensionSingle.from(detail, substraitBuilder.emptyScan()).build(); + ColumnAppendDetail detail = new ColumnAppendDetail(sb.i32(2)); + ImmutableExtensionSingle rel = ExtensionSingle.from(detail, sb.emptyScan()).build(); roundtrip(rel); } @Test void extensionMultiRelDetailTest() { - ColumnAppendDetail detail = new ColumnAppendDetail(substraitBuilder.i32(3)); + ColumnAppendDetail detail = new ColumnAppendDetail(sb.i32(3)); ImmutableExtensionMulti rel = - ExtensionMulti.from(detail, substraitBuilder.emptyScan(), substraitBuilder.emptyScan()) - .build(); + ExtensionMulti.from(detail, sb.emptyScan(), sb.emptyScan()).build(); roundtrip(rel); } diff --git a/isthmus/src/test/java/io/substrait/isthmus/SchemaCollectorTest.java b/isthmus/src/test/java/io/substrait/isthmus/SchemaCollectorTest.java index 2869ad797..c14d6737e 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/SchemaCollectorTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/SchemaCollectorTest.java @@ -4,7 +4,6 @@ import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; -import io.substrait.dsl.SubstraitBuilder; import io.substrait.expression.Expression; import io.substrait.expression.ExpressionCreator; import io.substrait.expression.FieldReference; @@ -21,7 +20,6 @@ class SchemaCollectorTest extends PlanTestBase { - SubstraitBuilder b = substraitBuilder; SchemaCollector schemaCollector = new SchemaCollector(typeFactory, TypeConverter.DEFAULT); void hasTable(CalciteSchema schema, String tableName, String tableSchema) { @@ -33,12 +31,12 @@ void hasTable(CalciteSchema schema, String tableName, String tableSchema) { @Test void canCollectTables() { Rel rel = - b.cross( - b.namedScan( + sb.cross( + sb.namedScan( List.of("table1"), List.of("col1", "col2", "col3"), List.of(N.I64, R.FP64, N.STRING)), - b.namedScan(List.of("table2"), List.of("col4", "col5"), List.of(N.BOOLEAN, N.I32))); + sb.namedScan(List.of("table2"), List.of("col4", "col5"), List.of(N.BOOLEAN, N.I32))); CalciteSchema calciteSchema = schemaCollector.toSchema(rel); hasTable( @@ -51,23 +49,23 @@ void canCollectTables() { @Test void canCollectTablesInSchemas() { Rel rel = - b.namedWrite( + sb.namedWrite( List.of("schema3", "table4"), List.of("col1", "col2", "col3", "col4", "col5", "col6"), AbstractWriteRel.WriteOp.UPDATE, AbstractWriteRel.CreateMode.REPLACE_IF_EXISTS, AbstractWriteRel.OutputMode.MODIFIED_RECORDS, - b.cross( - b.cross( - b.namedScan( + sb.cross( + sb.cross( + sb.namedScan( List.of("schema1", "table1"), List.of("col1", "col2", "col3"), List.of(N.I64, N.FP64, N.STRING)), - b.namedScan( + sb.namedScan( List.of("schema1", "table2"), List.of("col4", "col5"), List.of(N.BOOLEAN, N.I32))), - b.namedScan(List.of("schema2", "table3"), List.of("col6"), List.of(N.I64)))); + sb.namedScan(List.of("schema2", "table3"), List.of("col6"), List.of(N.I64)))); CalciteSchema calciteSchema = schemaCollector.toSchema(rel); CalciteSchema schema1 = calciteSchema.getSubSchema("schema1", false); @@ -113,13 +111,13 @@ void testUpdate() { Expression condition = ExpressionCreator.bool(false, true); Rel rel = - b.namedWrite( + sb.namedWrite( List.of("schema1", "table2"), List.of("col1"), AbstractWriteRel.WriteOp.INSERT, AbstractWriteRel.CreateMode.APPEND_IF_EXISTS, AbstractWriteRel.OutputMode.NO_OUTPUT, - b.namedUpdate( + sb.namedUpdate( List.of("schema1", "table1"), List.of("col1"), transformations, condition, true)); CalciteSchema calciteSchema = schemaCollector.toSchema(rel); @@ -132,10 +130,10 @@ void testUpdate() { @Test void canHandleMultipleSchemas() { Rel rel = - b.cross( - b.namedScan( + sb.cross( + sb.namedScan( List.of("level1", "level2a", "level3", "t1"), List.of("col1"), List.of(N.I64)), - b.namedScan(List.of("level1", "level2b", "t2"), List.of("col2"), List.of(N.I32))); + sb.namedScan(List.of("level1", "level2b", "t2"), List.of("col2"), List.of(N.I32))); CalciteSchema rootSchema = schemaCollector.toSchema(rel); CalciteSchema level1 = rootSchema.getSubSchema("level1", false); @@ -150,8 +148,8 @@ void canHandleMultipleSchemas() { @Test void canHandleDuplicateNamedScans() { - Rel table = b.namedScan(List.of("table"), List.of("col1"), List.of(N.BOOLEAN)); - Rel rel = b.cross(table, table); + Rel table = sb.namedScan(List.of("table"), List.of("col1"), List.of(N.BOOLEAN)); + Rel rel = sb.cross(table, table); CalciteSchema calciteSchema = schemaCollector.toSchema(rel); hasTable(calciteSchema, "table", "RecordType(BOOLEAN col1) NOT NULL"); @@ -160,9 +158,9 @@ void canHandleDuplicateNamedScans() { @Test void validatesSchemasForDuplicateNamedScans() { Rel rel = - b.cross( - b.namedScan(List.of("t"), List.of("col1"), List.of(N.BOOLEAN)), - b.namedScan(List.of("t"), List.of("col1"), List.of(R.BOOLEAN))); + sb.cross( + sb.namedScan(List.of("t"), List.of("col1"), List.of(N.BOOLEAN)), + sb.namedScan(List.of("t"), List.of("col1"), List.of(R.BOOLEAN))); IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> schemaCollector.toSchema(rel)); @@ -174,9 +172,9 @@ void validatesSchemasForDuplicateNamedScans() { @Test void validatesSchemasForNestedDuplicateNamedScans() { Rel rel = - b.cross( - b.namedScan(List.of("s", "t"), List.of("col1"), List.of(N.BOOLEAN)), - b.namedScan(List.of("s", "t"), List.of("col1"), List.of(R.BOOLEAN))); + sb.cross( + sb.namedScan(List.of("s", "t"), List.of("col1"), List.of(N.BOOLEAN)), + sb.namedScan(List.of("s", "t"), List.of("col1"), List.of(R.BOOLEAN))); IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> schemaCollector.toSchema(rel)); diff --git a/isthmus/src/test/java/io/substrait/isthmus/SubstraitExpressionConverterTest.java b/isthmus/src/test/java/io/substrait/isthmus/SubstraitExpressionConverterTest.java index d2ffd60ef..f9d8d763d 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/SubstraitExpressionConverterTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/SubstraitExpressionConverterTest.java @@ -4,7 +4,6 @@ import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertThrows; -import io.substrait.dsl.SubstraitBuilder; import io.substrait.expression.Expression; import io.substrait.expression.Expression.Switch; import io.substrait.expression.WindowBound; @@ -17,7 +16,6 @@ import io.substrait.relation.Rel; import io.substrait.relation.Rel.Remap; import io.substrait.type.Type; -import io.substrait.type.TypeCreator; import java.util.List; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.logical.LogicalProject; @@ -28,16 +26,11 @@ class SubstraitExpressionConverterTest extends PlanTestBase { - static final TypeCreator R = TypeCreator.of(false); - static final TypeCreator N = TypeCreator.of(true); - - final SubstraitBuilder b = new SubstraitBuilder(extensions); - final ExpressionRexConverter converter; final List commonTableType = List.of(R.I32, R.FP32, N.STRING, N.BOOLEAN); final Rel commonTable = - b.namedScan(List.of("example"), List.of("a", "b", "c", "d"), commonTableType); + sb.namedScan(List.of("example"), List.of("a", "b", "c", "d"), commonTableType); final SubstraitRelNodeConverter relNodeConverter = new SubstraitRelNodeConverter(extensions, typeFactory, builder); @@ -56,10 +49,10 @@ public SubstraitExpressionConverterTest() { @Test void switchExpression() { Switch expr = - b.switchExpression( - b.fieldReference(commonTable, 0), - List.of(b.switchClause(b.i32(0), b.fieldReference(commonTable, 3))), - b.bool(false)); + sb.switchExpression( + sb.fieldReference(commonTable, 0), + List.of(sb.switchClause(sb.i32(0), sb.fieldReference(commonTable, 3))), + sb.bool(false)); RexNode calciteExpr = expr.accept(converter, Context.newContext()); assertTypeMatch(calciteExpr.getType(), N.BOOLEAN); @@ -72,9 +65,8 @@ void scalarSubQuery() { Expression.ScalarSubquery expr = Expression.ScalarSubquery.builder().type(R.I64).input(subQueryRel).build(); - Project query = b.project(input -> List.of(expr), b.emptyScan()); + Project query = sb.project(input -> List.of(expr), sb.emptyScan()); - SubstraitToCalcite substraitToCalcite = new SubstraitToCalcite(extensions, typeFactory); RelNode calciteRel = substraitToCalcite.convert(query); assertInstanceOf(LogicalProject.class, calciteRel); @@ -93,9 +85,8 @@ void existsSetPredicate() { .tuples(subQueryRel) .build(); - Project query = b.project(input -> List.of(expr), b.emptyScan()); + Project query = sb.project(input -> List.of(expr), sb.emptyScan()); - SubstraitToCalcite substraitToCalcite = new SubstraitToCalcite(extensions, typeFactory); RelNode calciteRel = substraitToCalcite.convert(query); assertInstanceOf(LogicalProject.class, calciteRel); @@ -114,9 +105,8 @@ void uniqueSetPredicate() { .tuples(subQueryRel) .build(); - Project query = b.project(input -> List.of(expr), b.emptyScan()); + Project query = sb.project(input -> List.of(expr), sb.emptyScan()); - SubstraitToCalcite substraitToCalcite = new SubstraitToCalcite(extensions, typeFactory); RelNode calciteRel = substraitToCalcite.convert(query); assertInstanceOf(LogicalProject.class, calciteRel); @@ -135,9 +125,8 @@ void unspecifiedSetPredicate() { .tuples(subQueryRel) .build(); - Project query = b.project(input -> List.of(expr), b.emptyScan()); + Project query = sb.project(input -> List.of(expr), sb.emptyScan()); - SubstraitToCalcite substraitToCalcite = new SubstraitToCalcite(extensions, typeFactory); Exception exception = assertThrows( UnsupportedOperationException.class, @@ -158,23 +147,23 @@ void unspecifiedSetPredicate() { * @return the Substrait {@link Rel} equivalent of the above SQL query */ Rel createSubQueryRel() { - return b.project( - input -> List.of(b.fieldReference(input, 0)), + return sb.project( + input -> List.of(sb.fieldReference(input, 0)), Remap.of(List.of(3)), - b.filter(input -> b.equal(b.fieldReference(input, 2), b.str("EUROPE")), commonTable)); + sb.filter(input -> sb.equal(sb.fieldReference(input, 2), sb.str("EUROPE")), commonTable)); } @Test void useSubstraitReturnTypeDuringScalarFunctionConversion() { Expression.ScalarFunctionInvocation expr = - b.scalarFn( + sb.scalarFn( DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, "add:i32_i32", // THIS IS (INTENTIONALLY) THE WRONG OUTPUT TYPE // SHOULD BE R.I32 R.FP32, - b.i32(7), - b.i32(42)); + sb.i32(7), + sb.i32(42)); RexNode calciteExpr = expr.accept(expressionRexConverter, Context.newContext()); assertEquals(TypeConverter.DEFAULT.toCalcite(typeFactory, R.FP32), calciteExpr.getType()); @@ -183,7 +172,7 @@ void useSubstraitReturnTypeDuringScalarFunctionConversion() { @Test void useSubstraitReturnTypeDuringWindowFunctionConversion() { Expression.WindowFunctionInvocation expr = - b.windowFn( + sb.windowFn( DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, "row_number:", // THIS IS (INTENTIONALLY) THE WRONG OUTPUT TYPE @@ -194,7 +183,7 @@ void useSubstraitReturnTypeDuringWindowFunctionConversion() { Expression.WindowBoundsType.RANGE, WindowBound.UNBOUNDED, WindowBound.UNBOUNDED, - b.i32(42)); + sb.i32(42)); RexNode calciteExpr = expr.accept(expressionRexConverter, Context.newContext()); assertEquals(TypeConverter.DEFAULT.toCalcite(typeFactory, R.STRING), calciteExpr.getType()); diff --git a/isthmus/src/test/java/io/substrait/isthmus/SubstraitRelNodeConverterTest.java b/isthmus/src/test/java/io/substrait/isthmus/SubstraitRelNodeConverterTest.java index e9cc9e02a..d751c9694 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/SubstraitRelNodeConverterTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/SubstraitRelNodeConverterTest.java @@ -1,13 +1,11 @@ package io.substrait.isthmus; -import io.substrait.dsl.SubstraitBuilder; import io.substrait.plan.Plan; import io.substrait.relation.Join.JoinType; import io.substrait.relation.Rel; import io.substrait.relation.Set.SetOp; import io.substrait.type.NamedStruct; import io.substrait.type.Type; -import io.substrait.type.TypeCreator; import java.util.Collections; import java.util.List; import java.util.stream.Collectors; @@ -18,47 +16,40 @@ class SubstraitRelNodeConverterTest extends PlanTestBase { - static final TypeCreator R = TypeCreator.of(false); - static final TypeCreator N = TypeCreator.of(true); - - final SubstraitBuilder b = new SubstraitBuilder(extensions); - // Define a shared table (i.e. a NamedScan) for use in tests. final List commonTableType = List.of(R.I32, R.FP32, N.STRING, N.BOOLEAN); final List commonTableTypeTwice = Stream.concat(commonTableType.stream(), commonTableType.stream()) .collect(Collectors.toList()); final Rel commonTable = - b.namedScan(List.of("example"), List.of("a", "b", "c", "d"), commonTableType); - - final SubstraitToCalcite converter = new SubstraitToCalcite(extensions, typeFactory); + sb.namedScan(List.of("example"), List.of("a", "b", "c", "d"), commonTableType); @Nested class Aggregate { @Test void direct() { Plan.Root root = - b.root( - b.aggregate( - input -> b.grouping(input, 0, 2), - input -> List.of(b.count(input, 0)), + sb.root( + sb.aggregate( + input -> sb.grouping(input, 0, 2), + input -> List.of(sb.count(input, 0)), commonTable)); - RelNode relNode = converter.convert(root.getInput()); + RelNode relNode = substraitToCalcite.convert(root.getInput()); assertRowMatch(relNode.getRowType(), R.I32, N.STRING, R.I64); } @Test void emit() { Plan.Root root = - b.root( - b.aggregate( - input -> b.grouping(input, 0, 2), - input -> List.of(b.count(input, 0)), - b.remap(1, 2), + sb.root( + sb.aggregate( + input -> sb.grouping(input, 0, 2), + input -> List.of(sb.count(input, 0)), + sb.remap(1, 2), commonTable)); - RelNode relNode = converter.convert(root.getInput()); + RelNode relNode = substraitToCalcite.convert(root.getInput()); assertRowMatch(relNode.getRowType(), N.STRING, R.I64); } } @@ -67,17 +58,17 @@ void emit() { class Cross { @Test void direct() { - Plan.Root root = b.root(b.cross(commonTable, commonTable)); + Plan.Root root = sb.root(sb.cross(commonTable, commonTable)); - RelNode relNode = converter.convert(root.getInput()); + RelNode relNode = substraitToCalcite.convert(root.getInput()); assertRowMatch(relNode.getRowType(), commonTableTypeTwice); } @Test void emit() { - Plan.Root root = b.root(b.cross(commonTable, commonTable, b.remap(0, 1, 4, 6))); + Plan.Root root = sb.root(sb.cross(commonTable, commonTable, sb.remap(0, 1, 4, 6))); - RelNode relNode = converter.convert(root.getInput()); + RelNode relNode = substraitToCalcite.convert(root.getInput()); assertRowMatch(relNode.getRowType(), R.I32, R.FP32, R.I32, N.STRING); } } @@ -86,17 +77,17 @@ void emit() { class Fetch { @Test void direct() { - Plan.Root root = b.root(b.fetch(20, 40, commonTable)); + Plan.Root root = sb.root(sb.fetch(20, 40, commonTable)); - RelNode relNode = converter.convert(root.getInput()); + RelNode relNode = substraitToCalcite.convert(root.getInput()); assertRowMatch(relNode.getRowType(), commonTableType); } @Test void emit() { - Plan.Root root = b.root(b.fetch(20, 40, b.remap(0, 2), commonTable)); + Plan.Root root = sb.root(sb.fetch(20, 40, sb.remap(0, 2), commonTable)); - RelNode relNode = converter.convert(root.getInput()); + RelNode relNode = substraitToCalcite.convert(root.getInput()); assertRowMatch(relNode.getRowType(), R.I32, N.STRING); } } @@ -105,17 +96,17 @@ void emit() { class Filter { @Test void direct() { - Plan.Root root = b.root(b.filter(input -> b.bool(true), commonTable)); + Plan.Root root = sb.root(sb.filter(input -> sb.bool(true), commonTable)); - RelNode relNode = converter.convert(root.getInput()); + RelNode relNode = substraitToCalcite.convert(root.getInput()); assertRowMatch(relNode.getRowType(), commonTableType); } @Test void emit() { - Plan.Root root = b.root(b.filter(input -> b.bool(true), b.remap(0, 2), commonTable)); + Plan.Root root = sb.root(sb.filter(input -> sb.bool(true), sb.remap(0, 2), commonTable)); - RelNode relNode = converter.convert(root.getInput()); + RelNode relNode = substraitToCalcite.convert(root.getInput()); assertRowMatch(relNode.getRowType(), R.I32, N.STRING); } } @@ -124,66 +115,66 @@ void emit() { class Join { @Test void direct() { - Plan.Root root = b.root(b.innerJoin(input -> b.bool(true), commonTable, commonTable)); + Plan.Root root = sb.root(sb.innerJoin(input -> sb.bool(true), commonTable, commonTable)); - RelNode relNode = converter.convert(root.getInput()); + RelNode relNode = substraitToCalcite.convert(root.getInput()); assertRowMatch(relNode.getRowType(), commonTableTypeTwice); } @Test void emit() { Plan.Root root = - b.root(b.innerJoin(input -> b.bool(true), b.remap(0, 6), commonTable, commonTable)); + sb.root(sb.innerJoin(input -> sb.bool(true), sb.remap(0, 6), commonTable, commonTable)); - RelNode relNode = converter.convert(root.getInput()); + RelNode relNode = substraitToCalcite.convert(root.getInput()); assertRowMatch(relNode.getRowType(), R.I32, N.STRING); } @Test void leftJoin() { final List joinTableType = List.of(R.STRING, R.FP64, R.BINARY); - final Rel joinTable = b.namedScan(List.of("join"), List.of("a", "b", "c"), joinTableType); + final Rel joinTable = sb.namedScan(List.of("join"), List.of("a", "b", "c"), joinTableType); Plan.Root root = - b.root( - b.project( - r -> b.fieldReferences(r, 0, 1, 3), - b.remap(6, 7, 8), - b.join(ji -> b.bool(true), JoinType.LEFT, joinTable, joinTable))); + sb.root( + sb.project( + r -> sb.fieldReferences(r, 0, 1, 3), + sb.remap(6, 7, 8), + sb.join(ji -> sb.bool(true), JoinType.LEFT, joinTable, joinTable))); - RelNode relNode = converter.convert(root.getInput()); + RelNode relNode = substraitToCalcite.convert(root.getInput()); assertRowMatch(relNode.getRowType(), R.STRING, R.FP64, N.STRING); } @Test void rightJoin() { final List joinTableType = List.of(R.STRING, R.FP64, R.BINARY); - final Rel joinTable = b.namedScan(List.of("join"), List.of("a", "b", "c"), joinTableType); + final Rel joinTable = sb.namedScan(List.of("join"), List.of("a", "b", "c"), joinTableType); Plan.Root root = - b.root( - b.project( - r -> b.fieldReferences(r, 0, 1, 3), - b.remap(6, 7, 8), - b.join(ji -> b.bool(true), JoinType.RIGHT, joinTable, joinTable))); + sb.root( + sb.project( + r -> sb.fieldReferences(r, 0, 1, 3), + sb.remap(6, 7, 8), + sb.join(ji -> sb.bool(true), JoinType.RIGHT, joinTable, joinTable))); - RelNode relNode = converter.convert(root.getInput()); + RelNode relNode = substraitToCalcite.convert(root.getInput()); assertRowMatch(relNode.getRowType(), N.STRING, N.FP64, R.STRING); } @Test void outerJoin() { final List joinTableType = List.of(R.STRING, R.FP64, R.BINARY); - final Rel joinTable = b.namedScan(List.of("join"), List.of("a", "b", "c"), joinTableType); + final Rel joinTable = sb.namedScan(List.of("join"), List.of("a", "b", "c"), joinTableType); Plan.Root root = - b.root( - b.project( - r -> b.fieldReferences(r, 0, 1, 3), - b.remap(6, 7, 8), - b.join(ji -> b.bool(true), JoinType.OUTER, joinTable, joinTable))); + sb.root( + sb.project( + r -> sb.fieldReferences(r, 0, 1, 3), + sb.remap(6, 7, 8), + sb.join(ji -> sb.bool(true), JoinType.OUTER, joinTable, joinTable))); - RelNode relNode = converter.convert(root.getInput()); + RelNode relNode = substraitToCalcite.convert(root.getInput()); assertRowMatch(relNode.getRowType(), N.STRING, N.FP64, N.STRING); } } @@ -193,20 +184,20 @@ class NamedScan { @Test void direct() { Plan.Root root = - b.root(b.namedScan(List.of("example"), List.of("a", "b"), List.of(R.I32, R.FP32))); + sb.root(sb.namedScan(List.of("example"), List.of("a", "b"), List.of(R.I32, R.FP32))); - RelNode relNode = converter.convert(root.getInput()); + RelNode relNode = substraitToCalcite.convert(root.getInput()); assertRowMatch(relNode.getRowType(), R.I32, R.FP32); } @Test void emit() { Plan.Root root = - b.root( - b.namedScan( - List.of("example"), List.of("a", "b"), List.of(R.I32, R.FP32), b.remap(1))); + sb.root( + sb.namedScan( + List.of("example"), List.of("a", "b"), List.of(R.I32, R.FP32), sb.remap(1))); - RelNode relNode = converter.convert(root.getInput()); + RelNode relNode = substraitToCalcite.convert(root.getInput()); assertRowMatch(relNode.getRowType(), R.FP32); } } @@ -215,9 +206,10 @@ void emit() { class Project { @Test void direct() { - Plan.Root root = b.root(b.project(input -> b.fieldReferences(input, 1, 0, 2), commonTable)); + Plan.Root root = + sb.root(sb.project(input -> sb.fieldReferences(input, 1, 0, 2), commonTable)); - RelNode relNode = converter.convert(root.getInput()); + RelNode relNode = substraitToCalcite.convert(root.getInput()); assertRowMatch( relNode.getRowType(), R.I32, R.FP32, N.STRING, N.BOOLEAN, R.FP32, R.I32, N.STRING); } @@ -225,11 +217,11 @@ void direct() { @Test void emit() { Plan.Root root = - b.root( - b.project( - input -> b.fieldReferences(input, 1, 0, 2), b.remap(0, 2, 4, 6), commonTable)); + sb.root( + sb.project( + input -> sb.fieldReferences(input, 1, 0, 2), sb.remap(0, 2, 4, 6), commonTable)); - RelNode relNode = converter.convert(root.getInput()); + RelNode relNode = substraitToCalcite.convert(root.getInput()); assertRowMatch(relNode.getRowType(), R.I32, N.STRING, R.FP32, N.STRING); } } @@ -238,17 +230,17 @@ void emit() { class Set { @Test void direct() { - Plan.Root root = b.root(b.set(SetOp.UNION_ALL, commonTable, commonTable)); + Plan.Root root = sb.root(sb.set(SetOp.UNION_ALL, commonTable, commonTable)); - RelNode relNode = converter.convert(root.getInput()); + RelNode relNode = substraitToCalcite.convert(root.getInput()); assertRowMatch(relNode.getRowType(), commonTableType); } @Test void emit() { - Plan.Root root = b.root(b.set(SetOp.UNION_ALL, b.remap(0, 2), commonTable, commonTable)); + Plan.Root root = sb.root(sb.set(SetOp.UNION_ALL, sb.remap(0, 2), commonTable, commonTable)); - RelNode relNode = converter.convert(root.getInput()); + RelNode relNode = substraitToCalcite.convert(root.getInput()); assertRowMatch(relNode.getRowType(), R.I32, N.STRING); } } @@ -257,18 +249,18 @@ void emit() { class Sort { @Test void direct() { - Plan.Root root = b.root(b.sort(input -> b.sortFields(input, 0, 1, 2), commonTable)); + Plan.Root root = sb.root(sb.sort(input -> sb.sortFields(input, 0, 1, 2), commonTable)); - RelNode relNode = converter.convert(root.getInput()); + RelNode relNode = substraitToCalcite.convert(root.getInput()); assertRowMatch(relNode.getRowType(), commonTableType); } @Test void emit() { Plan.Root root = - b.root(b.sort(input -> b.sortFields(input, 0, 1, 2), b.remap(0, 2), commonTable)); + sb.root(sb.sort(input -> sb.sortFields(input, 0, 1, 2), sb.remap(0, 2), commonTable)); - RelNode relNode = converter.convert(root.getInput()); + RelNode relNode = substraitToCalcite.convert(root.getInput()); assertRowMatch(relNode.getRowType(), R.I32, N.STRING); } } @@ -283,8 +275,8 @@ void direct() { .initialSchema(NamedStruct.of(Collections.emptyList(), R.struct(R.I32, N.STRING))) .build(); - Plan.Root root = b.root(emptyScan); - RelNode relNode = converter.convert(root.getInput()); + Plan.Root root = sb.root(emptyScan); + RelNode relNode = substraitToCalcite.convert(root.getInput()); assertRowMatch(relNode.getRowType(), List.of(R.I32, N.STRING)); } @@ -296,8 +288,8 @@ void emit() { .remap(Rel.Remap.of(List.of(0))) .build(); - Plan.Root root = b.root(emptyScanWithRemap); - RelNode relNode = converter.convert(root.getInput()); + Plan.Root root = sb.root(emptyScanWithRemap); + RelNode relNode = substraitToCalcite.convert(root.getInput()); assertRowMatch(relNode.getRowType(), R.I32); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/SubstraitToCalciteTest.java b/isthmus/src/test/java/io/substrait/isthmus/SubstraitToCalciteTest.java index d5d8ada75..68997a828 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/SubstraitToCalciteTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/SubstraitToCalciteTest.java @@ -13,18 +13,17 @@ import org.junit.jupiter.api.Test; class SubstraitToCalciteTest extends PlanTestBase { - final SubstraitToCalcite converter = new SubstraitToCalcite(extensions, typeFactory); @Test void testConvertRootSingleColumn() { Iterable types = List.of(TypeCreator.REQUIRED.STRING); Root root = Root.builder() - .input(substraitBuilder.namedScan(List.of("stores"), List.of("s"), types)) + .input(sb.namedScan(List.of("stores"), List.of("s"), types)) .addNames("store") .build(); - RelRoot relRoot = converter.convert(root); + RelRoot relRoot = substraitToCalcite.convert(root); assertEquals(root.getNames(), relRoot.fields.rightList()); } @@ -34,11 +33,11 @@ void testConvertRootMultipleColumns() { Iterable types = List.of(TypeCreator.REQUIRED.I64, TypeCreator.REQUIRED.STRING); Root root = Root.builder() - .input(substraitBuilder.namedScan(List.of("stores"), List.of("s_store_id", "s"), types)) + .input(sb.namedScan(List.of("stores"), List.of("s_store_id", "s"), types)) .addNames("s_store_id", "store") .build(); - RelRoot relRoot = converter.convert(root); + RelRoot relRoot = substraitToCalcite.convert(root); assertEquals(root.getNames(), relRoot.fields.rightList()); } @@ -51,14 +50,13 @@ void testConvertRootStructField() { Root root = Root.builder() .input( - substraitBuilder.namedScan( - List.of("stores"), List.of("s", "s_store_id", "s_store_name"), types)) + sb.namedScan(List.of("stores"), List.of("s", "s_store_id", "s_store_name"), types)) .addNames("store", "store_id", "store_name") .build(); assertEquals(List.of("store", "store_id", "store_name"), root.getNames()); - RelRoot relRoot = converter.convert(root); + RelRoot relRoot = substraitToCalcite.convert(root); // Apache Calcite's RelRoot.fields only contains the top level field names assertEquals(List.of("store"), relRoot.fields.rightList()); @@ -79,12 +77,11 @@ void testConvertRootArrayWithStructField() { Root root = Root.builder() .input( - substraitBuilder.namedScan( - List.of("stores"), List.of("s", "s_store_id", "s_store_name"), types)) + sb.namedScan(List.of("stores"), List.of("s", "s_store_id", "s_store_name"), types)) .addNames("store", "store_id", "store_name") .build(); - RelRoot relRoot = converter.convert(root); + RelRoot relRoot = substraitToCalcite.convert(root); // Apache Calcite's RelRoot.fields only contains the top level field names assertEquals(List.of("store"), relRoot.fields.rightList()); @@ -109,12 +106,11 @@ void testConvertRootMapWithStructValues() { Root root = Root.builder() .input( - substraitBuilder.namedScan( - List.of("stores"), List.of("s", "s_store_id", "s_store_name"), types)) + sb.namedScan(List.of("stores"), List.of("s", "s_store_id", "s_store_name"), types)) .addNames("store", "store_id", "store_name") .build(); - final RelRoot relRoot = converter.convert(root); + final RelRoot relRoot = substraitToCalcite.convert(root); // Apache Calcite's RelRoot.fields only contains the top level field names assertEquals(List.of("store"), relRoot.fields.rightList()); @@ -139,12 +135,11 @@ void testConvertRootMapWithStructKeys() { Root root = Root.builder() .input( - substraitBuilder.namedScan( - List.of("stores"), List.of("s", "s_store_id", "s_store_name"), types)) + sb.namedScan(List.of("stores"), List.of("s", "s_store_id", "s_store_name"), types)) .addNames("store", "store_id", "store_name") .build(); - RelRoot relRoot = converter.convert(root); + RelRoot relRoot = substraitToCalcite.convert(root); // Apache Calcite's RelRoot.fields only contains the top level field names assertEquals(List.of("store"), relRoot.fields.rightList()); diff --git a/isthmus/src/test/java/io/substrait/isthmus/VirtualTableScanTest.java b/isthmus/src/test/java/io/substrait/isthmus/VirtualTableScanTest.java index 7d0b26fd8..ce22c7d41 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/VirtualTableScanTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/VirtualTableScanTest.java @@ -2,7 +2,6 @@ import static org.junit.jupiter.api.Assertions.assertEquals; -import io.substrait.dsl.SubstraitBuilder; import io.substrait.expression.Expression; import io.substrait.relation.VirtualTableScan; import io.substrait.type.NamedStruct; @@ -19,9 +18,6 @@ class VirtualTableScanTest extends PlanTestBase { - final SubstraitBuilder b = new SubstraitBuilder(extensions); - final SubstraitToCalcite substraitToCalcite = new SubstraitToCalcite(extensions, typeFactory); - @Test void literalOnlyVirtualTable() { NamedStruct schema = @@ -29,8 +25,8 @@ void literalOnlyVirtualTable() { VirtualTableScan virtualTableScan = createVirtualTableScan( schema, - List.of(b.i32(2), b.fp64(4), b.str("a")), - List.of(b.i32(6), b.fp64(8.8), b.str("b"))); + List.of(sb.i32(2), sb.fp64(4), sb.str("a")), + List.of(sb.i32(6), sb.fp64(8.8), sb.str("b"))); // Check the specific Calcite encoding RelNode relNode = substraitToCalcite.convert(virtualTableScan); @@ -48,8 +44,8 @@ void expressionContainingVirtualTable() { VirtualTableScan virtualTableScan = createVirtualTableScan( schema, - List.of(b.i32(2), b.add(b.fp64(4.4), b.fp64(4.5))), - List.of(b.multiply(b.i32(6), b.i32(2)), b.fp64(8.8))); + List.of(sb.i32(2), sb.add(sb.fp64(4.4), sb.fp64(4.5))), + List.of(sb.multiply(sb.i32(6), sb.i32(2)), sb.fp64(8.8))); // Check the specific Calcite encoding RelNode relNode = substraitToCalcite.convert(virtualTableScan); diff --git a/isthmus/src/test/java/io/substrait/isthmus/WindowFunctionTest.java b/isthmus/src/test/java/io/substrait/isthmus/WindowFunctionTest.java index acf6942da..64ae6d60c 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/WindowFunctionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/WindowFunctionTest.java @@ -190,10 +190,10 @@ void rejectQueriesWithIgnoreNulls() { @ValueSource(strings = {"lag", "lead"}) void lagLeadFunctions(String function) { Rel rel = - substraitBuilder.project( + sb.project( input -> List.of( - substraitBuilder.windowFn( + sb.windowFn( DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, String.format("%s:any", function), R.FP64, @@ -202,9 +202,9 @@ void lagLeadFunctions(String function) { Expression.WindowBoundsType.ROWS, WindowBound.Preceding.UNBOUNDED, WindowBound.Following.CURRENT_ROW, - substraitBuilder.fieldReference(input, 0))), - substraitBuilder.remap(1), - substraitBuilder.namedScan(List.of("window_test"), List.of("a"), List.of(R.FP64))); + sb.fieldReference(input, 0))), + sb.remap(1), + sb.namedScan(List.of("window_test"), List.of("a"), List.of(R.FP64))); assertFullRoundTrip(rel); } @@ -213,10 +213,10 @@ void lagLeadFunctions(String function) { @ValueSource(strings = {"lag", "lead"}) void lagLeadWithOffset(String function) { Rel rel = - substraitBuilder.project( + sb.project( input -> List.of( - substraitBuilder.windowFn( + sb.windowFn( DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, String.format("%s:any_i32", function), R.FP64, @@ -225,10 +225,10 @@ void lagLeadWithOffset(String function) { Expression.WindowBoundsType.RANGE, WindowBound.Preceding.UNBOUNDED, WindowBound.Following.UNBOUNDED, - substraitBuilder.fieldReference(input, 0), - substraitBuilder.i32(1))), - substraitBuilder.remap(1), - substraitBuilder.namedScan(List.of("window_test"), List.of("a"), List.of(R.FP64))); + sb.fieldReference(input, 0), + sb.i32(1))), + sb.remap(1), + sb.namedScan(List.of("window_test"), List.of("a"), List.of(R.FP64))); assertFullRoundTrip(rel); } @@ -237,10 +237,10 @@ void lagLeadWithOffset(String function) { @ValueSource(strings = {"lag", "lead"}) void lagLeadWithOffsetAndDefault(String function) { Rel rel = - substraitBuilder.project( + sb.project( input -> List.of( - substraitBuilder.windowFn( + sb.windowFn( DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, String.format("%s:any_i32_any", function), R.I64, @@ -249,11 +249,11 @@ void lagLeadWithOffsetAndDefault(String function) { Expression.WindowBoundsType.ROWS, WindowBound.Preceding.UNBOUNDED, WindowBound.Following.CURRENT_ROW, - substraitBuilder.fieldReference(input, 0), - substraitBuilder.i32(1), - substraitBuilder.fp64(100.0))), - substraitBuilder.remap(1), - substraitBuilder.namedScan(List.of("window_test"), List.of("a"), List.of(R.FP64))); + sb.fieldReference(input, 0), + sb.i32(1), + sb.fp64(100.0))), + sb.remap(1), + sb.namedScan(List.of("window_test"), List.of("a"), List.of(R.FP64))); assertFullRoundTrip(rel); } diff --git a/isthmus/src/test/java/io/substrait/isthmus/expression/SubqueryConversionTest.java b/isthmus/src/test/java/io/substrait/isthmus/expression/SubqueryConversionTest.java index ea82e0a8d..410bdccd6 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/expression/SubqueryConversionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/expression/SubqueryConversionTest.java @@ -4,7 +4,6 @@ import io.substrait.expression.FieldReference; import io.substrait.isthmus.PlanTestBase; -import io.substrait.isthmus.SubstraitToCalcite; import io.substrait.isthmus.sql.SubstraitSqlDialect; import io.substrait.relation.NamedScan; import io.substrait.relation.Rel; @@ -15,22 +14,21 @@ import org.junit.jupiter.api.Test; class SubqueryConversionTest extends PlanTestBase { - protected final SubstraitToCalcite converter = new SubstraitToCalcite(extensions, typeFactory); private final Rel customerTableScan = - substraitBuilder.namedScan( + sb.namedScan( List.of("customer"), List.of("c_custkey", "c_nationkey"), List.of(TypeCreator.REQUIRED.I64, TypeCreator.REQUIRED.I64)); private final NamedScan orderTableScan = - substraitBuilder.namedScan( + sb.namedScan( List.of("orders"), List.of("o_orderkey", "o_custkey"), List.of(TypeCreator.REQUIRED.I64, TypeCreator.REQUIRED.I64)); private final NamedScan nationTableScan = - substraitBuilder.namedScan( + sb.namedScan( List.of("nation"), List.of("n_nationkey", "n_name"), List.of(TypeCreator.REQUIRED.I64, TypeCreator.REQUIRED.STRING)); @@ -44,22 +42,22 @@ void testOuterFieldReferenceOneStep() { * FROM orders */ final Rel root = - substraitBuilder.project( + sb.project( input -> List.of( // orders.o_orderkey - substraitBuilder.fieldReference(input, 0), + sb.fieldReference(input, 0), // (SELECT customer.c_nationkey FROM customer WHERE customer.c_custkey = // orders.o_custkey) - substraitBuilder.scalarSubquery( - substraitBuilder.project( - input2 -> List.of(substraitBuilder.fieldReference(input2, 1)), + sb.scalarSubquery( + sb.project( + input2 -> List.of(sb.fieldReference(input2, 1)), Remap.of(List.of(1)), - substraitBuilder.filter( + sb.filter( input2 -> - substraitBuilder.equal( + sb.equal( // customer.c_custkey - substraitBuilder.fieldReference(input2, 0), + sb.fieldReference(input2, 0), // orders.o_custkey FieldReference.newRootStructOuterReference( 1, TypeCreator.REQUIRED.I64, 1)), @@ -68,7 +66,7 @@ void testOuterFieldReferenceOneStep() { Remap.of(List.of(2, 3)), orderTableScan); - final RelNode calciteRel = converter.convert(root); + final RelNode calciteRel = substraitToCalcite.convert(root); // LogicalFilter has field reference with $cor0 correlation variable // outer LogicalProject has variablesSet containing $cor0 correlation variable @@ -110,30 +108,27 @@ void testOuterFieldReferenceTwoSteps() { * FROM orders */ final Rel root = - substraitBuilder.project( + sb.project( input -> List.of( - substraitBuilder.fieldReference(input, 0), - substraitBuilder.scalarSubquery( - substraitBuilder.project( - input2 -> List.of(substraitBuilder.fieldReference(input2, 1)), + sb.fieldReference(input, 0), + sb.scalarSubquery( + sb.project( + input2 -> List.of(sb.fieldReference(input2, 1)), Remap.of(List.of(2)), - substraitBuilder.filter( + sb.filter( input2 -> - substraitBuilder.equal( - substraitBuilder.fieldReference(input2, 0), - substraitBuilder.scalarSubquery( - substraitBuilder.project( - input3 -> - List.of( - substraitBuilder.fieldReference(input3, 1)), + sb.equal( + sb.fieldReference(input2, 0), + sb.scalarSubquery( + sb.project( + input3 -> List.of(sb.fieldReference(input3, 1)), Remap.of(List.of(1)), - substraitBuilder.filter( + sb.filter( input3 -> - substraitBuilder.equal( + sb.equal( // customer.c_custkey - substraitBuilder.fieldReference( - input3, 0), + sb.fieldReference(input3, 0), // orders.o_custkey FieldReference .newRootStructOuterReference( @@ -147,7 +142,7 @@ void testOuterFieldReferenceTwoSteps() { Remap.of(List.of(2, 3)), orderTableScan); - final RelNode calciteRel = converter.convert(root); + final RelNode calciteRel = substraitToCalcite.convert(root); // most inner LogicalFilter has field reference with $cor0 correlation variable // most outer LogicalProject has variablesSet containing $cor0 correlation variable @@ -195,37 +190,36 @@ void testInPredicateOuterFieldReference() { * FROM orders */ final Rel root = - substraitBuilder.project( + sb.project( input -> List.of( - substraitBuilder.fieldReference(input, 0), - substraitBuilder.scalarSubquery( - substraitBuilder.project( - input2 -> List.of(substraitBuilder.fieldReference(input2, 1)), + sb.fieldReference(input, 0), + sb.scalarSubquery( + sb.project( + input2 -> List.of(sb.fieldReference(input2, 1)), Remap.of(List.of(2)), - substraitBuilder.filter( + sb.filter( input2 -> - substraitBuilder.inPredicate( - substraitBuilder.project( - input3 -> - List.of(substraitBuilder.fieldReference(input3, 1)), + sb.inPredicate( + sb.project( + input3 -> List.of(sb.fieldReference(input3, 1)), Remap.of(List.of(1)), - substraitBuilder.filter( + sb.filter( input3 -> - substraitBuilder.equal( + sb.equal( // customer.c_custkey - substraitBuilder.fieldReference(input3, 0), + sb.fieldReference(input3, 0), // orders.o_custkey FieldReference.newRootStructOuterReference( 1, TypeCreator.REQUIRED.I64, 2)), customerTableScan)), - substraitBuilder.fieldReference(input2, 0)), + sb.fieldReference(input2, 0)), nationTableScan)), TypeCreator.NULLABLE.STRING)), Remap.of(List.of(2, 3)), orderTableScan); - final RelNode calciteRel = converter.convert(root); + final RelNode calciteRel = substraitToCalcite.convert(root); // most inner LogicalFilter has field reference with $cor0 correlation variable // most outer LogicalProject has variablesSet containing $cor0 correlation variable @@ -274,38 +268,35 @@ void testSetPredicateOuterFieldReference() { * FROM orders */ final Rel root = - substraitBuilder.project( + sb.project( input -> List.of( - substraitBuilder.fieldReference(input, 0), - substraitBuilder.scalarSubquery( - substraitBuilder.project( - input2 -> List.of(substraitBuilder.fieldReference(input2, 1)), + sb.fieldReference(input, 0), + sb.scalarSubquery( + sb.project( + input2 -> List.of(sb.fieldReference(input2, 1)), Remap.of(List.of(2)), - substraitBuilder.filter( + sb.filter( input2 -> - substraitBuilder.exists( - substraitBuilder.project( - input3 -> - List.of(substraitBuilder.fieldReference(input3, 1)), + sb.exists( + sb.project( + input3 -> List.of(sb.fieldReference(input3, 1)), Remap.of(List.of(1)), - substraitBuilder.filter( + sb.filter( input3 -> - substraitBuilder.and( - substraitBuilder.equal( + sb.and( + sb.equal( // customer.c_custkey - substraitBuilder.fieldReference( - input3, 0), + sb.fieldReference(input3, 0), // orders.o_custkey FieldReference .newRootStructOuterReference( 1, TypeCreator.REQUIRED.I64, 2)), - substraitBuilder.equal( + sb.equal( // customer.c_nationkey - substraitBuilder.fieldReference( - input3, 1), + sb.fieldReference(input3, 1), // nation.n_nationkey FieldReference .newRootStructOuterReference( @@ -318,7 +309,7 @@ void testSetPredicateOuterFieldReference() { Remap.of(List.of(2, 3)), orderTableScan); - final RelNode calciteRel = converter.convert(root); + final RelNode calciteRel = substraitToCalcite.convert(root); // most inner LogicalFilter has field references with $cor0 and $cor1 correlation variables // most outer LogicalProject has variablesSet containing $cor0 correlation variable