diff --git a/core/src/main/java/io/substrait/expression/EnumArg.java b/core/src/main/java/io/substrait/expression/EnumArg.java index 956c9a653..a22fa9fe1 100644 --- a/core/src/main/java/io/substrait/expression/EnumArg.java +++ b/core/src/main/java/io/substrait/expression/EnumArg.java @@ -26,7 +26,10 @@ default R accept( } static EnumArg of(SimpleExtension.EnumArgument enumArg, String option) { - assert (enumArg.options().contains(option)); + if (!enumArg.options().contains(option)) { + throw new IllegalArgumentException( + String.format("EnumArg value %s not valid for options: %s", option, enumArg.options())); + } return builder().value(Optional.of(option)).build(); } diff --git a/examples/isthmus-api/.gitignore b/examples/isthmus-api/.gitignore index 6ead40eb5..8972674d3 100644 --- a/examples/isthmus-api/.gitignore +++ b/examples/isthmus-api/.gitignore @@ -2,3 +2,4 @@ _apps _data **/*/bin build +substrait.plan diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/EnumConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/EnumConverter.java index b69ef9b02..2f95ca276 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/EnumConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/EnumConverter.java @@ -72,6 +72,12 @@ public class EnumConverter { calciteEnumMap.put( argAnchor(DefaultExtensionCatalog.FUNCTIONS_STRING, "rtrim:str_str", 0), SqlTrimFunction.Flag.class); + calciteEnumMap.put( + argAnchor(DefaultExtensionCatalog.FUNCTIONS_DATETIME, "extract:req_req_date", 0), + TimeUnitRange.class); + calciteEnumMap.put( + argAnchor(DefaultExtensionCatalog.FUNCTIONS_DATETIME, "extract:req_req_date", 1), + ExtractIndexing.class); } private static Optional> constructValue( @@ -84,6 +90,10 @@ private static Optional> constructValue( return option.get().map(SqlTrimFunction.Flag::valueOf); } + // ExtractIndexing does not need to be converted here. Calcite + // doesn't have the concept of the indexing. It's date + // functions are all indexed from 1 + return Optional.empty(); } diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/ExtractDateFunctionMapper.java b/isthmus/src/main/java/io/substrait/isthmus/expression/ExtractDateFunctionMapper.java new file mode 100644 index 000000000..073fa9f3b --- /dev/null +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/ExtractDateFunctionMapper.java @@ -0,0 +1,96 @@ +package io.substrait.isthmus.expression; + +import io.substrait.expression.Expression; +import io.substrait.expression.FunctionArg; +import io.substrait.extension.SimpleExtension.ScalarFunctionVariant; +import java.util.Collections; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; +import org.apache.calcite.avatica.util.TimeUnitRange; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexLiteral; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.type.SqlTypeName; + +/** + * Custom mapping for the Calcite MONTH/DAY/QUARTER functions. + * + *

These come from Calcite as 2 argument functions (like YEAR) but in Substrait these functions + * are 3 arguments; the additional being if this is a 0 or 1 based value. Calcite is a 1 based value + * in this case. + * + *

We need to therefore map the MONTH etc functions to a different Substrait function. + */ +final class ExtractDateFunctionMapper implements ScalarFunctionMapper { + + private final Map extractFunctions; + + public ExtractDateFunctionMapper(List functions) { + + Map fns = + functions.stream() + .filter(f -> "extract".equals(f.name())) + .collect(Collectors.toMap(Object::toString, f -> f)); + + this.extractFunctions = Collections.unmodifiableMap(fns); + } + + @Override + public Optional toSubstrait(final RexCall call) { + if (!SqlStdOperatorTable.EXTRACT.equals(call.op)) { + return Optional.empty(); + } + + if (call.operandCount() < 2) { + return Optional.empty(); + } + + RexNode extractType = call.operands.get(0); + if (extractType.getType().getSqlTypeName() != SqlTypeName.SYMBOL) { + return Optional.empty(); + } + + final RexNode dataType = call.operands.get(1); + if (!dataType.getType().getSqlTypeName().equals(SqlTypeName.DATE)) { + return Optional.empty(); + } + + TimeUnitRange value = ((RexLiteral) extractType).getValueAs(TimeUnitRange.class); + + switch (value) { + case QUARTER: + case MONTH: + case DAY: + { + final List newOperands = new LinkedList<>(call.operands); + newOperands.add(1, RexBuilder.DEFAULT.makeFlag(ExtractIndexing.ONE)); + + final ScalarFunctionVariant substraitFn = + this.extractFunctions.get("extract:req_req_date"); + return Optional.of( + new SubstraitFunctionMapping("extract", newOperands, List.of(substraitFn))); + } + default: + return Optional.empty(); + } + } + + @Override + public Optional> getExpressionArguments( + final Expression.ScalarFunctionInvocation expression) { + String name = expression.declaration().toString(); + + if ("extract:req_req_date".equals(name)) { + final List newArgs = new LinkedList<>(expression.arguments()); + newArgs.remove(1); + + return Optional.of(newArgs); + } + return Optional.empty(); + } +} diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/ExtractIndexing.java b/isthmus/src/main/java/io/substrait/isthmus/expression/ExtractIndexing.java new file mode 100644 index 000000000..54c9bff41 --- /dev/null +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/ExtractIndexing.java @@ -0,0 +1,11 @@ +package io.substrait.isthmus.expression; + +/** + * Enum to define the INDEXING property on the date functions. + * + *

Controls if the number used for example in months is 0 or 1 based. + */ +public enum ExtractIndexing { + ONE, + ZERO +} diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/LiteralConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/LiteralConverter.java index 02cb8a116..5d91e6a9a 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/LiteralConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/LiteralConverter.java @@ -88,7 +88,7 @@ public Expression.Literal convert(RexLiteral literal) { return ExpressionCreator.bool(n, literal.getValueAs(Boolean.class)); case CHAR: { - Comparable val = literal.getValue(); + Comparable val = literal.getValue(); if (val instanceof NlsString) { NlsString nls = (NlsString) val; return ExpressionCreator.fixedChar(n, nls.getValue()); @@ -127,11 +127,11 @@ public Expression.Literal convert(RexLiteral literal) { case SYMBOL: { Object value = literal.getValue(); - // case TimeUnitRange tur -> string(n, tur.name()); if (value instanceof NlsString) { return ExpressionCreator.string(n, ((NlsString) value).getValue()); } else if (value instanceof Enum) { Enum v = (Enum) value; + Optional r = EnumConverter.canConvert(v) ? Optional.of(ExpressionCreator.string(n, v.name())) diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/ScalarFunctionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/ScalarFunctionConverter.java index b3ad6514c..00bfeec4b 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/ScalarFunctionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/ScalarFunctionConverter.java @@ -24,7 +24,6 @@ public class ScalarFunctionConverter Expression, ScalarFunctionConverter.WrappedScalarCall> implements CallConverter { - /** * Function mappers provide a hook point for any custom mapping to Substrait functions and * arguments. @@ -43,7 +42,11 @@ public ScalarFunctionConverter( TypeConverter typeConverter) { super(functions, additionalSignatures, typeFactory, typeConverter); - mappers = List.of(new TrimFunctionMapper(functions), new SqrtFunctionMapper(functions)); + mappers = + List.of( + new TrimFunctionMapper(functions), + new SqrtFunctionMapper(functions), + new ExtractDateFunctionMapper(functions)); } @Override @@ -68,6 +71,7 @@ private Optional getMappingForCall(final RexCall call) .orElse(Optional.empty()); } + /** Application of the more complex mappings. */ private Optional mappedConvert( SubstraitFunctionMapping mapping, RexCall call, @@ -85,6 +89,7 @@ public Stream getOperands() { return attemptMatch(finder, wrapped, topLevelConverter); } + /** Default conversion for functions that have simple 1:1 mappings. */ private Optional defaultConvert( RexCall call, Function topLevelConverter) { FunctionFinder finder = signatures.get(call.op); diff --git a/isthmus/src/test/java/io/substrait/isthmus/FunctionConversionTest.java b/isthmus/src/test/java/io/substrait/isthmus/FunctionConversionTest.java index 102aa2f75..211f1b1ef 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/FunctionConversionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/FunctionConversionTest.java @@ -48,9 +48,8 @@ class FunctionConversionTest extends PlanTestBase { @Test void subtractDateIDay() { // When this function is converted to Calcite, if the Calcite type derivation is used an - // java.lang.ArrayIndexOutOfBoundsException is thrown. It is quite likely that this is being - // mapped to the wrong - // Calcite function. + // java.lang.ArrayIndexOutOfBoundsException is thrown. It is quite likely that + // this is being mapped to the wrong Calcite function. // TODO: https://github.com/substrait-io/substrait-java/issues/377 Expression.ScalarFunctionInvocation expr = b.scalarFn( @@ -183,6 +182,25 @@ void extractTimeScalarFunction() { assertEquals("EXTRACT(FLAG(MINUTE), 00:00:00:TIME(6))", extract.toString()); } + @Test + void extractDateWithIndexing() { + ScalarFunctionInvocation reqReqDateFn = + b.scalarFn( + DefaultExtensionCatalog.FUNCTIONS_DATETIME, + "extract:req_req_date", + TypeCreator.REQUIRED.I64, + EnumArg.builder().value("MONTH").build(), + EnumArg.builder().value("ONE").build(), + Expression.DateLiteral.builder().value(0).build()); + + RexNode calciteExpr = reqReqDateFn.accept(expressionRexConverter, Context.newContext()); + assertEquals(SqlKind.EXTRACT, calciteExpr.getKind()); + assertInstanceOf(RexCall.class, calciteExpr); + + RexCall extract = (RexCall) calciteExpr; + assertEquals("EXTRACT(FLAG(MONTH), 1970-01-01)", extract.toString()); + } + @Test void unsupportedExtractTimestampTzWithIndexing() { ScalarFunctionInvocation reqReqTstzFn = @@ -249,22 +267,6 @@ void unsupportedExtractPrecisionTimestampWithIndexing() { () -> reqReqPtsFn.accept(expressionRexConverter, Context.newContext())); } - @Test - void unsupportedExtractDateWithIndexing() { - ScalarFunctionInvocation reqReqDateFn = - b.scalarFn( - DefaultExtensionCatalog.FUNCTIONS_DATETIME, - "extract:req_req_date", - TypeCreator.REQUIRED.I64, - EnumArg.builder().value("MONTH").build(), - EnumArg.builder().value("ONE").build(), - Expression.DateLiteral.builder().value(0).build()); - - assertThrows( - UnsupportedOperationException.class, - () -> reqReqDateFn.accept(expressionRexConverter, Context.newContext())); - } - @Test void concatStringLiteralAndVarchar() throws Exception { assertProtoPlanRoundrip("select 'part_'||P_NAME from PART"); diff --git a/isthmus/src/test/java/io/substrait/isthmus/Substrait2SqlTest.java b/isthmus/src/test/java/io/substrait/isthmus/Substrait2SqlTest.java index 7060e9075..e8a5e7c23 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/Substrait2SqlTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/Substrait2SqlTest.java @@ -212,4 +212,11 @@ void caseWhenTest() throws Exception { assertFullRoundTrip( "select case when p_size > 100 then 'large' when p_size > 50 then 'medium' else 'small' end from part"); } + + @Test + void dateFunctions() throws Exception { + assertSqlSubstraitRelRoundTrip("select month(o_orderdate),year(o_orderdate) from orders"); + assertSqlSubstraitRelRoundTrip( + "select extract(month from o_orderdate),extract(year from o_orderdate),extract(day from o_orderdate) from orders"); + } }