Skip to content

Commit 6f39b3a

Browse files
committed
Fix const fold
1 parent cf7db21 commit 6f39b3a

3 files changed

Lines changed: 75 additions & 25 deletions

File tree

optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java

Lines changed: 58 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import static com.google.common.base.Preconditions.checkNotNull;
1717
import static com.google.common.collect.ImmutableList.toImmutableList;
18-
import static com.google.common.collect.MoreCollectors.onlyElement;
1918
import static dev.cel.checker.CelStandardDeclarations.StandardFunction.DURATION;
2019
import static dev.cel.checker.CelStandardDeclarations.StandardFunction.TIMESTAMP;
2120

@@ -183,9 +182,9 @@ private boolean canFold(CelNavigableMutableExpr navigableExpr) {
183182
if (functionName.equals(Operator.EQUALS.getFunction())
184183
|| functionName.equals(Operator.NOT_EQUALS.getFunction())) {
185184
if (mutableCall.args().stream()
186-
.anyMatch(node -> isExprConstantOfKind(node, CelConstant.Kind.BOOLEAN_VALUE))
185+
.anyMatch(node -> isExprConstantOfKind(node, CelConstant.Kind.BOOLEAN_VALUE))
187186
|| mutableCall.args().stream()
188-
.allMatch(node -> node.getKind().equals(Kind.CONSTANT))) {
187+
.allMatch(node -> node.getKind().equals(Kind.CONSTANT))) {
189188
return true;
190189
}
191190
}
@@ -196,17 +195,69 @@ private boolean canFold(CelNavigableMutableExpr navigableExpr) {
196195

197196
// Default case: all call arguments must be constants. If the argument is a container (ex:
198197
// list, map), then its arguments must be a constant.
199-
return areChildrenArgConstant(navigableExpr);
198+
return navigableExpr.children().allMatch(this::canEvaluate);
200199
case SELECT:
201-
CelNavigableMutableExpr operand = navigableExpr.children().collect(onlyElement());
202-
return areChildrenArgConstant(operand);
200+
return navigableExpr.children().allMatch(this::canEvaluate);
203201
case COMPREHENSION:
204-
return !isNestedComprehension(navigableExpr);
202+
if (isNestedComprehension(navigableExpr)) {
203+
return false;
204+
}
205+
CelMutableComprehension comprehension = navigableExpr.expr().comprehension();
206+
207+
if (!canEvaluate(CelNavigableMutableExpr.fromExpr(comprehension.iterRange()))
208+
|| !canEvaluate(CelNavigableMutableExpr.fromExpr(comprehension.accuInit()))) {
209+
return false;
210+
}
211+
212+
return canEvaluateComprehensionBody(CelNavigableMutableExpr.fromExpr(comprehension.loopStep()))
213+
&& canEvaluateComprehensionBody(CelNavigableMutableExpr.fromExpr(comprehension.loopCondition()));
205214
default:
206215
return false;
207216
}
208217
}
209218

219+
/**
220+
* Checks if a subtree is safe to evaluate (i.e: it evaluates down to a constant expression)
221+
*/
222+
private boolean canEvaluate(CelNavigableMutableExpr expression) {
223+
return expression.allNodes().allMatch(this::isAllowedInConstantExpr);
224+
}
225+
226+
private boolean canEvaluateComprehensionBody(CelNavigableMutableExpr expression) {
227+
return expression.allNodes().allMatch(node -> {
228+
Kind kind = node.getKind();
229+
if (kind.equals(Kind.IDENT) || kind.equals(Kind.COMPREHENSION)) {
230+
return true;
231+
}
232+
return isAllowedInConstantExpr(node);
233+
});
234+
}
235+
236+
private boolean isAllowedInConstantExpr(CelNavigableMutableExpr node) {
237+
Kind kind = node.getKind();
238+
if (kind.equals(Kind.CONSTANT)
239+
|| kind.equals(Kind.LIST)
240+
|| kind.equals(Kind.MAP)
241+
|| kind.equals(Kind.STRUCT)
242+
|| kind.equals(Kind.SELECT)) {
243+
return true;
244+
}
245+
if (kind.equals(Kind.CALL)) {
246+
CelMutableCall call = node.expr().call();
247+
return foldableFunctions.contains(call.function());
248+
}
249+
250+
return false;
251+
}
252+
253+
private boolean isAllowedInFoldableExpr(CelNavigableMutableExpr node) {
254+
Kind kind = node.getKind();
255+
if (kind.equals(Kind.IDENT) || kind.equals(Kind.COMPREHENSION)) {
256+
return true;
257+
}
258+
return isAllowedInConstantExpr(node);
259+
}
260+
210261
private boolean containsFoldableFunctionOnly(CelNavigableMutableExpr navigableExpr) {
211262
return navigableExpr
212263
.allNodes()
@@ -248,22 +299,6 @@ private static boolean canFoldInOperator(CelNavigableMutableExpr navigableExpr)
248299
return true;
249300
}
250301

