diff --git a/core/src/test/java/io/substrait/TestBase.java b/core/src/test/java/io/substrait/TestBase.java
index 3defbf78f..8307785a1 100644
--- a/core/src/test/java/io/substrait/TestBase.java
+++ b/core/src/test/java/io/substrait/TestBase.java
@@ -2,6 +2,8 @@
import static org.junit.jupiter.api.Assertions.assertEquals;
+import com.google.common.base.Charsets;
+import com.google.common.io.Resources;
import io.substrait.dsl.SubstraitBuilder;
import io.substrait.extension.DefaultExtensionCatalog;
import io.substrait.extension.ExtensionCollector;
@@ -10,24 +12,38 @@
import io.substrait.relation.Rel;
import io.substrait.relation.RelProtoConverter;
import io.substrait.type.TypeCreator;
+import java.io.IOException;
public abstract class TestBase {
- protected static final SimpleExtension.ExtensionCollection defaultExtensionCollection =
- DefaultExtensionCatalog.DEFAULT_COLLECTION;
+ protected static final TypeCreator R = TypeCreator.REQUIRED;
+ protected static final TypeCreator N = TypeCreator.NULLABLE;
- protected TypeCreator R = TypeCreator.REQUIRED;
- protected TypeCreator N = TypeCreator.NULLABLE;
+ protected final SimpleExtension.ExtensionCollection extensions;
- protected SubstraitBuilder b = new SubstraitBuilder(defaultExtensionCollection);
protected ExtensionCollector functionCollector = new ExtensionCollector();
protected RelProtoConverter relProtoConverter = new RelProtoConverter(functionCollector);
- protected ProtoRelConverter protoRelConverter =
- new ProtoRelConverter(functionCollector, defaultExtensionCollection);
+
+ protected SubstraitBuilder sb;
+ protected ProtoRelConverter protoRelConverter;
+
+ protected TestBase() {
+ this(DefaultExtensionCatalog.DEFAULT_COLLECTION);
+ }
+
+ protected TestBase(SimpleExtension.ExtensionCollection extensions) {
+ this.extensions = extensions;
+ this.sb = new SubstraitBuilder(extensions);
+ this.protoRelConverter = new ProtoRelConverter(functionCollector, extensions);
+ }
protected void verifyRoundTrip(Rel rel) {
io.substrait.proto.Rel protoRel = relProtoConverter.toProto(rel);
Rel relReturned = protoRelConverter.from(protoRel);
assertEquals(rel, relReturned);
}
+
+ public static String asString(String resource) throws IOException {
+ return Resources.toString(Resources.getResource(resource), Charsets.UTF_8);
+ }
}
diff --git a/core/src/test/java/io/substrait/expression/VariadicParameterConsistencyTest.java b/core/src/test/java/io/substrait/expression/VariadicParameterConsistencyTest.java
index 46fc5041e..a25a9009c 100644
--- a/core/src/test/java/io/substrait/expression/VariadicParameterConsistencyTest.java
+++ b/core/src/test/java/io/substrait/expression/VariadicParameterConsistencyTest.java
@@ -3,19 +3,17 @@
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertThrows;
+import io.substrait.TestBase;
import io.substrait.extension.ImmutableSimpleExtension;
import io.substrait.extension.SimpleExtension;
import io.substrait.function.ParameterizedType;
-import io.substrait.type.TypeCreator;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import org.junit.jupiter.api.Test;
/** Tests for variadic parameter consistency validation in Expression. */
-class VariadicParameterConsistencyTest {
-
- private static final TypeCreator R = TypeCreator.of(false);
+class VariadicParameterConsistencyTest extends TestBase {
/**
* Helper method to create a ScalarFunctionInvocation and test if it validates correctly. The
diff --git a/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionRoundTripTest.java b/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionRoundTripTest.java
index b84ef8bd2..9b6ae95a8 100644
--- a/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionRoundTripTest.java
+++ b/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionRoundTripTest.java
@@ -88,7 +88,7 @@ private static ImmutableExpressionReference getFieldReferenceExpression() {
private static ImmutableExpressionReference getScalarFunctionExpression() {
Expression.ScalarFunctionInvocation scalarFunctionInvocation =
- new SubstraitBuilder(defaultExtensionCollection)
+ new SubstraitBuilder(DefaultExtensionCatalog.DEFAULT_COLLECTION)
.scalarFn(
DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC_DECIMAL,
"add:dec_dec",
@@ -111,8 +111,9 @@ private static ImmutableAggregateFunctionReference getAggregateFunctionReference
.function(
AggregateFunctionInvocation.builder()
.arguments(Collections.emptyList())
- .declaration(defaultExtensionCollection.aggregateFunctions().get(0))
- .outputType(TypeCreator.of(false).I64)
+ .declaration(
+ DefaultExtensionCatalog.DEFAULT_COLLECTION.aggregateFunctions().get(0))
+ .outputType(R.I64)
.aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT)
.invocation(Expression.AggregationInvocation.ALL)
.build())
diff --git a/core/src/test/java/io/substrait/extension/AdvancedExtensionRelProtoConversionTest.java b/core/src/test/java/io/substrait/extension/AdvancedExtensionRelProtoConversionTest.java
index c217d078e..dec329ae0 100644
--- a/core/src/test/java/io/substrait/extension/AdvancedExtensionRelProtoConversionTest.java
+++ b/core/src/test/java/io/substrait/extension/AdvancedExtensionRelProtoConversionTest.java
@@ -2,7 +2,7 @@
import static org.junit.jupiter.api.Assertions.assertEquals;
-import io.substrait.dsl.SubstraitBuilder;
+import io.substrait.TestBase;
import io.substrait.expression.Expression;
import io.substrait.expression.Expression.SortDirection;
import io.substrait.expression.ExpressionCreator;
@@ -48,9 +48,7 @@
import io.substrait.utils.StringHolderHandlingProtoExtensionConverter;
import org.junit.jupiter.api.Test;
-class AdvancedExtensionRelProtoConversionTest {
- final SubstraitBuilder builder = new SubstraitBuilder(DefaultExtensionCatalog.DEFAULT_COLLECTION);
-
+class AdvancedExtensionRelProtoConversionTest extends TestBase {
final StringHolder enhanced = new StringHolder("ENHANCED");
final StringHolder optimized = new StringHolder("OPTIMIZED");
final AdvancedExtension, ?> extension =
@@ -168,7 +166,7 @@ void testAggregateRelConversionRoundtrip() throws Exception {
.initialSchema(
NamedStruct.builder().struct(TypeCreator.REQUIRED.struct()).build())
.build())
- .addMeasures(builder.countStar())
+ .addMeasures(sb.countStar())
.extension(extension)
.build();
@@ -191,7 +189,7 @@ void testSortRelConversionRoundtrip() throws Exception {
.addSortFields(
SortField.builder()
.direction(SortDirection.ASC_NULLS_FIRST)
- .expr(builder.fieldReference(scan, 0))
+ .expr(sb.fieldReference(scan, 0))
.build())
.extension(extension)
.build();
@@ -214,8 +212,7 @@ void testJoinRelConversionRoundtrip() throws Exception {
.left(scan)
.right(scan)
.joinType(JoinType.INNER)
- .condition(
- builder.equal(builder.fieldReference(scan, 0), builder.fieldReference(scan, 0)))
+ .condition(sb.equal(sb.fieldReference(scan, 0), sb.fieldReference(scan, 0)))
.extension(extension)
.build();
@@ -235,7 +232,7 @@ void testProjectRelConversionRoundtrip() throws Exception {
final Project rel =
Project.builder()
.input(scan)
- .addExpressions(builder.fieldReference(scan, 0))
+ .addExpressions(sb.fieldReference(scan, 0))
.extension(extension)
.build();
@@ -384,12 +381,12 @@ void testNamedUpdateRelConversionRoundtrip() throws Exception {
.addNames("CUSTOMER")
.tableSchema(schema)
.condition(
- builder.equal(
+ sb.equal(
FieldReference.builder()
.addSegments(FieldReference.StructField.of(0))
.type(TypeCreator.REQUIRED.BOOLEAN)
.build(),
- builder.bool(true)))
+ sb.bool(true)))
.addTransformations(
TransformExpression.builder()
.columnTarget(0)
diff --git a/core/src/test/java/io/substrait/extension/TypeExtensionTest.java b/core/src/test/java/io/substrait/extension/TypeExtensionTest.java
index abb4008d5..4b3b94a80 100644
--- a/core/src/test/java/io/substrait/extension/TypeExtensionTest.java
+++ b/core/src/test/java/io/substrait/extension/TypeExtensionTest.java
@@ -3,13 +3,13 @@
import static io.substrait.type.TypeCreator.REQUIRED;
import static org.junit.jupiter.api.Assertions.assertEquals;
-import io.substrait.dsl.SubstraitBuilder;
+import io.substrait.TestBase;
import io.substrait.plan.Plan;
import io.substrait.plan.PlanProtoConverter;
import io.substrait.plan.ProtoPlanConverter;
import io.substrait.type.Type;
-import io.substrait.type.TypeCreator;
-import java.io.InputStream;
+import java.io.IOException;
+import java.io.UncheckedIOException;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@@ -24,24 +24,29 @@
*
Roundtrip between POJO and Proto
*
*/
-class TypeExtensionTest {
-
- static final TypeCreator R = TypeCreator.of(false);
+class TypeExtensionTest extends TestBase {
static final String URN = "extension:test:custom_extensions";
- final SimpleExtension.ExtensionCollection extensionCollection;
+ static final SimpleExtension.ExtensionCollection CUSTOM_EXTENSION;
- {
- String path = "/extensions/custom_extensions.yaml";
- InputStream inputStream = this.getClass().getResourceAsStream(path);
- extensionCollection = SimpleExtension.load(path, inputStream);
+ static {
+ try {
+ String customExtensionStr = asString("extensions/custom_extensions.yaml");
+ CUSTOM_EXTENSION = SimpleExtension.load(URN, customExtensionStr);
+ } catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
}
- final SubstraitBuilder b = new SubstraitBuilder(extensionCollection);
- Type customType1 = b.userDefinedType(URN, "customType1");
- Type customType2 = b.userDefinedType(URN, "customType2");
+ Type customType1 = sb.userDefinedType(URN, "customType1");
+ Type customType2 = sb.userDefinedType(URN, "customType2");
final PlanProtoConverter planProtoConverter = new PlanProtoConverter();
- final ProtoPlanConverter protoPlanConverter = new ProtoPlanConverter(extensionCollection);
+ final ProtoPlanConverter protoPlanConverter;
+
+ TypeExtensionTest() {
+ super(CUSTOM_EXTENSION);
+ this.protoPlanConverter = new ProtoPlanConverter(extensions);
+ }
@Test
void roundtripCustomType() {
@@ -49,26 +54,26 @@ void roundtripCustomType() {
List tableName = Stream.of("example").collect(Collectors.toList());
List columnNames =
Stream.of("custom_type_column", "i64_column").collect(Collectors.toList());
- List types = Stream.of(customType1, R.I64).collect(Collectors.toList());
+ List types = Stream.of(customType1, R.I64).collect(Collectors.toList());
// SELECT custom_type_column, scalar1(custom_type_column), scalar2(i64_column)
// FROM example
Plan plan =
- b.plan(
- b.root(
- b.project(
+ sb.plan(
+ sb.root(
+ sb.project(
input ->
Stream.of(
- b.fieldReference(input, 0),
- b.scalarFn(
+ sb.fieldReference(input, 0),
+ sb.scalarFn(
URN,
"scalar1:u!customType1",
R.I64,
- b.fieldReference(input, 0)),
- b.scalarFn(
- URN, "scalar2:i64", customType2, b.fieldReference(input, 1)))
+ sb.fieldReference(input, 0)),
+ sb.scalarFn(
+ URN, "scalar2:i64", customType2, sb.fieldReference(input, 1)))
.collect(Collectors.toList()),
- b.namedScan(tableName, columnNames, types))));
+ sb.namedScan(tableName, columnNames, types))));
io.substrait.proto.Plan protoPlan = planProtoConverter.toProto(plan);
Plan planReturned = protoPlanConverter.from(protoPlan);
@@ -80,19 +85,21 @@ void roundtripNumberedAnyTypes() {
List tableName = Stream.of("example").collect(Collectors.toList());
List columnNames =
Stream.of("array_i64_type_column", "array_i64_column").collect(Collectors.toList());
- List types =
- Stream.of(REQUIRED.list(R.I64)).collect(Collectors.toList());
+ List types = Stream.of(REQUIRED.list(R.I64)).collect(Collectors.toList());
Plan plan =
- b.plan(
- b.root(
- b.project(
+ sb.plan(
+ sb.root(
+ sb.project(
input ->
Stream.of(
- b.scalarFn(
- URN, "array_index:list_i64", R.I64, b.fieldReference(input, 0)))
+ sb.scalarFn(
+ URN,
+ "array_index:list_i64",
+ R.I64,
+ sb.fieldReference(input, 0)))
.collect(Collectors.toList()),
- b.namedScan(tableName, columnNames, types))));
+ sb.namedScan(tableName, columnNames, types))));
io.substrait.proto.Plan protoPlan = planProtoConverter.toProto(plan);
Plan planReturned = protoPlanConverter.from(protoPlan);
assertEquals(plan, planReturned);
diff --git a/core/src/test/java/io/substrait/relation/ProtoRelConverterTest.java b/core/src/test/java/io/substrait/relation/ProtoRelConverterTest.java
index 3876fd2ad..d84f06234 100644
--- a/core/src/test/java/io/substrait/relation/ProtoRelConverterTest.java
+++ b/core/src/test/java/io/substrait/relation/ProtoRelConverterTest.java
@@ -25,7 +25,7 @@
class ProtoRelConverterTest extends TestBase {
final NamedScan commonTable =
- b.namedScan(Collections.emptyList(), Collections.emptyList(), Collections.emptyList());
+ sb.namedScan(Collections.emptyList(), Collections.emptyList(), Collections.emptyList());
/**
* Verify default behaviour of {@link ProtoRelConverter} in the presence of {@link
@@ -121,9 +121,7 @@ functionCollector, new StringHolderHandlingExtensionProtoConverter())
final Rel relFromProto =
new ProtoRelConverter(
- functionCollector,
- defaultExtensionCollection,
- new StringHolderHandlingProtoExtensionConverter())
+ functionCollector, extensions, new StringHolderHandlingProtoExtensionConverter())
.from(protoRel);
assertEquals(rel, relFromProto);
@@ -140,9 +138,7 @@ functionCollector, new StringHolderHandlingExtensionProtoConverter())
final Rel relFromProto =
new ProtoRelConverter(
- functionCollector,
- defaultExtensionCollection,
- new StringHolderHandlingProtoExtensionConverter())
+ functionCollector, extensions, new StringHolderHandlingProtoExtensionConverter())
.from(protoRel);
assertEquals(rel, relFromProto);
diff --git a/core/src/test/java/io/substrait/type/proto/AggregateRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/AggregateRoundtripTest.java
index 01276fd4a..6d32e9301 100644
--- a/core/src/test/java/io/substrait/type/proto/AggregateRoundtripTest.java
+++ b/core/src/test/java/io/substrait/type/proto/AggregateRoundtripTest.java
@@ -13,7 +13,6 @@
import io.substrait.relation.RelProtoConverter;
import io.substrait.relation.VirtualTableScan;
import io.substrait.type.NamedStruct;
-import io.substrait.type.TypeCreator;
import java.math.BigDecimal;
import java.util.Arrays;
import java.util.Collections;
@@ -33,8 +32,6 @@ private void assertAggregateRoundtrip(Expression.AggregationInvocation invocatio
ExtensionCollector functionCollector = new ExtensionCollector();
RelProtoConverter to = new RelProtoConverter(functionCollector);
- io.substrait.extension.SimpleExtension.ExtensionCollection extensions =
- defaultExtensionCollection;
ProtoRelConverter from = new ProtoRelConverter(functionCollector, extensions);
io.substrait.relation.ImmutableMeasure measure =
@@ -43,7 +40,7 @@ private void assertAggregateRoundtrip(Expression.AggregationInvocation invocatio
AggregateFunctionInvocation.builder()
.arguments(Collections.emptyList())
.declaration(extensions.aggregateFunctions().get(0))
- .outputType(TypeCreator.of(false).I64)
+ .outputType(R.I64)
.aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT)
.invocation(invocation)
.options(
@@ -56,7 +53,7 @@ private void assertAggregateRoundtrip(Expression.AggregationInvocation invocatio
Arrays.asList(
Expression.SortField.builder()
// SORT BY decimal
- .expr(b.fieldReference(input, 0))
+ .expr(sb.fieldReference(input, 0))
.direction(Expression.SortDirection.ASC_NULLS_LAST)
.build()))
.build())
diff --git a/core/src/test/java/io/substrait/type/proto/ConsistentPartitionWindowRelRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/ConsistentPartitionWindowRelRoundtripTest.java
index afd8a494d..bf8639380 100644
--- a/core/src/test/java/io/substrait/type/proto/ConsistentPartitionWindowRelRoundtripTest.java
+++ b/core/src/test/java/io/substrait/type/proto/ConsistentPartitionWindowRelRoundtripTest.java
@@ -18,11 +18,11 @@ class ConsistentPartitionWindowRelRoundtripTest extends TestBase {
@Test
void consistentPartitionWindowRoundtripSingle() {
SimpleExtension.WindowFunctionVariant windowFunctionDeclaration =
- defaultExtensionCollection.getWindowFunction(
+ extensions.getWindowFunction(
SimpleExtension.FunctionAnchor.of(
DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, "lead:any"));
Rel input =
- b.namedScan(
+ sb.namedScan(
Arrays.asList("test"),
Arrays.asList("a", "b", "c"),
Arrays.asList(R.I64, R.I16, R.I32));
@@ -34,7 +34,7 @@ void consistentPartitionWindowRoundtripSingle() {
ConsistentPartitionWindow.WindowRelFunctionInvocation.builder()
.declaration(windowFunctionDeclaration)
// lead(a)
- .arguments(Arrays.asList(b.fieldReference(input, 0)))
+ .arguments(Arrays.asList(sb.fieldReference(input, 0)))
.options(
Arrays.asList(
FunctionOption.builder()
@@ -49,12 +49,12 @@ void consistentPartitionWindowRoundtripSingle() {
.boundsType(Expression.WindowBoundsType.RANGE)
.build()))
// PARTITION BY b
- .partitionExpressions(Arrays.asList(b.fieldReference(input, 1)))
+ .partitionExpressions(Arrays.asList(sb.fieldReference(input, 1)))
.sorts(
Arrays.asList(
Expression.SortField.builder()
// SORT BY c
- .expr(b.fieldReference(input, 2))
+ .expr(sb.fieldReference(input, 2))
.direction(Expression.SortDirection.ASC_NULLS_FIRST)
.build()))
.build();
@@ -71,15 +71,15 @@ void consistentPartitionWindowRoundtripSingle() {
@Test
void consistentPartitionWindowRoundtripMulti() {
SimpleExtension.WindowFunctionVariant windowFunctionLeadDeclaration =
- defaultExtensionCollection.getWindowFunction(
+ extensions.getWindowFunction(
SimpleExtension.FunctionAnchor.of(
DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, "lead:any"));
SimpleExtension.WindowFunctionVariant windowFunctionLagDeclaration =
- defaultExtensionCollection.getWindowFunction(
+ extensions.getWindowFunction(
SimpleExtension.FunctionAnchor.of(
DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, "lead:any"));
Rel input =
- b.namedScan(
+ sb.namedScan(
Arrays.asList("test"),
Arrays.asList("a", "b", "c"),
Arrays.asList(R.I64, R.I16, R.I32));
@@ -91,7 +91,7 @@ void consistentPartitionWindowRoundtripMulti() {
ConsistentPartitionWindow.WindowRelFunctionInvocation.builder()
.declaration(windowFunctionLeadDeclaration)
// lead(a)
- .arguments(Arrays.asList(b.fieldReference(input, 0)))
+ .arguments(Arrays.asList(sb.fieldReference(input, 0)))
.options(
Arrays.asList(
FunctionOption.builder()
@@ -108,7 +108,7 @@ void consistentPartitionWindowRoundtripMulti() {
ConsistentPartitionWindow.WindowRelFunctionInvocation.builder()
.declaration(windowFunctionLagDeclaration)
// lag(a)
- .arguments(Arrays.asList(b.fieldReference(input, 0)))
+ .arguments(Arrays.asList(sb.fieldReference(input, 0)))
.options(
Arrays.asList(
FunctionOption.builder()
@@ -123,12 +123,12 @@ void consistentPartitionWindowRoundtripMulti() {
.boundsType(Expression.WindowBoundsType.RANGE)
.build()))
// PARTITION BY b
- .partitionExpressions(Arrays.asList(b.fieldReference(input, 1)))
+ .partitionExpressions(Arrays.asList(sb.fieldReference(input, 1)))
.sorts(
Arrays.asList(
Expression.SortField.builder()
// SORT BY c
- .expr(b.fieldReference(input, 2))
+ .expr(sb.fieldReference(input, 2))
.direction(Expression.SortDirection.ASC_NULLS_FIRST)
.build()))
.build();
diff --git a/core/src/test/java/io/substrait/type/proto/DdlRelRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/DdlRelRoundtripTest.java
index 82a816db0..5eb0c3100 100644
--- a/core/src/test/java/io/substrait/type/proto/DdlRelRoundtripTest.java
+++ b/core/src/test/java/io/substrait/type/proto/DdlRelRoundtripTest.java
@@ -44,7 +44,7 @@ void create() {
@Test
void alter() {
ProtoRelConverter protoRelConverter =
- new StringHolderHandlingProtoRelConverter(functionCollector, defaultExtensionCollection);
+ new StringHolderHandlingProtoRelConverter(functionCollector, extensions);
StringHolder detail = new StringHolder("DETAIL");
diff --git a/core/src/test/java/io/substrait/type/proto/ExchangeRelRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/ExchangeRelRoundtripTest.java
index 95b06b8fc..36e757289 100644
--- a/core/src/test/java/io/substrait/type/proto/ExchangeRelRoundtripTest.java
+++ b/core/src/test/java/io/substrait/type/proto/ExchangeRelRoundtripTest.java
@@ -17,7 +17,7 @@
class ExchangeRelRoundtripTest extends TestBase {
final Rel baseTable =
- b.namedScan(
+ sb.namedScan(
Collections.singletonList("exchange_test_table"),
Arrays.asList("id", "amount", "name", "status"),
Arrays.asList(R.I64, R.FP64, R.STRING, R.BOOLEAN));
@@ -42,7 +42,7 @@ void scatterExchange() {
Rel exchange =
ScatterExchange.builder()
.input(baseTable)
- .addFields(b.fieldReference(baseTable, 0))
+ .addFields(sb.fieldReference(baseTable, 0))
.partitionCount(1)
.build();
@@ -55,7 +55,7 @@ void singleBucketExchange() {
SingleBucketExchange.builder()
.input(baseTable)
.partitionCount(1)
- .expression(b.fieldReference(baseTable, 0))
+ .expression(sb.fieldReference(baseTable, 0))
.build();
verifyRoundTrip(exchange);
@@ -66,7 +66,7 @@ void multiBucketExchange() {
Rel exchange =
MultiBucketExchange.builder()
.input(baseTable)
- .expression(b.fieldReference(baseTable, 0))
+ .expression(sb.fieldReference(baseTable, 0))
.constrainedToCount(true)
.partitionCount(1)
.build();
diff --git a/core/src/test/java/io/substrait/type/proto/ExpandRelRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/ExpandRelRoundtripTest.java
index 127ac6c00..ef5e60a41 100644
--- a/core/src/test/java/io/substrait/type/proto/ExpandRelRoundtripTest.java
+++ b/core/src/test/java/io/substrait/type/proto/ExpandRelRoundtripTest.java
@@ -13,20 +13,20 @@
class ExpandRelRoundtripTest extends TestBase {
final Rel input =
- b.namedScan(
+ sb.namedScan(
Stream.of("a_table").collect(Collectors.toList()),
Stream.of("column1", "column2").collect(Collectors.toList()),
Stream.of(R.I64, R.I64).collect(Collectors.toList()));
private Expand.ExpandField getConsistentField(int index) {
- return Expand.ConsistentField.builder().expression(b.fieldReference(input, index)).build();
+ return Expand.ConsistentField.builder().expression(sb.fieldReference(input, index)).build();
}
private Expand.ExpandField getSwitchingField(List indexes) {
return Expand.SwitchingField.builder()
.addAllDuplicates(
indexes.stream()
- .map(index -> b.fieldReference(input, index))
+ .map(index -> sb.fieldReference(input, index))
.collect(Collectors.toList()))
.build();
}
@@ -35,7 +35,7 @@ private Expand.ExpandField getSwitchingField(List indexes) {
void expandConsistent() {
Rel rel =
Expand.builder()
- .from(b.expand(__ -> Collections.emptyList(), input))
+ .from(sb.expand(__ -> Collections.emptyList(), input))
.hint(
Hint.builder()
.alias("alias1")
@@ -52,7 +52,7 @@ void expandConsistent() {
void expandSwitching() {
Rel rel =
Expand.builder()
- .from(b.expand(__ -> Collections.emptyList(), input))
+ .from(sb.expand(__ -> Collections.emptyList(), input))
.hint(Hint.builder().addAllOutputNames(Arrays.asList("name1", "name2")).build())
.fields(
Stream.of(
diff --git a/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java
index e19f2ed7d..49895b01d 100644
--- a/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java
+++ b/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java
@@ -47,10 +47,10 @@
class ExtensionRoundtripTest extends TestBase {
final ProtoRelConverter protoRelConverter =
- new StringHolderHandlingProtoRelConverter(functionCollector, defaultExtensionCollection);
+ new StringHolderHandlingProtoRelConverter(functionCollector, extensions);
final Rel commonTable =
- b.namedScan(Collections.emptyList(), Collections.emptyList(), Collections.emptyList());
+ sb.namedScan(Collections.emptyList(), Collections.emptyList(), Collections.emptyList());
final AdvancedExtension commonExtension =
AdvancedExtension.builder()
@@ -105,7 +105,7 @@ void namedScan() {
Rel rel =
NamedScan.builder()
.from(
- b.namedScan(
+ sb.namedScan(
Collections.emptyList(), Collections.emptyList(), Collections.emptyList()))
.commonExtension(commonExtension)
.extension(relExtension)
@@ -123,7 +123,7 @@ void extensionTable() {
void filter() {
Rel rel =
Filter.builder()
- .from(b.filter(__ -> b.bool(true), commonTable))
+ .from(sb.filter(__ -> sb.bool(true), commonTable))
.commonExtension(commonExtension)
.extension(relExtension)
.build();
@@ -134,7 +134,7 @@ void filter() {
void fetch() {
Rel rel =
Fetch.builder()
- .from(b.fetch(1, 2, commonTable))
+ .from(sb.fetch(1, 2, commonTable))
.commonExtension(commonExtension)
.extension(relExtension)
.build();
@@ -145,7 +145,7 @@ void fetch() {
void aggregate() {
Rel rel =
Aggregate.builder()
- .from(b.aggregate(b::grouping, __ -> Collections.emptyList(), commonTable))
+ .from(sb.aggregate(sb::grouping, __ -> Collections.emptyList(), commonTable))
.commonExtension(commonExtension)
.extension(relExtension)
.build();
@@ -156,7 +156,7 @@ void aggregate() {
void sort() {
Rel rel =
Sort.builder()
- .from(b.sort(__ -> Collections.emptyList(), commonTable))
+ .from(sb.sort(__ -> Collections.emptyList(), commonTable))
.commonExtension(commonExtension)
.extension(relExtension)
.build();
@@ -167,7 +167,7 @@ void sort() {
void join() {
Rel rel =
Join.builder()
- .from(b.innerJoin(__ -> b.bool(true), commonTable, commonTable))
+ .from(sb.innerJoin(__ -> sb.bool(true), commonTable, commonTable))
.commonExtension(commonExtension)
.extension(relExtension)
.build();
@@ -182,7 +182,7 @@ void hashJoin() {
Rel relWithoutKeys =
HashJoin.builder()
.from(
- b.hashJoin(
+ sb.hashJoin(
leftEmptyKeys,
rightEmptyKeys,
HashJoin.JoinType.INNER,
@@ -202,7 +202,7 @@ void mergeJoin() {
Rel relWithoutKeys =
MergeJoin.builder()
.from(
- b.mergeJoin(
+ sb.mergeJoin(
leftEmptyKeys,
rightEmptyKeys,
MergeJoin.JoinType.INNER,
@@ -219,8 +219,8 @@ void nestedLoopJoin() {
Rel rel =
NestedLoopJoin.builder()
.from(
- b.nestedLoopJoin(
- __ -> b.bool(true), NestedLoopJoin.JoinType.INNER, commonTable, commonTable))
+ sb.nestedLoopJoin(
+ __ -> sb.bool(true), NestedLoopJoin.JoinType.INNER, commonTable, commonTable))
.commonExtension(commonExtension)
.extension(relExtension)
.build();
@@ -231,7 +231,7 @@ void nestedLoopJoin() {
void project() {
Rel rel =
Project.builder()
- .from(b.project(__ -> Collections.emptyList(), commonTable))
+ .from(sb.project(__ -> Collections.emptyList(), commonTable))
.commonExtension(commonExtension)
.extension(relExtension)
.build();
@@ -242,7 +242,7 @@ void project() {
void expand() {
Rel rel =
Expand.builder()
- .from(b.expand(__ -> Collections.emptyList(), commonTable))
+ .from(sb.expand(__ -> Collections.emptyList(), commonTable))
.commonExtension(commonExtension)
.build();
verifyRoundTrip(rel);
@@ -252,7 +252,7 @@ void expand() {
void set() {
Rel rel =
Set.builder()
- .from(b.set(Set.SetOp.UNION_ALL, commonTable))
+ .from(sb.set(Set.SetOp.UNION_ALL, commonTable))
.commonExtension(commonExtension)
.extension(relExtension)
.build();
@@ -284,7 +284,7 @@ void extensionLeafRel() {
void cross() {
Rel rel =
Cross.builder()
- .from(b.cross(commonTable, commonTable))
+ .from(sb.cross(commonTable, commonTable))
.commonExtension(commonExtension)
.extension(relExtension)
.build();
@@ -297,13 +297,13 @@ class ExtensionThroughExpression {
// Check that custom extensions in these relations can be handled.
Rel baseTable =
- b.namedScan(
+ sb.namedScan(
Stream.of("test_table").collect(Collectors.toList()),
Stream.of("test_column").collect(Collectors.toList()),
Stream.of(TypeCreator.REQUIRED.I64).collect(Collectors.toList()));
Rel relWithEnhancement =
Project.builder()
- .from(b.project(input -> Collections.emptyList(), baseTable))
+ .from(sb.project(input -> Collections.emptyList(), baseTable))
.commonExtension(commonExtension)
.extension(relExtension)
.build();
@@ -311,7 +311,7 @@ class ExtensionThroughExpression {
@Test
void scalarSubquery() {
Project rel =
- b.project(
+ sb.project(
input ->
Stream.of(
Expression.ScalarSubquery.builder()
@@ -327,7 +327,7 @@ void scalarSubquery() {
@Test
void inPredicate() {
Project rel =
- b.project(
+ sb.project(
input ->
Stream.of(
Expression.InPredicate.builder()
@@ -342,7 +342,7 @@ void inPredicate() {
@Test
void setPredicate() {
Project rel =
- b.project(
+ sb.project(
input ->
Stream.of(
Expression.SetPredicate.builder()
diff --git a/core/src/test/java/io/substrait/type/proto/FieldReferenceRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/FieldReferenceRoundtripTest.java
index a04f0bad4..4e00797c2 100644
--- a/core/src/test/java/io/substrait/type/proto/FieldReferenceRoundtripTest.java
+++ b/core/src/test/java/io/substrait/type/proto/FieldReferenceRoundtripTest.java
@@ -17,7 +17,7 @@
class FieldReferenceRoundtripTest extends TestBase {
final Rel baseTable =
- b.namedScan(
+ sb.namedScan(
Collections.singletonList("test_table"),
Arrays.asList("id", "amount", "name", "nested_struct"),
Arrays.asList(
@@ -30,7 +30,7 @@ class FieldReferenceRoundtripTest extends TestBase {
void simpleStructFieldReference() {
// Test simple root struct field reference via projection
Rel projection =
- Project.builder().input(baseTable).addExpressions(b.fieldReference(baseTable, 0)).build();
+ Project.builder().input(baseTable).addExpressions(sb.fieldReference(baseTable, 0)).build();
verifyRoundTrip(projection);
}
@@ -42,9 +42,9 @@ void multipleFieldReferences() {
Project.builder()
.input(baseTable)
.addExpressions(
- b.fieldReference(baseTable, 0),
- b.fieldReference(baseTable, 1),
- b.fieldReference(baseTable, 2))
+ sb.fieldReference(baseTable, 0),
+ sb.fieldReference(baseTable, 1),
+ sb.fieldReference(baseTable, 2))
.build();
verifyRoundTrip(projection);
@@ -53,7 +53,8 @@ void multipleFieldReferences() {
@Test
void fieldReferenceInFilter() {
// Test field reference in filter condition
- Expression condition = b.equal(b.fieldReference(baseTable, 0), b.fieldReference(baseTable, 0));
+ Expression condition =
+ sb.equal(sb.fieldReference(baseTable, 0), sb.fieldReference(baseTable, 0));
Rel filter = Filter.builder().input(baseTable).condition(condition).build();
@@ -63,7 +64,7 @@ void fieldReferenceInFilter() {
@Test
void fieldReferenceInComplexExpression() {
// Test field reference as part of arithmetic expression
- Expression add = b.add(b.fieldReference(baseTable, 0), b.fieldReference(baseTable, 0));
+ Expression add = sb.add(sb.fieldReference(baseTable, 0), sb.fieldReference(baseTable, 0));
Rel projection = Project.builder().input(baseTable).addExpressions(add).build();
@@ -76,13 +77,13 @@ void fieldReferenceInNestedProjection() {
Rel firstProjection =
Project.builder()
.input(baseTable)
- .addExpressions(b.fieldReference(baseTable, 0), b.fieldReference(baseTable, 2))
+ .addExpressions(sb.fieldReference(baseTable, 0), sb.fieldReference(baseTable, 2))
.build();
Rel secondProjection =
Project.builder()
.input(firstProjection)
- .addExpressions(b.fieldReference(firstProjection, 1))
+ .addExpressions(sb.fieldReference(firstProjection, 1))
.build();
verifyRoundTrip(secondProjection);
@@ -95,10 +96,10 @@ void fieldReferenceAllFields() {
Project.builder()
.input(baseTable)
.addExpressions(
- b.fieldReference(baseTable, 0),
- b.fieldReference(baseTable, 1),
- b.fieldReference(baseTable, 2),
- b.fieldReference(baseTable, 3))
+ sb.fieldReference(baseTable, 0),
+ sb.fieldReference(baseTable, 1),
+ sb.fieldReference(baseTable, 2),
+ sb.fieldReference(baseTable, 3))
.build();
verifyRoundTrip(projection);
@@ -108,9 +109,9 @@ void fieldReferenceAllFields() {
void fieldReferenceWithBooleanLogic() {
// Test field references in boolean expressions
Expression condition =
- b.and(
- b.equal(b.fieldReference(baseTable, 0), b.fieldReference(baseTable, 0)),
- b.equal(b.fieldReference(baseTable, 2), b.str("test")));
+ sb.and(
+ sb.equal(sb.fieldReference(baseTable, 0), sb.fieldReference(baseTable, 0)),
+ sb.equal(sb.fieldReference(baseTable, 2), sb.str("test")));
Rel filter = Filter.builder().input(baseTable).condition(condition).build();
@@ -120,8 +121,8 @@ void fieldReferenceWithBooleanLogic() {
@Test
void fieldReferenceInMultipleArithmetic() {
// Test multiple field references in arithmetic
- Expression add = b.add(b.fieldReference(baseTable, 1), b.fieldReference(baseTable, 1));
- Expression multiply = b.multiply(add, b.fieldReference(baseTable, 1));
+ Expression add = sb.add(sb.fieldReference(baseTable, 1), sb.fieldReference(baseTable, 1));
+ Expression multiply = sb.multiply(add, sb.fieldReference(baseTable, 1));
Rel projection = Project.builder().input(baseTable).addExpressions(multiply).build();
@@ -135,9 +136,9 @@ void fieldReferenceReordering() {
Project.builder()
.input(baseTable)
.addExpressions(
- b.fieldReference(baseTable, 3),
- b.fieldReference(baseTable, 0),
- b.fieldReference(baseTable, 2))
+ sb.fieldReference(baseTable, 3),
+ sb.fieldReference(baseTable, 0),
+ sb.fieldReference(baseTable, 2))
.build();
verifyRoundTrip(projection);
@@ -150,9 +151,9 @@ void sameFieldReferencedMultipleTimes() {
Project.builder()
.input(baseTable)
.addExpressions(
- b.fieldReference(baseTable, 0),
- b.fieldReference(baseTable, 0),
- b.fieldReference(baseTable, 0))
+ sb.fieldReference(baseTable, 0),
+ sb.fieldReference(baseTable, 0),
+ sb.fieldReference(baseTable, 0))
.build();
verifyRoundTrip(projection);
diff --git a/core/src/test/java/io/substrait/type/proto/FilterRelRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/FilterRelRoundtripTest.java
index 4ee77133f..95fce4c1c 100644
--- a/core/src/test/java/io/substrait/type/proto/FilterRelRoundtripTest.java
+++ b/core/src/test/java/io/substrait/type/proto/FilterRelRoundtripTest.java
@@ -11,7 +11,7 @@
class FilterRelRoundtripTest extends TestBase {
final Rel baseTable =
- b.namedScan(
+ sb.namedScan(
Collections.singletonList("test_table"),
Arrays.asList("id", "amount", "name", "status"),
Arrays.asList(R.I64, R.FP64, R.STRING, R.BOOLEAN));
@@ -19,7 +19,7 @@ class FilterRelRoundtripTest extends TestBase {
@Test
void simpleEqualityFilter() {
// Filter: WHERE id = 100
- Expression condition = b.equal(b.fieldReference(baseTable, 0), b.i32(100));
+ Expression condition = sb.equal(sb.fieldReference(baseTable, 0), sb.i32(100));
Rel filter = Filter.builder().input(baseTable).condition(condition).build();
@@ -29,7 +29,7 @@ void simpleEqualityFilter() {
@Test
void stringComparisonFilter() {
// Filter: WHERE name = 'John'
- Expression condition = b.equal(b.fieldReference(baseTable, 2), b.str("John"));
+ Expression condition = sb.equal(sb.fieldReference(baseTable, 2), sb.str("John"));
Rel filter = Filter.builder().input(baseTable).condition(condition).build();
@@ -40,9 +40,9 @@ void stringComparisonFilter() {
void andConditionFilter() {
// Filter: WHERE id = 10 AND amount = 100.0
Expression condition =
- b.and(
- b.equal(b.fieldReference(baseTable, 0), b.i32(10)),
- b.equal(b.fieldReference(baseTable, 1), b.fp64(100.0)));
+ sb.and(
+ sb.equal(sb.fieldReference(baseTable, 0), sb.i32(10)),
+ sb.equal(sb.fieldReference(baseTable, 1), sb.fp64(100.0)));
Rel filter = Filter.builder().input(baseTable).condition(condition).build();
@@ -53,9 +53,9 @@ void andConditionFilter() {
void orConditionFilter() {
// Filter: WHERE id = 5 OR id = 95
Expression condition =
- b.or(
- b.equal(b.fieldReference(baseTable, 0), b.i32(5)),
- b.equal(b.fieldReference(baseTable, 0), b.i32(95)));
+ sb.or(
+ sb.equal(sb.fieldReference(baseTable, 0), sb.i32(5)),
+ sb.equal(sb.fieldReference(baseTable, 0), sb.i32(95)));
Rel filter = Filter.builder().input(baseTable).condition(condition).build();
@@ -66,12 +66,12 @@ void orConditionFilter() {
void complexBooleanFilter() {
// Filter: WHERE (id = 10 AND amount = 100) OR status = true
Expression andCondition =
- b.and(
- b.equal(b.fieldReference(baseTable, 0), b.i32(10)),
- b.equal(b.fieldReference(baseTable, 1), b.fp64(100.0)));
+ sb.and(
+ sb.equal(sb.fieldReference(baseTable, 0), sb.i32(10)),
+ sb.equal(sb.fieldReference(baseTable, 1), sb.fp64(100.0)));
Expression condition =
- b.or(andCondition, b.equal(b.fieldReference(baseTable, 3), b.bool(true)));
+ sb.or(andCondition, sb.equal(sb.fieldReference(baseTable, 3), sb.bool(true)));
Rel filter = Filter.builder().input(baseTable).condition(condition).build();
@@ -81,7 +81,8 @@ void complexBooleanFilter() {
@Test
void multipleFieldComparison() {
// Filter: WHERE id = amount (comparing two fields)
- Expression condition = b.equal(b.fieldReference(baseTable, 0), b.fieldReference(baseTable, 1));
+ Expression condition =
+ sb.equal(sb.fieldReference(baseTable, 0), sb.fieldReference(baseTable, 1));
Rel filter = Filter.builder().input(baseTable).condition(condition).build();
@@ -91,10 +92,10 @@ void multipleFieldComparison() {
@Test
void nestedFilters() {
// Apply filter on top of another filter
- Expression firstCondition = b.equal(b.fieldReference(baseTable, 0), b.i32(10));
+ Expression firstCondition = sb.equal(sb.fieldReference(baseTable, 0), sb.i32(10));
Rel firstFilter = Filter.builder().input(baseTable).condition(firstCondition).build();
- Expression secondCondition = b.equal(b.fieldReference(firstFilter, 1), b.fp64(100.0));
+ Expression secondCondition = sb.equal(sb.fieldReference(firstFilter, 1), sb.fp64(100.0));
Rel secondFilter = Filter.builder().input(firstFilter).condition(secondCondition).build();
verifyRoundTrip(secondFilter);
@@ -103,8 +104,8 @@ void nestedFilters() {
@Test
void filterWithArithmeticExpression() {
// Filter: WHERE amount * 2 = 100
- Expression multiply = b.multiply(b.fieldReference(baseTable, 1), b.fp64(2.0));
- Expression condition = b.equal(multiply, b.fp64(100.0));
+ Expression multiply = sb.multiply(sb.fieldReference(baseTable, 1), sb.fp64(2.0));
+ Expression condition = sb.equal(multiply, sb.fp64(100.0));
Rel filter = Filter.builder().input(baseTable).condition(condition).build();
@@ -114,7 +115,7 @@ void filterWithArithmeticExpression() {
@Test
void filterWithBooleanField() {
// Filter: WHERE status (direct boolean field)
- Expression condition = b.fieldReference(baseTable, 3);
+ Expression condition = sb.fieldReference(baseTable, 3);
Rel filter = Filter.builder().input(baseTable).condition(condition).build();
@@ -124,8 +125,8 @@ void filterWithBooleanField() {
@Test
void filterWithAddition() {
// Filter: WHERE id + id = id (field with itself)
- Expression add = b.add(b.fieldReference(baseTable, 0), b.fieldReference(baseTable, 0));
- Expression condition = b.equal(add, b.fieldReference(baseTable, 0));
+ Expression add = sb.add(sb.fieldReference(baseTable, 0), sb.fieldReference(baseTable, 0));
+ Expression condition = sb.equal(add, sb.fieldReference(baseTable, 0));
Rel filter = Filter.builder().input(baseTable).condition(condition).build();
diff --git a/core/src/test/java/io/substrait/type/proto/GenericRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/GenericRoundtripTest.java
index 268729c68..3e3f42646 100644
--- a/core/src/test/java/io/substrait/type/proto/GenericRoundtripTest.java
+++ b/core/src/test/java/io/substrait/type/proto/GenericRoundtripTest.java
@@ -54,10 +54,7 @@ void roundtripTest(Method m, List