Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import org.apache.doris.catalog.AliasFunction;
import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.analyzer.UnboundFunction;
import org.apache.doris.nereids.parser.NereidsParser;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
Expand All @@ -42,23 +41,15 @@
* alias function
*/
public class AliasUdf extends ScalarFunction implements ExplicitlyCastableSignature {
private final UnboundFunction unboundFunction;
private final Expression unboundFunction;
private final List<String> parameters;
private final List<DataType> argTypes;
private final Map<String, String> sessionVariables;

/**
* constructor
*/
public AliasUdf(String name, List<DataType> argTypes, UnboundFunction unboundFunction,
List<String> parameters, Expression... arguments) {
this(name, argTypes, unboundFunction, parameters, null, arguments);
}

/**
* constructor with session variables
*/
public AliasUdf(String name, List<DataType> argTypes, UnboundFunction unboundFunction,
public AliasUdf(String name, List<DataType> argTypes, Expression unboundFunction,
List<String> parameters, Map<String, String> sessionVariables, Expression... arguments) {
super(name, arguments);
this.argTypes = argTypes;
Expand All @@ -76,7 +67,7 @@ public List<String> getParameters() {
return parameters;
}

public UnboundFunction getUnboundFunction() {
public Expression getUnboundFunction() {
return unboundFunction;
}

Expand Down Expand Up @@ -107,7 +98,7 @@ public static void translateToNereidsFunction(String dbName, AliasFunction funct
AliasUdf aliasUdf = new AliasUdf(
function.functionName(),
Arrays.stream(function.getArgs()).map(DataType::fromCatalogType).collect(Collectors.toList()),
((UnboundFunction) parsedFunction),
parsedFunction,
function.getParameters(),
sessionVariables);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,15 @@
import org.apache.doris.nereids.trees.expressions.IntegralDivide;
import org.apache.doris.nereids.trees.expressions.Mod;
import org.apache.doris.nereids.trees.expressions.Multiply;
import org.apache.doris.nereids.trees.expressions.Placeholder;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.SubqueryExpr;
import org.apache.doris.nereids.trees.expressions.Subtract;
import org.apache.doris.nereids.trees.expressions.WindowExpression;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
import org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingScalarFunction;
import org.apache.doris.nereids.trees.expressions.functions.table.TableValuedFunction;
import org.apache.doris.nereids.trees.plans.PlanType;
import org.apache.doris.nereids.trees.plans.commands.info.FunctionArgTypesInfo;
import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation;
Expand Down Expand Up @@ -1025,7 +1030,7 @@ private Expr translateToLegacyExpr(Expression expression, ConnectContext ctx) th
ConnectContext.get().getStatementContext().getNextRelationId(), new ArrayList<>());
CascadesContext cascadesContext = CascadesContext.initContext(ctx.getStatementContext(), plan,
PhysicalProperties.ANY);
Map<String, DataType> argTypeMap = new CaseInsensitiveMap();
Map<String, DataType> argTypeMap = new CaseInsensitiveMap<>();
List<DataType> argTypes = argsDef.getArgTypeDefs();
if (!parameters.isEmpty()) {
if (parameters.size() != argTypes.size()) {
Expand All @@ -1039,6 +1044,13 @@ private Expr translateToLegacyExpr(Expression expression, ConnectContext ctx) th
ExpressionAnalyzer analyzer = new CustomExpressionAnalyzer(cascadesContext, argTypeMap);
expression = analyzer.analyze(expression);

if (expression.containsType(
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction.class,
GroupingScalarFunction.class, WindowExpression.class, Placeholder.class,
TableValuedFunction.class, SubqueryExpr.class)) {
throw new AnalysisException("Alias function only supports scalar functions.");
}

PlanTranslatorContext translatorContext = new PlanTranslatorContext(cascadesContext);
ExpressionToExpr translator = new ExpressionToExpr();
return expression.accept(translator, translatorContext);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,37 @@ public void testParameterUseMoreThanOneTime() throws Exception {
);
}

@Test
public void testAliasFunctionWithCastOutermostExpression() throws Exception {
// Bug fix: when the outermost expression of an alias function body is a Cast,
// AliasUdf.translateToNereidsFunction previously cast parsedFunction to UnboundFunction,
// causing ClassCastException. After the fix, parsedFunction is kept as Expression.
createFunction(
"create alias function f_cast_varchar(int) with parameter(n) as cast(n as varchar(20))");

Assertions.assertEquals(1, Env.getCurrentEnv().getFunctionRegistry()
.findUdfBuilder(connectContext.getDatabase(), "f_cast_varchar").size());

// Verify the function can be used in a query without error
PlanChecker.from(connectContext)
.analyze("select f_cast_varchar(42)")
.matches(
logicalOneRowRelation()
.when(oneRow -> oneRow.getProjects().size() == 1)
);
}

@Test
public void testAliasFunctionWithIllegalExpressionsRejected() throws Exception {
// Bug fix: before the fix, alias functions containing aggregate functions in their body
// could be created successfully, which is incorrect behavior.
// After the fix, they are rejected with a clear error message.
Exception e = Assertions.assertThrows(Exception.class, () ->
createFunction(
"create alias function f_agg_rejected(int) with parameter(n) as sum(n)"));
Assertions.assertTrue(e.getMessage().contains("Alias function only supports scalar functions."));
}

@Test
public void testReadFromStream() throws Exception {
createFunction("create global alias function f8(int) with parameter(n) as hours_add(now(3), n)");
Expand Down
Loading