251-
private static boolean areChildrenArgConstant(CelNavigableMutableExpr expr) {
252-
if (expr.getKind().equals(Kind.CONSTANT)) {
253-
return true;
254-
}
255-
256-
if (expr.getKind().equals(Kind.CALL)
257-
|| expr.getKind().equals(Kind.LIST)
258-
|| expr.getKind().equals(Kind.MAP)
259-
|| expr.getKind().equals(Kind.SELECT)
260-
|| expr.getKind().equals(Kind.STRUCT)) {
261-
return expr.children().allMatch(ConstantFoldingOptimizer::areChildrenArgConstant);
262-
}
263-
264-
return false;
265-
}
266-
267302
private static boolean isNestedComprehension(CelNavigableMutableExpr expr) {
268303
Optional<CelNavigableMutableExpr> maybeParent = expr.parent();
269304
while (maybeParent.isPresent()) {

optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
public class ConstantFoldingOptimizerTest {
4949
private static final CelOptions CEL_OPTIONS =
5050
CelOptions.current()
51+
.populateMacroCalls(true)
5152
.enableTimestampEpoch(true)
5253
.build();
5354
private static final Cel CEL =
@@ -56,12 +57,23 @@ public class ConstantFoldingOptimizerTest {
5657
.addVar("y", SimpleType.DYN)
5758
.addVar("list_var", ListType.create(SimpleType.STRING))
5859
.addVar("map_var", MapType.create(SimpleType.STRING, SimpleType.STRING))
60+
.setStandardMacros(CelStandardMacro.STANDARD_MACROS)
5961
.addFunctionDeclarations(
6062
CelFunctionDecl.newFunctionDeclaration(
6163
"get_true",
62-
CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL)))
64+
CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL)),
65+
CelFunctionDecl.newFunctionDeclaration(
66+
"get_list",
67+
CelOverloadDecl.newGlobalOverload(
68+
"get_list_overload",
69+
ListType.create(SimpleType.INT),
70+
ListType.create(SimpleType.INT)))
71+
)
6372
.addFunctionBindings(
64-
CelFunctionBinding.from("get_true_overload", ImmutableList.of(), unused -> true))
73+
CelFunctionBinding.from("get_true_overload", ImmutableList.of(), unused -> true),
74+
CelFunctionBinding.from(
75+
"get_list_overload", ImmutableList.class, arg -> arg)
76+
)
6577
.addMessageTypes(TestAllTypes.getDescriptor())
6678
.setContainer(CelContainer.ofName("cel.expr.conformance.proto3"))
6779
.setOptions(CEL_OPTIONS)
@@ -371,6 +383,8 @@ public void constantFold_macros_withoutMacroCallMetadata(String source) throws E
371383
@TestParameters("{source: 'x == 42'}")
372384
@TestParameters("{source: 'timestamp(100)'}")
373385
@TestParameters("{source: 'duration(\"1h\")'}")
386+
@TestParameters("{source: '[true].exists(x, x == get_true())'}")
387+
@TestParameters("{source: 'get_list([1, 2]).map(x, x * 2)'}")
374388
public void constantFold_noOp(String source) throws Exception {
375389
CelAbstractSyntaxTree ast = CEL.compile(source).getAst();
376390

tools/src/test/java/dev/cel/tools/ai/AgenticPolicyCompilerTest.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@ private void runTests(Cel cel, CelAbstractSyntaxTree ast, PolicyTestSuite testSu
283283
? (List<AgentMessage>) inputMap.get("_test_history")
284284
: ImmutableList.of();
285285

286+
@SuppressWarnings("Immutable")
286287
CelLateFunctionBindings bindings = CelLateFunctionBindings.from(
287288
CelFunctionBinding.from(
288289
"agent_history",

0 commit comments

Comments
 (0)