diff --git a/docs/source/user-guide/latest/expressions.md b/docs/source/user-guide/latest/expressions.md index 1281f3bad2..4bd8916448 100644 --- a/docs/source/user-guide/latest/expressions.md +++ b/docs/source/user-guide/latest/expressions.md @@ -589,7 +589,7 @@ expression-level). The `outer` variants are wired but marked `Incompatible`; the | `to_char` | ✅ | | | `to_number` | ✅ | | | `to_varchar` | ✅ | | -| `translate` | ✅ | | +| `translate` | ✅ | Falls back by default; opt-in via allowIncompatible ([#4463](https://github.com/apache/datafusion-comet/issues/4463)) | | `trim` | ✅ | | | `try_to_binary` | 🔜 | Lowers to `TryEval(...)`, which falls back | | `try_to_number` | 🔜 | TRY variant of `to_number` | diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index a21d930226..7ee74bb3ab 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -208,7 +208,7 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { // when `++` is applied directly to a `Map(...)` literal. val base: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( classOf[Ascii] -> CometScalarFunction("ascii"), - classOf[BitLength] -> CometScalarFunction("bit_length"), + classOf[BitLength] -> CometBitLength, classOf[Chr] -> CometScalarFunction("char"), classOf[ConcatWs] -> CometConcatWs, classOf[Concat] -> CometConcat, @@ -220,7 +220,7 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { classOf[Levenshtein] -> CometLevenshtein, classOf[Like] -> CometLike, classOf[Lower] -> CometLower, - classOf[OctetLength] -> CometScalarFunction("octet_length"), + classOf[OctetLength] -> CometOctetLength, classOf[RegExpExtract] -> CometRegExpExtract, classOf[RegExpExtractAll] -> CometRegExpExtractAll, classOf[RegExpInStr] -> CometRegExpInStr, @@ -235,7 +235,7 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { classOf[StringLPad] -> CometStringLPad, classOf[StringSpace] -> CometScalarFunction("space"), classOf[StringSplit] -> CometStringSplit, - classOf[StringTranslate] -> CometScalarFunction("translate"), + classOf[StringTranslate] -> CometStringTranslate, classOf[StringTrim] -> CometScalarFunction("trim"), classOf[StringTrimLeft] -> CometScalarFunction("ltrim"), classOf[StringTrimRight] -> CometScalarFunction("rtrim"), diff --git a/spark/src/main/scala/org/apache/comet/serde/strings.scala b/spark/src/main/scala/org/apache/comet/serde/strings.scala index f2d4da1b64..c4abe8ad4e 100644 --- a/spark/src/main/scala/org/apache/comet/serde/strings.scala +++ b/spark/src/main/scala/org/apache/comet/serde/strings.scala @@ -21,7 +21,7 @@ package org.apache.comet.serde import java.util.Locale -import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Concat, ConcatWs, Elt, Expression, FindInSet, FormatNumber, FormatString, GetJsonObject, If, InitCap, IsNull, Left, Length, Levenshtein, Like, Literal, Lower, Overlay, RegExpExtract, RegExpExtractAll, RegExpInStr, RegExpReplace, Right, RLike, SoundEx, StringLocate, StringLPad, StringRepeat, StringReplace, StringRPad, StringSplit, Substring, SubstringIndex, ToCharacter, ToNumber, UnBase64, Upper} +import org.apache.spark.sql.catalyst.expressions.{Attribute, BitLength, Cast, Concat, ConcatWs, Elt, Expression, FindInSet, FormatNumber, FormatString, GetJsonObject, If, InitCap, IsNull, Left, Length, Levenshtein, Like, Literal, Lower, OctetLength, Overlay, RegExpExtract, RegExpExtractAll, RegExpInStr, RegExpReplace, Right, RLike, SoundEx, StringLocate, StringLPad, StringRepeat, StringReplace, StringRPad, StringSplit, StringTranslate, Substring, SubstringIndex, ToCharacter, ToNumber, UnBase64, Upper} import org.apache.spark.sql.types.{BinaryType, DataTypes, LongType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -84,6 +84,35 @@ object CometLength extends CometScalarFunction[Length]("length") { } } +object CometBitLength extends CometScalarFunction[BitLength]("bit_length") { + override def getUnsupportedReasons(): Seq[String] = Seq("`BinaryType` input is not supported") + + override def getSupportLevel(expr: BitLength): SupportLevel = expr.child.dataType match { + case _: BinaryType => Unsupported(Some("bit_length on BinaryType is not supported")) + case _ => Compatible() + } +} + +object CometOctetLength extends CometScalarFunction[OctetLength]("octet_length") { + override def getUnsupportedReasons(): Seq[String] = Seq("`BinaryType` input is not supported") + + override def getSupportLevel(expr: OctetLength): SupportLevel = expr.child.dataType match { + case _: BinaryType => Unsupported(Some("octet_length on BinaryType is not supported")) + case _ => Compatible() + } +} + +object CometStringTranslate extends CometScalarFunction[StringTranslate]("translate") { + private val incompatReason = + "DataFusion's translate iterates over Unicode graphemes (Spark uses code points) and" + + " substitutes U+0000 instead of treating it as a deletion sentinel" + + override def getIncompatibleReasons(): Seq[String] = Seq(incompatReason) + + override def getSupportLevel(expr: StringTranslate): SupportLevel = Incompatible( + Some(incompatReason)) +} + object CometInitCap extends CometScalarFunction[InitCap]("initcap") { override def getSupportLevel(expr: InitCap): SupportLevel = Compatible() diff --git a/spark/src/test/resources/sql-tests/expressions/string/bit_length.sql b/spark/src/test/resources/sql-tests/expressions/string/bit_length.sql index 34565d6fe3..a232712016 100644 --- a/spark/src/test/resources/sql-tests/expressions/string/bit_length.sql +++ b/spark/src/test/resources/sql-tests/expressions/string/bit_length.sql @@ -27,3 +27,17 @@ SELECT bit_length(s) FROM test_bit_length -- literal arguments query SELECT bit_length('hello'), bit_length(''), bit_length(NULL) + +-- BinaryType input falls back to Spark; the native DataFusion impl rejects Binary at runtime, +-- so the serde gates Binary as Unsupported (matching the existing CometLength shape). +statement +CREATE TABLE test_bit_length_binary(b binary) USING parquet + +statement +INSERT INTO test_bit_length_binary VALUES (X'48656c6c6f'), (X''), (NULL), (X'FF') + +query expect_fallback(bit_length on BinaryType is not supported) +SELECT bit_length(b) FROM test_bit_length_binary + +query expect_fallback(bit_length on BinaryType is not supported) +SELECT bit_length(X'48656c6c6f'), bit_length(CAST(NULL AS BINARY)) diff --git a/spark/src/test/resources/sql-tests/expressions/string/octet_length.sql b/spark/src/test/resources/sql-tests/expressions/string/octet_length.sql index 095d3e30a2..e650950347 100644 --- a/spark/src/test/resources/sql-tests/expressions/string/octet_length.sql +++ b/spark/src/test/resources/sql-tests/expressions/string/octet_length.sql @@ -27,3 +27,17 @@ SELECT octet_length(s) FROM test_octet_length -- literal arguments query SELECT octet_length('hello'), octet_length(''), octet_length(NULL) + +-- BinaryType input falls back to Spark; the native DataFusion impl rejects Binary at runtime, +-- so the serde gates Binary as Unsupported (matching the existing CometLength shape). +statement +CREATE TABLE test_octet_length_binary(b binary) USING parquet + +statement +INSERT INTO test_octet_length_binary VALUES (X'48656c6c6f'), (X''), (NULL), (X'FF') + +query expect_fallback(octet_length on BinaryType is not supported) +SELECT octet_length(b) FROM test_octet_length_binary + +query expect_fallback(octet_length on BinaryType is not supported) +SELECT octet_length(X'48656c6c6f'), octet_length(CAST(NULL AS BINARY)) diff --git a/spark/src/test/resources/sql-tests/expressions/string/string_translate.sql b/spark/src/test/resources/sql-tests/expressions/string/string_translate.sql index 88bd1aa935..d9dabde9f5 100644 --- a/spark/src/test/resources/sql-tests/expressions/string/string_translate.sql +++ b/spark/src/test/resources/sql-tests/expressions/string/string_translate.sql @@ -15,23 +15,29 @@ -- specific language governing permissions and limitations -- under the License. +-- translate is gated as Incompatible by default. DataFusion's translate iterates over Unicode +-- graphemes (Spark uses code points) and substitutes U+0000 instead of treating it as a deletion +-- sentinel, so the native path silently diverges from Spark for combining-mark inputs and for +-- to=NUL. These default-config tests assert that the expression falls back cleanly to Spark. +-- See string_translate_enabled.sql for the opt-in native path. + statement CREATE TABLE test_translate(s string, from_str string, to_str string) USING parquet statement INSERT INTO test_translate VALUES ('hello', 'el', 'ip'), ('hello', 'aeiou', '12345'), ('', 'a', 'b'), (NULL, 'a', 'b'), ('hello', '', ''), ('abc', 'abc', 'x') -query +query expect_fallback(is not fully compatible with Spark) SELECT translate(s, from_str, to_str) FROM test_translate -- column + literal + literal -query +query expect_fallback(is not fully compatible with Spark) SELECT translate(s, 'el', 'ip') FROM test_translate -- literal + column + column -query +query expect_fallback(is not fully compatible with Spark) SELECT translate('hello', from_str, to_str) FROM test_translate -- literal + literal + literal -query +query expect_fallback(is not fully compatible with Spark) SELECT translate('hello', 'el', 'ip'), translate('hello', 'aeiou', '12345'), translate('', 'a', 'b'), translate(NULL, 'a', 'b') diff --git a/spark/src/test/resources/sql-tests/expressions/string/string_translate_enabled.sql b/spark/src/test/resources/sql-tests/expressions/string/string_translate_enabled.sql new file mode 100644 index 0000000000..9249730c70 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/string/string_translate_enabled.sql @@ -0,0 +1,43 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Tests for the native translate path, which the user must opt into via +-- spark.comet.expression.StringTranslate.allowIncompatible=true (gated as Incompatible because +-- DataFusion's translate diverges from Spark on combining-mark inputs and on to=NUL deletion). +-- These ASCII-only tests run under inputs where the two implementations agree. +-- Config: spark.comet.expression.StringTranslate.allowIncompatible=true + +statement +CREATE TABLE test_translate_enabled(s string, from_str string, to_str string) USING parquet + +statement +INSERT INTO test_translate_enabled VALUES ('hello', 'el', 'ip'), ('hello', 'aeiou', '12345'), ('', 'a', 'b'), (NULL, 'a', 'b'), ('hello', '', ''), ('abc', 'abc', 'x') + +query +SELECT translate(s, from_str, to_str) FROM test_translate_enabled + +-- column + literal + literal +query +SELECT translate(s, 'el', 'ip') FROM test_translate_enabled + +-- literal + column + column +query +SELECT translate('hello', from_str, to_str) FROM test_translate_enabled + +-- literal + literal + literal +query +SELECT translate('hello', 'el', 'ip'), translate('hello', 'aeiou', '12345'), translate('', 'a', 'b'), translate(NULL, 'a', 'b') diff --git a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala index d07257442d..0b5acc70a2 100644 --- a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala @@ -373,12 +373,15 @@ class CometStringExpressionSuite extends CometTestBase { test("length, reverse, instr, replace, translate") { val table = "test" - withTable(table) { - sql(s"create table $table(col string) using parquet") - sql( - s"insert into $table values('Spark SQL '), (NULL), (''), ('苹果手机'), ('Spark SQL '), (NULL), (''), ('苹果手机')") - checkSparkAnswerAndOperator("select length(col), reverse(col), instr(col, 'SQL'), instr(col, '手机'), replace(col, 'SQL', '123')," + - s" replace(col, 'SQL'), replace(col, '手机', '平板'), translate(col, 'SL苹', '123') from $table") + withSQLConf("spark.comet.expression.StringTranslate.allowIncompatible" -> "true") { + withTable(table) { + sql(s"create table $table(col string) using parquet") + sql( + s"insert into $table values('Spark SQL '), (NULL), (''), ('苹果手机'), ('Spark SQL '), (NULL), (''), ('苹果手机')") + checkSparkAnswerAndOperator( + "select length(col), reverse(col), instr(col, 'SQL'), instr(col, '手机'), replace(col, 'SQL', '123')," + + s" replace(col, 'SQL'), replace(col, '手机', '平板'), translate(col, 'SL苹', '123') from $table") + } } }