diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/udf/AliasUdf.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/udf/AliasUdf.java index ada1808351f8a4..8ab8dba509958e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/udf/AliasUdf.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/udf/AliasUdf.java @@ -20,7 +20,6 @@ import org.apache.doris.catalog.AliasFunction; import org.apache.doris.catalog.Env; import org.apache.doris.catalog.FunctionSignature; -import org.apache.doris.nereids.analyzer.UnboundFunction; import org.apache.doris.nereids.parser.NereidsParser; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; @@ -42,23 +41,15 @@ * alias function */ public class AliasUdf extends ScalarFunction implements ExplicitlyCastableSignature { - private final UnboundFunction unboundFunction; + private final Expression unboundFunction; private final List parameters; private final List argTypes; private final Map sessionVariables; - /** - * constructor - */ - public AliasUdf(String name, List argTypes, UnboundFunction unboundFunction, - List parameters, Expression... arguments) { - this(name, argTypes, unboundFunction, parameters, null, arguments); - } - /** * constructor with session variables */ - public AliasUdf(String name, List argTypes, UnboundFunction unboundFunction, + public AliasUdf(String name, List argTypes, Expression unboundFunction, List parameters, Map sessionVariables, Expression... arguments) { super(name, arguments); this.argTypes = argTypes; @@ -76,7 +67,7 @@ public List getParameters() { return parameters; } - public UnboundFunction getUnboundFunction() { + public Expression getUnboundFunction() { return unboundFunction; } @@ -107,7 +98,7 @@ public static void translateToNereidsFunction(String dbName, AliasFunction funct AliasUdf aliasUdf = new AliasUdf( function.functionName(), Arrays.stream(function.getArgs()).map(DataType::fromCatalogType).collect(Collectors.toList()), - ((UnboundFunction) parsedFunction), + parsedFunction, function.getParameters(), sessionVariables); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/CreateFunctionCommand.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/CreateFunctionCommand.java index cec745e1cd4213..2ba2ed28d95996 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/CreateFunctionCommand.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/CreateFunctionCommand.java @@ -67,10 +67,15 @@ import org.apache.doris.nereids.trees.expressions.IntegralDivide; import org.apache.doris.nereids.trees.expressions.Mod; import org.apache.doris.nereids.trees.expressions.Multiply; +import org.apache.doris.nereids.trees.expressions.Placeholder; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.SubqueryExpr; import org.apache.doris.nereids.trees.expressions.Subtract; +import org.apache.doris.nereids.trees.expressions.WindowExpression; import org.apache.doris.nereids.trees.expressions.functions.BoundFunction; +import org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingScalarFunction; +import org.apache.doris.nereids.trees.expressions.functions.table.TableValuedFunction; import org.apache.doris.nereids.trees.plans.PlanType; import org.apache.doris.nereids.trees.plans.commands.info.FunctionArgTypesInfo; import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation; @@ -1025,7 +1030,7 @@ private Expr translateToLegacyExpr(Expression expression, ConnectContext ctx) th ConnectContext.get().getStatementContext().getNextRelationId(), new ArrayList<>()); CascadesContext cascadesContext = CascadesContext.initContext(ctx.getStatementContext(), plan, PhysicalProperties.ANY); - Map argTypeMap = new CaseInsensitiveMap(); + Map argTypeMap = new CaseInsensitiveMap<>(); List argTypes = argsDef.getArgTypeDefs(); if (!parameters.isEmpty()) { if (parameters.size() != argTypes.size()) { @@ -1039,6 +1044,13 @@ private Expr translateToLegacyExpr(Expression expression, ConnectContext ctx) th ExpressionAnalyzer analyzer = new CustomExpressionAnalyzer(cascadesContext, argTypeMap); expression = analyzer.analyze(expression); + if (expression.containsType( + org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction.class, + GroupingScalarFunction.class, WindowExpression.class, Placeholder.class, + TableValuedFunction.class, SubqueryExpr.class)) { + throw new AnalysisException("Alias function only supports scalar functions."); + } + PlanTranslatorContext translatorContext = new PlanTranslatorContext(cascadesContext); ExpressionToExpr translator = new ExpressionToExpr(); return expression.accept(translator, translatorContext); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/UdfTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/UdfTest.java index d2e004d0a54113..73a44299f7375e 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/UdfTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/UdfTest.java @@ -188,6 +188,37 @@ public void testParameterUseMoreThanOneTime() throws Exception { ); } + @Test + public void testAliasFunctionWithCastOutermostExpression() throws Exception { + // Bug fix: when the outermost expression of an alias function body is a Cast, + // AliasUdf.translateToNereidsFunction previously cast parsedFunction to UnboundFunction, + // causing ClassCastException. After the fix, parsedFunction is kept as Expression. + createFunction( + "create alias function f_cast_varchar(int) with parameter(n) as cast(n as varchar(20))"); + + Assertions.assertEquals(1, Env.getCurrentEnv().getFunctionRegistry() + .findUdfBuilder(connectContext.getDatabase(), "f_cast_varchar").size()); + + // Verify the function can be used in a query without error + PlanChecker.from(connectContext) + .analyze("select f_cast_varchar(42)") + .matches( + logicalOneRowRelation() + .when(oneRow -> oneRow.getProjects().size() == 1) + ); + } + + @Test + public void testAliasFunctionWithIllegalExpressionsRejected() throws Exception { + // Bug fix: before the fix, alias functions containing aggregate functions in their body + // could be created successfully, which is incorrect behavior. + // After the fix, they are rejected with a clear error message. + Exception e = Assertions.assertThrows(Exception.class, () -> + createFunction( + "create alias function f_agg_rejected(int) with parameter(n) as sum(n)")); + Assertions.assertTrue(e.getMessage().contains("Alias function only supports scalar functions.")); + } + @Test public void testReadFromStream() throws Exception { createFunction("create global alias function f8(int) with parameter(n) as hours_add(now(3), n)");