diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java index 5acf1fe43..750ee9b86 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java @@ -78,7 +78,12 @@ public class FunctionMappings { s(SqlStdOperatorTable.BITXOR, "bitwise_xor"), s(SqlStdOperatorTable.RADIANS, "radians"), s(SqlStdOperatorTable.DEGREES, "degrees"), - s(SqlLibraryOperators.FACTORIAL, "factorial")) + s(SqlLibraryOperators.FACTORIAL, "factorial"), + s(SqlStdOperatorTable.IS_TRUE, "is_true"), + s(SqlStdOperatorTable.IS_FALSE, "is_false"), + s(SqlStdOperatorTable.IS_NOT_TRUE, "is_not_true"), + s(SqlStdOperatorTable.IS_NOT_FALSE, "is_not_false"), + s(SqlStdOperatorTable.IS_DISTINCT_FROM, "is_distinct_from")) .build(); public static final ImmutableList AGGREGATE_SIGS = diff --git a/isthmus/src/test/java/io/substrait/isthmus/ComparisonFunctionsTest.java b/isthmus/src/test/java/io/substrait/isthmus/ComparisonFunctionsTest.java new file mode 100644 index 000000000..bcc74db88 --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/ComparisonFunctionsTest.java @@ -0,0 +1,49 @@ +package io.substrait.isthmus; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; +import org.junit.jupiter.params.provider.ValueSource; + +public class ComparisonFunctionsTest extends PlanTestBase { + static String CREATES = + "CREATE TABLE numbers (int_a INT, int_b INT, double_a DOUBLE, double_b DOUBLE)"; + + @Test + void is_true() throws Exception { + String query = "SELECT ((int_a > int_b) IS TRUE) FROM numbers"; + assertSqlSubstraitRelRoundTrip(query, CREATES); + } + + @Test + void is_false() throws Exception { + String query = "SELECT ((int_a > int_b) IS FALSE) FROM numbers"; + assertSqlSubstraitRelRoundTrip(query, CREATES); + } + + @Test + void is_not_true() throws Exception { + String query = "SELECT ((int_a > int_b) IS NOT TRUE) FROM numbers"; + assertSqlSubstraitRelRoundTrip(query, CREATES); + } + + @Test + void is_not_false() throws Exception { + String query = "SELECT ((int_a > int_b) IS NOT FALSE) FROM numbers"; + assertSqlSubstraitRelRoundTrip(query, CREATES); + } + + @ParameterizedTest + @CsvSource({"int_a, int_b", "int_b, int_a", "double_a, double_b", "double_b, double_a"}) + void is_distinct_from(String left, String right) throws Exception { + String query = String.format("SELECT (%s IS DISTINCT FROM %s) FROM numbers", left, right); + assertSqlSubstraitRelRoundTrip(query, CREATES); + } + + @ParameterizedTest + @ValueSource(strings = {"int_a", "int_b", "double_a", "double_b"}) + void is_distinct_from_null_vs_col(String column) throws Exception { + String query = String.format("SELECT (NULL IS DISTINCT FROM %s) FROM numbers", column); + assertSqlSubstraitRelRoundTrip(query, CREATES); + } +}