Skip to content

Commit 82141bb

Browse files
authored
feat(isthmus): support month()-style datetime operators (#643)
The Substrait Function for Month style mappings doesn't work as the calcite only has 2 arguments, but substrait has three. Isthmus (in either direction) assumes that there is a 1:1 mapping betwen a substrait function and a calicte function. With the same number of arguments in each. (although types can be coerced if needeed) In this case, the 3 arg Substrait function needs to be mapped to a 2 arg calcite one. In effect this is working by swapping out the function to a different one and then letting the rest of the logic work as before. A concern is the increased complexity - and the code as a whole wasn't built with this in mind. Signed-off-by: MBWhite <whitemat@uk.ibm.com>
1 parent 9060539 commit 82141bb

File tree

9 files changed

+159
-24
lines changed

9 files changed

+159
-24
lines changed

core/src/main/java/io/substrait/expression/EnumArg.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,10 @@ default <R, C extends VisitationContext, E extends Throwable> R accept(
2626
}
2727

2828
static EnumArg of(SimpleExtension.EnumArgument enumArg, String option) {
29-
assert (enumArg.options().contains(option));
29+
if (!enumArg.options().contains(option)) {
30+
throw new IllegalArgumentException(
31+
String.format("EnumArg value %s not valid for options: %s", option, enumArg.options()));
32+
}
3033
return builder().value(Optional.of(option)).build();
3134
}
3235

examples/isthmus-api/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ _apps
22
_data
33
**/*/bin
44
build
5+
substrait.plan

isthmus/src/main/java/io/substrait/isthmus/expression/EnumConverter.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,12 @@ public class EnumConverter {
7272
calciteEnumMap.put(
7373
argAnchor(DefaultExtensionCatalog.FUNCTIONS_STRING, "rtrim:str_str", 0),
7474
SqlTrimFunction.Flag.class);
75+
calciteEnumMap.put(
76+
argAnchor(DefaultExtensionCatalog.FUNCTIONS_DATETIME, "extract:req_req_date", 0),
77+
TimeUnitRange.class);
78+
calciteEnumMap.put(
79+
argAnchor(DefaultExtensionCatalog.FUNCTIONS_DATETIME, "extract:req_req_date", 1),
80+
ExtractIndexing.class);
7581
}
7682

7783
private static Optional<Enum<?>> constructValue(
@@ -84,6 +90,10 @@ private static Optional<Enum<?>> constructValue(
8490
return option.get().map(SqlTrimFunction.Flag::valueOf);
8591
}
8692

93+
// ExtractIndexing does not need to be converted here. Calcite
94+
// doesn't have the concept of the indexing. It's date
95+
// functions are all indexed from 1
96+
8797
return Optional.empty();
8898
}
8999

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
package io.substrait.isthmus.expression;
2+
3+
import io.substrait.expression.Expression;
4+
import io.substrait.expression.FunctionArg;
5+
import io.substrait.extension.SimpleExtension.ScalarFunctionVariant;
6+
import java.util.Collections;
7+
import java.util.LinkedList;
8+
import java.util.List;
9+
import java.util.Map;
10+
import java.util.Optional;
11+
import java.util.stream.Collectors;
12+
import org.apache.calcite.avatica.util.TimeUnitRange;
13+
import org.apache.calcite.rex.RexBuilder;
14+
import org.apache.calcite.rex.RexCall;
15+
import org.apache.calcite.rex.RexLiteral;
16+
import org.apache.calcite.rex.RexNode;
17+
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
18+
import org.apache.calcite.sql.type.SqlTypeName;
19+
20+
/**
21+
* Custom mapping for the Calcite MONTH/DAY/QUARTER functions.
22+
*
23+
* <p>These come from Calcite as 2 argument functions (like YEAR) but in Substrait these functions
24+
* are 3 arguments; the additional being if this is a 0 or 1 based value. Calcite is a 1 based value
25+
* in this case.
26+
*
27+
* <p>We need to therefore map the MONTH etc functions to a different Substrait function.
28+
*/
29+
final class ExtractDateFunctionMapper implements ScalarFunctionMapper {
30+
31+
private final Map<String, ScalarFunctionVariant> extractFunctions;
32+
33+
public ExtractDateFunctionMapper(List<ScalarFunctionVariant> functions) {
34+
35+
Map<String, ScalarFunctionVariant> fns =
36+
functions.stream()
37+
.filter(f -> "extract".equals(f.name()))
38+
.collect(Collectors.toMap(Object::toString, f -> f));
39+
40+
this.extractFunctions = Collections.unmodifiableMap(fns);
41+
}
42+
43+
@Override
44+
public Optional<SubstraitFunctionMapping> toSubstrait(final RexCall call) {
45+
if (!SqlStdOperatorTable.EXTRACT.equals(call.op)) {
46+
return Optional.empty();
47+
}
48+
49+
if (call.operandCount() < 2) {
50+
return Optional.empty();
51+
}
52+
53+
RexNode extractType = call.operands.get(0);
54+
if (extractType.getType().getSqlTypeName() != SqlTypeName.SYMBOL) {
55+
return Optional.empty();
56+
}
57+
58+
final RexNode dataType = call.operands.get(1);
59+
if (!dataType.getType().getSqlTypeName().equals(SqlTypeName.DATE)) {
60+
return Optional.empty();
61+
}
62+
63+
TimeUnitRange value = ((RexLiteral) extractType).getValueAs(TimeUnitRange.class);
64+
65+
switch (value) {
66+
case QUARTER:
67+
case MONTH:
68+
case DAY:
69+
{
70+
final List<RexNode> newOperands = new LinkedList<>(call.operands);
71+
newOperands.add(1, RexBuilder.DEFAULT.makeFlag(ExtractIndexing.ONE));
72+
73+
final ScalarFunctionVariant substraitFn =
74+
this.extractFunctions.get("extract:req_req_date");
75+
return Optional.of(
76+
new SubstraitFunctionMapping("extract", newOperands, List.of(substraitFn)));
77+
}
78+
default:
79+
return Optional.empty();
80+
}
81+
}
82+
83+
@Override
84+
public Optional<List<FunctionArg>> getExpressionArguments(
85+
final Expression.ScalarFunctionInvocation expression) {
86+
String name = expression.declaration().toString();
87+
88+
if ("extract:req_req_date".equals(name)) {
89+
final List<FunctionArg> newArgs = new LinkedList<>(expression.arguments());
90+
newArgs.remove(1);
91+
92+
return Optional.of(newArgs);
93+
}
94+
return Optional.empty();
95+
}
96+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
package io.substrait.isthmus.expression;
2+
3+
/**
4+
* Enum to define the INDEXING property on the date functions.
5+
*
6+
* <p>Controls if the number used for example in months is 0 or 1 based.
7+
*/
8+
public enum ExtractIndexing {
9+
ONE,
10+
ZERO
11+
}

isthmus/src/main/java/io/substrait/isthmus/expression/LiteralConverter.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ public Expression.Literal convert(RexLiteral literal) {
8888
return ExpressionCreator.bool(n, literal.getValueAs(Boolean.class));
8989
case CHAR:
9090
{
91-
Comparable val = literal.getValue();
91+
Comparable<?> val = literal.getValue();
9292
if (val instanceof NlsString) {
9393
NlsString nls = (NlsString) val;
9494
return ExpressionCreator.fixedChar(n, nls.getValue());
@@ -127,11 +127,11 @@ public Expression.Literal convert(RexLiteral literal) {
127127
case SYMBOL:
128128
{
129129
Object value = literal.getValue();
130-
// case TimeUnitRange tur -> string(n, tur.name());
131130
if (value instanceof NlsString) {
132131
return ExpressionCreator.string(n, ((NlsString) value).getValue());
133132
} else if (value instanceof Enum) {
134133
Enum<?> v = (Enum<?>) value;
134+
135135
Optional<Expression.Literal> r =
136136
EnumConverter.canConvert(v)
137137
? Optional.of(ExpressionCreator.string(n, v.name()))

isthmus/src/main/java/io/substrait/isthmus/expression/ScalarFunctionConverter.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ public class ScalarFunctionConverter
2424
Expression,
2525
ScalarFunctionConverter.WrappedScalarCall>
2626
implements CallConverter {
27-
2827
/**
2928
* Function mappers provide a hook point for any custom mapping to Substrait functions and
3029
* arguments.
@@ -43,7 +42,11 @@ public ScalarFunctionConverter(
4342
TypeConverter typeConverter) {
4443
super(functions, additionalSignatures, typeFactory, typeConverter);
4544

46-
mappers = List.of(new TrimFunctionMapper(functions), new SqrtFunctionMapper(functions));
45+
mappers =
46+
List.of(
47+
new TrimFunctionMapper(functions),
48+
new SqrtFunctionMapper(functions),
49+
new ExtractDateFunctionMapper(functions));
4750
}
4851

4952
@Override
@@ -68,6 +71,7 @@ private Optional<SubstraitFunctionMapping> getMappingForCall(final RexCall call)
6871
.orElse(Optional.empty());
6972
}
7073

74+
/** Application of the more complex mappings. */
7175
private Optional<Expression> mappedConvert(
7276
SubstraitFunctionMapping mapping,
7377
RexCall call,
@@ -85,6 +89,7 @@ public Stream<RexNode> getOperands() {
8589
return attemptMatch(finder, wrapped, topLevelConverter);
8690
}
8791

92+
/** Default conversion for functions that have simple 1:1 mappings. */
8893
private Optional<Expression> defaultConvert(
8994
RexCall call, Function<RexNode, Expression> topLevelConverter) {
9095
FunctionFinder finder = signatures.get(call.op);

isthmus/src/test/java/io/substrait/isthmus/FunctionConversionTest.java

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,8 @@ class FunctionConversionTest extends PlanTestBase {
4848
@Test
4949
void subtractDateIDay() {
5050
// When this function is converted to Calcite, if the Calcite type derivation is used an
51-
// java.lang.ArrayIndexOutOfBoundsException is thrown. It is quite likely that this is being
52-
// mapped to the wrong
53-
// Calcite function.
51+
// java.lang.ArrayIndexOutOfBoundsException is thrown. It is quite likely that
52+
// this is being mapped to the wrong Calcite function.
5453
// TODO: https://github.com/substrait-io/substrait-java/issues/377
5554
Expression.ScalarFunctionInvocation expr =
5655
b.scalarFn(
@@ -183,6 +182,25 @@ void extractTimeScalarFunction() {
183182
assertEquals("EXTRACT(FLAG(MINUTE), 00:00:00:TIME(6))", extract.toString());
184183
}
185184

185+
@Test
186+
void extractDateWithIndexing() {
187+
ScalarFunctionInvocation reqReqDateFn =
188+
b.scalarFn(
189+
DefaultExtensionCatalog.FUNCTIONS_DATETIME,
190+
"extract:req_req_date",
191+
TypeCreator.REQUIRED.I64,
192+
EnumArg.builder().value("MONTH").build(),
193+
EnumArg.builder().value("ONE").build(),
194+
Expression.DateLiteral.builder().value(0).build());
195+
196+
RexNode calciteExpr = reqReqDateFn.accept(expressionRexConverter, Context.newContext());
197+
assertEquals(SqlKind.EXTRACT, calciteExpr.getKind());
198+
assertInstanceOf(RexCall.class, calciteExpr);
199+
200+
RexCall extract = (RexCall) calciteExpr;
201+
assertEquals("EXTRACT(FLAG(MONTH), 1970-01-01)", extract.toString());
202+
}
203+
186204
@Test
187205
void unsupportedExtractTimestampTzWithIndexing() {
188206
ScalarFunctionInvocation reqReqTstzFn =
@@ -249,22 +267,6 @@ void unsupportedExtractPrecisionTimestampWithIndexing() {
249267
() -> reqReqPtsFn.accept(expressionRexConverter, Context.newContext()));
250268
}
251269

252-
@Test
253-
void unsupportedExtractDateWithIndexing() {
254-
ScalarFunctionInvocation reqReqDateFn =
255-
b.scalarFn(
256-
DefaultExtensionCatalog.FUNCTIONS_DATETIME,
257-
"extract:req_req_date",
258-
TypeCreator.REQUIRED.I64,
259-
EnumArg.builder().value("MONTH").build(),
260-
EnumArg.builder().value("ONE").build(),
261-
Expression.DateLiteral.builder().value(0).build());
262-
263-
assertThrows(
264-
UnsupportedOperationException.class,
265-
() -> reqReqDateFn.accept(expressionRexConverter, Context.newContext()));
266-
}
267-
268270
@Test
269271
void concatStringLiteralAndVarchar() throws Exception {
270272
assertProtoPlanRoundrip("select 'part_'||P_NAME from PART");

isthmus/src/test/java/io/substrait/isthmus/Substrait2SqlTest.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,4 +212,11 @@ void caseWhenTest() throws Exception {
212212
assertFullRoundTrip(
213213
"select case when p_size > 100 then 'large' when p_size > 50 then 'medium' else 'small' end from part");
214214
}
215+
216+
@Test
217+
void dateFunctions() throws Exception {
218+
assertSqlSubstraitRelRoundTrip("select month(o_orderdate),year(o_orderdate) from orders");
219+
assertSqlSubstraitRelRoundTrip(
220+
"select extract(month from o_orderdate),extract(year from o_orderdate),extract(day from o_orderdate) from orders");
221+
}
215222
}

0 commit comments

Comments
 (0)