Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion core/src/main/java/io/substrait/expression/EnumArg.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ default <R, C extends VisitationContext, E extends Throwable> R accept(
}

static EnumArg of(SimpleExtension.EnumArgument enumArg, String option) {
assert (enumArg.options().contains(option));
if (!enumArg.options().contains(option)) {
throw new IllegalArgumentException(
String.format("EnumArg value %s not valid for options: %s", option, enumArg.options()));
}
return builder().value(Optional.of(option)).build();
}

Expand Down
1 change: 1 addition & 0 deletions examples/isthmus-api/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ _apps
_data
**/*/bin
build
substrait.plan
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to ignore these?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's created for the examples, and I was used them as a test and kept adding it to the commit.

Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ public class EnumConverter {
calciteEnumMap.put(
argAnchor(DefaultExtensionCatalog.FUNCTIONS_STRING, "rtrim:str_str", 0),
SqlTrimFunction.Flag.class);
calciteEnumMap.put(
argAnchor(DefaultExtensionCatalog.FUNCTIONS_DATETIME, "extract:req_req_date", 0),
TimeUnitRange.class);
calciteEnumMap.put(
argAnchor(DefaultExtensionCatalog.FUNCTIONS_DATETIME, "extract:req_req_date", 1),
ExtractIndexing.class);
}

private static Optional<Enum<?>> constructValue(
Expand All @@ -84,6 +90,10 @@ private static Optional<Enum<?>> constructValue(
return option.get().map(SqlTrimFunction.Flag::valueOf);
}

// ExtractIndexing does not need to be converted here. Calcite
// doesn't have the concept of the indexing. It's date
// functions are all indexed from 1

return Optional.empty();
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package io.substrait.isthmus.expression;

import io.substrait.expression.Expression;
import io.substrait.expression.FunctionArg;
import io.substrait.extension.SimpleExtension.ScalarFunctionVariant;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import org.apache.calcite.avatica.util.TimeUnitRange;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeName;

/**
* Custom mapping for the Calcite MONTH/DAY/QUARTER functions.
*
* <p>These come from Calcite as 2 argument functions (like YEAR) but in Substrait these functions
* are 3 arguments; the additional being if this is a 0 or 1 based value. Calcite is a 1 based value
* in this case.
*
* <p>We need to therefore map the MONTH etc functions to a different Substrait function.
*/
final class ExtractDateFunctionMapper implements ScalarFunctionMapper {

private final Map<String, ScalarFunctionVariant> extractFunctions;

public ExtractDateFunctionMapper(List<ScalarFunctionVariant> functions) {

Map<String, ScalarFunctionVariant> fns =
functions.stream()
.filter(f -> "extract".equals(f.name()))
.collect(Collectors.toMap(Object::toString, f -> f));

this.extractFunctions = Collections.unmodifiableMap(fns);
}

@Override
public Optional<SubstraitFunctionMapping> toSubstrait(final RexCall call) {
if (!SqlStdOperatorTable.EXTRACT.equals(call.op)) {
return Optional.empty();
}

if (call.operandCount() < 2) {
return Optional.empty();
}

RexNode extractType = call.operands.get(0);
if (extractType.getType().getSqlTypeName() != SqlTypeName.SYMBOL) {
return Optional.empty();
}

final RexNode dataType = call.operands.get(1);
if (!dataType.getType().getSqlTypeName().equals(SqlTypeName.DATE)) {
return Optional.empty();
}

TimeUnitRange value = ((RexLiteral) extractType).getValueAs(TimeUnitRange.class);

switch (value) {
case QUARTER:
case MONTH:
case DAY:
{
final List<RexNode> newOperands = new LinkedList<>(call.operands);
newOperands.add(1, RexBuilder.DEFAULT.makeFlag(ExtractIndexing.ONE));

final ScalarFunctionVariant substraitFn =
this.extractFunctions.get("extract:req_req_date");
return Optional.of(
new SubstraitFunctionMapping("extract", newOperands, List.of(substraitFn)));
}
default:
return Optional.empty();
}
}

@Override
public Optional<List<FunctionArg>> getExpressionArguments(
final Expression.ScalarFunctionInvocation expression) {
String name = expression.declaration().toString();

if ("extract:req_req_date".equals(name)) {
final List<FunctionArg> newArgs = new LinkedList<>(expression.arguments());
newArgs.remove(1);

return Optional.of(newArgs);
}
return Optional.empty();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package io.substrait.isthmus.expression;

/**
* Enum to define the INDEXING property on the date functions.
*
* <p>Controls if the number used for example in months is 0 or 1 based.
*/
public enum ExtractIndexing {
ONE,
ZERO
}
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ public Expression.Literal convert(RexLiteral literal) {
return ExpressionCreator.bool(n, literal.getValueAs(Boolean.class));
case CHAR:
{
Comparable val = literal.getValue();
Comparable<?> val = literal.getValue();
if (val instanceof NlsString) {
NlsString nls = (NlsString) val;
return ExpressionCreator.fixedChar(n, nls.getValue());
Expand Down Expand Up @@ -127,11 +127,11 @@ public Expression.Literal convert(RexLiteral literal) {
case SYMBOL:
{
Object value = literal.getValue();
// case TimeUnitRange tur -> string(n, tur.name());
if (value instanceof NlsString) {
return ExpressionCreator.string(n, ((NlsString) value).getValue());
} else if (value instanceof Enum) {
Enum<?> v = (Enum<?>) value;

Optional<Expression.Literal> r =
EnumConverter.canConvert(v)
? Optional.of(ExpressionCreator.string(n, v.name()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ public class ScalarFunctionConverter
Expression,
ScalarFunctionConverter.WrappedScalarCall>
implements CallConverter {

/**
* Function mappers provide a hook point for any custom mapping to Substrait functions and
* arguments.
Expand All @@ -43,7 +42,11 @@ public ScalarFunctionConverter(
TypeConverter typeConverter) {
super(functions, additionalSignatures, typeFactory, typeConverter);

mappers = List.of(new TrimFunctionMapper(functions), new SqrtFunctionMapper(functions));
mappers =
List.of(
new TrimFunctionMapper(functions),
new SqrtFunctionMapper(functions),
new ExtractDateFunctionMapper(functions));
}

@Override
Expand All @@ -68,6 +71,7 @@ private Optional<SubstraitFunctionMapping> getMappingForCall(final RexCall call)
.orElse(Optional.empty());
}

/** Application of the more complex mappings. */
private Optional<Expression> mappedConvert(
SubstraitFunctionMapping mapping,
RexCall call,
Expand All @@ -85,6 +89,7 @@ public Stream<RexNode> getOperands() {
return attemptMatch(finder, wrapped, topLevelConverter);
}

/** Default conversion for functions that have simple 1:1 mappings. */
private Optional<Expression> defaultConvert(
RexCall call, Function<RexNode, Expression> topLevelConverter) {
FunctionFinder finder = signatures.get(call.op);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,8 @@ class FunctionConversionTest extends PlanTestBase {
@Test
void subtractDateIDay() {
// When this function is converted to Calcite, if the Calcite type derivation is used an
// java.lang.ArrayIndexOutOfBoundsException is thrown. It is quite likely that this is being
// mapped to the wrong
// Calcite function.
// java.lang.ArrayIndexOutOfBoundsException is thrown. It is quite likely that
// this is being mapped to the wrong Calcite function.
// TODO: https://github.com/substrait-io/substrait-java/issues/377
Expression.ScalarFunctionInvocation expr =
b.scalarFn(
Expand Down Expand Up @@ -183,6 +182,25 @@ void extractTimeScalarFunction() {
assertEquals("EXTRACT(FLAG(MINUTE), 00:00:00:TIME(6))", extract.toString());
}

@Test
void extractDateWithIndexing() {
ScalarFunctionInvocation reqReqDateFn =
b.scalarFn(
DefaultExtensionCatalog.FUNCTIONS_DATETIME,
"extract:req_req_date",
TypeCreator.REQUIRED.I64,
EnumArg.builder().value("MONTH").build(),
EnumArg.builder().value("ONE").build(),
Expression.DateLiteral.builder().value(0).build());

RexNode calciteExpr = reqReqDateFn.accept(expressionRexConverter, Context.newContext());
assertEquals(SqlKind.EXTRACT, calciteExpr.getKind());
assertInstanceOf(RexCall.class, calciteExpr);

RexCall extract = (RexCall) calciteExpr;
assertEquals("EXTRACT(FLAG(MONTH), 1970-01-01)", extract.toString());
}

@Test
void unsupportedExtractTimestampTzWithIndexing() {
ScalarFunctionInvocation reqReqTstzFn =
Expand Down Expand Up @@ -249,22 +267,6 @@ void unsupportedExtractPrecisionTimestampWithIndexing() {
() -> reqReqPtsFn.accept(expressionRexConverter, Context.newContext()));
}

@Test
void unsupportedExtractDateWithIndexing() {
ScalarFunctionInvocation reqReqDateFn =
b.scalarFn(
DefaultExtensionCatalog.FUNCTIONS_DATETIME,
"extract:req_req_date",
TypeCreator.REQUIRED.I64,
EnumArg.builder().value("MONTH").build(),
EnumArg.builder().value("ONE").build(),
Expression.DateLiteral.builder().value(0).build());

assertThrows(
UnsupportedOperationException.class,
() -> reqReqDateFn.accept(expressionRexConverter, Context.newContext()));
}

@Test
void concatStringLiteralAndVarchar() throws Exception {
assertProtoPlanRoundrip("select 'part_'||P_NAME from PART");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,4 +212,11 @@ void caseWhenTest() throws Exception {
assertFullRoundTrip(
"select case when p_size > 100 then 'large' when p_size > 50 then 'medium' else 'small' end from part");
}

@Test
void dateFunctions() throws Exception {
assertSqlSubstraitRelRoundTrip("select month(o_orderdate),year(o_orderdate) from orders");
assertSqlSubstraitRelRoundTrip(
"select extract(month from o_orderdate),extract(year from o_orderdate),extract(day from o_orderdate) from orders");
}
}