Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions core/src/test/java/io/substrait/TestBase.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import static org.junit.jupiter.api.Assertions.assertEquals;

import com.google.common.base.Charsets;
import com.google.common.io.Resources;
import io.substrait.dsl.SubstraitBuilder;
import io.substrait.extension.DefaultExtensionCatalog;
import io.substrait.extension.ExtensionCollector;
Expand All @@ -10,24 +12,38 @@
import io.substrait.relation.Rel;
import io.substrait.relation.RelProtoConverter;
import io.substrait.type.TypeCreator;
import java.io.IOException;

public abstract class TestBase {

protected static final SimpleExtension.ExtensionCollection defaultExtensionCollection =
DefaultExtensionCatalog.DEFAULT_COLLECTION;
protected static final TypeCreator R = TypeCreator.REQUIRED;
protected static final TypeCreator N = TypeCreator.NULLABLE;

protected TypeCreator R = TypeCreator.REQUIRED;
protected TypeCreator N = TypeCreator.NULLABLE;
protected final SimpleExtension.ExtensionCollection extensions;

protected SubstraitBuilder b = new SubstraitBuilder(defaultExtensionCollection);
protected ExtensionCollector functionCollector = new ExtensionCollector();
protected RelProtoConverter relProtoConverter = new RelProtoConverter(functionCollector);
protected ProtoRelConverter protoRelConverter =
new ProtoRelConverter(functionCollector, defaultExtensionCollection);

protected SubstraitBuilder sb;
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I opted to use sb for SubstraitBuilder, which felt slightly more descriptive than b. A short-name is useful because when it's used, it's called frequently to build a plan component. A descriptive name like substraitBuilder ends up being noisy.

protected ProtoRelConverter protoRelConverter;

protected TestBase() {
this(DefaultExtensionCatalog.DEFAULT_COLLECTION);
}

protected TestBase(SimpleExtension.ExtensionCollection extensions) {
this.extensions = extensions;
this.sb = new SubstraitBuilder(extensions);
this.protoRelConverter = new ProtoRelConverter(functionCollector, extensions);
}

protected void verifyRoundTrip(Rel rel) {
io.substrait.proto.Rel protoRel = relProtoConverter.toProto(rel);
Rel relReturned = protoRelConverter.from(protoRel);
assertEquals(rel, relReturned);
}

public static String asString(String resource) throws IOException {
return Resources.toString(Resources.getResource(resource), Charsets.UTF_8);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,17 @@
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertThrows;

import io.substrait.TestBase;
import io.substrait.extension.ImmutableSimpleExtension;
import io.substrait.extension.SimpleExtension;
import io.substrait.function.ParameterizedType;
import io.substrait.type.TypeCreator;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import org.junit.jupiter.api.Test;

/** Tests for variadic parameter consistency validation in Expression. */
class VariadicParameterConsistencyTest {

private static final TypeCreator R = TypeCreator.of(false);
class VariadicParameterConsistencyTest extends TestBase {

/**
* Helper method to create a ScalarFunctionInvocation and test if it validates correctly. The
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ private static ImmutableExpressionReference getFieldReferenceExpression() {

private static ImmutableExpressionReference getScalarFunctionExpression() {
Expression.ScalarFunctionInvocation scalarFunctionInvocation =
new SubstraitBuilder(defaultExtensionCollection)
new SubstraitBuilder(DefaultExtensionCatalog.DEFAULT_COLLECTION)
.scalarFn(
DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC_DECIMAL,
"add:dec_dec",
Expand All @@ -111,8 +111,9 @@ private static ImmutableAggregateFunctionReference getAggregateFunctionReference
.function(
AggregateFunctionInvocation.builder()
.arguments(Collections.emptyList())
.declaration(defaultExtensionCollection.aggregateFunctions().get(0))
.outputType(TypeCreator.of(false).I64)
.declaration(
DefaultExtensionCatalog.DEFAULT_COLLECTION.aggregateFunctions().get(0))
.outputType(R.I64)
.aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT)
.invocation(Expression.AggregationInvocation.ALL)
.build())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import static org.junit.jupiter.api.Assertions.assertEquals;

import io.substrait.dsl.SubstraitBuilder;
import io.substrait.TestBase;
import io.substrait.expression.Expression;
import io.substrait.expression.Expression.SortDirection;
import io.substrait.expression.ExpressionCreator;
Expand Down Expand Up @@ -48,9 +48,7 @@
import io.substrait.utils.StringHolderHandlingProtoExtensionConverter;
import org.junit.jupiter.api.Test;

class AdvancedExtensionRelProtoConversionTest {
final SubstraitBuilder builder = new SubstraitBuilder(DefaultExtensionCatalog.DEFAULT_COLLECTION);

class AdvancedExtensionRelProtoConversionTest extends TestBase {
final StringHolder enhanced = new StringHolder("ENHANCED");
final StringHolder optimized = new StringHolder("OPTIMIZED");
final AdvancedExtension<?, ?> extension =
Expand Down Expand Up @@ -168,7 +166,7 @@ void testAggregateRelConversionRoundtrip() throws Exception {
.initialSchema(
NamedStruct.builder().struct(TypeCreator.REQUIRED.struct()).build())
.build())
.addMeasures(builder.countStar())
.addMeasures(sb.countStar())
.extension(extension)
.build();

Expand All @@ -191,7 +189,7 @@ void testSortRelConversionRoundtrip() throws Exception {
.addSortFields(
SortField.builder()
.direction(SortDirection.ASC_NULLS_FIRST)
.expr(builder.fieldReference(scan, 0))
.expr(sb.fieldReference(scan, 0))
.build())
.extension(extension)
.build();
Expand All @@ -214,8 +212,7 @@ void testJoinRelConversionRoundtrip() throws Exception {
.left(scan)
.right(scan)
.joinType(JoinType.INNER)
.condition(
builder.equal(builder.fieldReference(scan, 0), builder.fieldReference(scan, 0)))
.condition(sb.equal(sb.fieldReference(scan, 0), sb.fieldReference(scan, 0)))
.extension(extension)
.build();

Expand All @@ -235,7 +232,7 @@ void testProjectRelConversionRoundtrip() throws Exception {
final Project rel =
Project.builder()
.input(scan)
.addExpressions(builder.fieldReference(scan, 0))
.addExpressions(sb.fieldReference(scan, 0))
.extension(extension)
.build();

Expand Down Expand Up @@ -384,12 +381,12 @@ void testNamedUpdateRelConversionRoundtrip() throws Exception {
.addNames("CUSTOMER")
.tableSchema(schema)
.condition(
builder.equal(
sb.equal(
FieldReference.builder()
.addSegments(FieldReference.StructField.of(0))
.type(TypeCreator.REQUIRED.BOOLEAN)
.build(),
builder.bool(true)))
sb.bool(true)))
.addTransformations(
TransformExpression.builder()
.columnTarget(0)
Expand Down
73 changes: 40 additions & 33 deletions core/src/test/java/io/substrait/extension/TypeExtensionTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
import static io.substrait.type.TypeCreator.REQUIRED;
import static org.junit.jupiter.api.Assertions.assertEquals;

import io.substrait.dsl.SubstraitBuilder;
import io.substrait.TestBase;
import io.substrait.plan.Plan;
import io.substrait.plan.PlanProtoConverter;
import io.substrait.plan.ProtoPlanConverter;
import io.substrait.type.Type;
import io.substrait.type.TypeCreator;
import java.io.InputStream;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
Expand All @@ -24,51 +24,56 @@
* <li>Roundtrip between POJO and Proto
* </ul>
*/
class TypeExtensionTest {

static final TypeCreator R = TypeCreator.of(false);
class TypeExtensionTest extends TestBase {

static final String URN = "extension:test:custom_extensions";
final SimpleExtension.ExtensionCollection extensionCollection;
static final SimpleExtension.ExtensionCollection CUSTOM_EXTENSION;

{
String path = "/extensions/custom_extensions.yaml";
InputStream inputStream = this.getClass().getResourceAsStream(path);
extensionCollection = SimpleExtension.load(path, inputStream);
static {
try {
String customExtensionStr = asString("extensions/custom_extensions.yaml");
CUSTOM_EXTENSION = SimpleExtension.load(URN, customExtensionStr);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}

final SubstraitBuilder b = new SubstraitBuilder(extensionCollection);
Type customType1 = b.userDefinedType(URN, "customType1");
Type customType2 = b.userDefinedType(URN, "customType2");
Type customType1 = sb.userDefinedType(URN, "customType1");
Type customType2 = sb.userDefinedType(URN, "customType2");
final PlanProtoConverter planProtoConverter = new PlanProtoConverter();
final ProtoPlanConverter protoPlanConverter = new ProtoPlanConverter(extensionCollection);
final ProtoPlanConverter protoPlanConverter;

TypeExtensionTest() {
super(CUSTOM_EXTENSION);
this.protoPlanConverter = new ProtoPlanConverter(extensions);
}

@Test
void roundtripCustomType() {
// CREATE TABLE example (custom_type_column custom_type1, i64_column BIGINT);
List<String> tableName = Stream.of("example").collect(Collectors.toList());
List<String> columnNames =
Stream.of("custom_type_column", "i64_column").collect(Collectors.toList());
List<io.substrait.type.Type> types = Stream.of(customType1, R.I64).collect(Collectors.toList());
List<Type> types = Stream.of(customType1, R.I64).collect(Collectors.toList());

// SELECT custom_type_column, scalar1(custom_type_column), scalar2(i64_column)
// FROM example
Plan plan =
b.plan(
b.root(
b.project(
sb.plan(
sb.root(
sb.project(
input ->
Stream.of(
b.fieldReference(input, 0),
b.scalarFn(
sb.fieldReference(input, 0),
sb.scalarFn(
URN,
"scalar1:u!customType1",
R.I64,
b.fieldReference(input, 0)),
b.scalarFn(
URN, "scalar2:i64", customType2, b.fieldReference(input, 1)))
sb.fieldReference(input, 0)),
sb.scalarFn(
URN, "scalar2:i64", customType2, sb.fieldReference(input, 1)))
.collect(Collectors.toList()),
b.namedScan(tableName, columnNames, types))));
sb.namedScan(tableName, columnNames, types))));

io.substrait.proto.Plan protoPlan = planProtoConverter.toProto(plan);
Plan planReturned = protoPlanConverter.from(protoPlan);
Expand All @@ -80,19 +85,21 @@ void roundtripNumberedAnyTypes() {
List<String> tableName = Stream.of("example").collect(Collectors.toList());
List<String> columnNames =
Stream.of("array_i64_type_column", "array_i64_column").collect(Collectors.toList());
List<io.substrait.type.Type> types =
Stream.of(REQUIRED.list(R.I64)).collect(Collectors.toList());
List<Type> types = Stream.of(REQUIRED.list(R.I64)).collect(Collectors.toList());

Plan plan =
b.plan(
b.root(
b.project(
sb.plan(
sb.root(
sb.project(
input ->
Stream.of(
b.scalarFn(
URN, "array_index:list_i64", R.I64, b.fieldReference(input, 0)))
sb.scalarFn(
URN,
"array_index:list_i64",
R.I64,
sb.fieldReference(input, 0)))
.collect(Collectors.toList()),
b.namedScan(tableName, columnNames, types))));
sb.namedScan(tableName, columnNames, types))));
io.substrait.proto.Plan protoPlan = planProtoConverter.toProto(plan);
Plan planReturned = protoPlanConverter.from(protoPlan);
assertEquals(plan, planReturned);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
class ProtoRelConverterTest extends TestBase {

final NamedScan commonTable =
b.namedScan(Collections.emptyList(), Collections.emptyList(), Collections.emptyList());
sb.namedScan(Collections.emptyList(), Collections.emptyList(), Collections.emptyList());

/**
* Verify default behaviour of {@link ProtoRelConverter} in the presence of {@link
Expand Down Expand Up @@ -121,9 +121,7 @@ functionCollector, new StringHolderHandlingExtensionProtoConverter())

final Rel relFromProto =
new ProtoRelConverter(
functionCollector,
defaultExtensionCollection,
new StringHolderHandlingProtoExtensionConverter())
functionCollector, extensions, new StringHolderHandlingProtoExtensionConverter())
.from(protoRel);

assertEquals(rel, relFromProto);
Expand All @@ -140,9 +138,7 @@ functionCollector, new StringHolderHandlingExtensionProtoConverter())

final Rel relFromProto =
new ProtoRelConverter(
functionCollector,
defaultExtensionCollection,
new StringHolderHandlingProtoExtensionConverter())
functionCollector, extensions, new StringHolderHandlingProtoExtensionConverter())
.from(protoRel);

assertEquals(rel, relFromProto);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import io.substrait.relation.RelProtoConverter;
import io.substrait.relation.VirtualTableScan;
import io.substrait.type.NamedStruct;
import io.substrait.type.TypeCreator;
import java.math.BigDecimal;
import java.util.Arrays;
import java.util.Collections;
Expand All @@ -33,8 +32,6 @@ private void assertAggregateRoundtrip(Expression.AggregationInvocation invocatio

ExtensionCollector functionCollector = new ExtensionCollector();
RelProtoConverter to = new RelProtoConverter(functionCollector);
io.substrait.extension.SimpleExtension.ExtensionCollection extensions =
defaultExtensionCollection;
ProtoRelConverter from = new ProtoRelConverter(functionCollector, extensions);

io.substrait.relation.ImmutableMeasure measure =
Expand All @@ -43,7 +40,7 @@ private void assertAggregateRoundtrip(Expression.AggregationInvocation invocatio
AggregateFunctionInvocation.builder()
.arguments(Collections.emptyList())
.declaration(extensions.aggregateFunctions().get(0))
.outputType(TypeCreator.of(false).I64)
.outputType(R.I64)
.aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT)
.invocation(invocation)
.options(
Expand All @@ -56,7 +53,7 @@ private void assertAggregateRoundtrip(Expression.AggregationInvocation invocatio
Arrays.asList(
Expression.SortField.builder()
// SORT BY decimal
.expr(b.fieldReference(input, 0))
.expr(sb.fieldReference(input, 0))
.direction(Expression.SortDirection.ASC_NULLS_LAST)
.build()))
.build())
Expand Down
Loading