1515
1616import static com .google .common .base .Preconditions .checkNotNull ;
1717import static com .google .common .collect .ImmutableList .toImmutableList ;
18- import static com .google .common .collect .MoreCollectors .onlyElement ;
1918import static dev .cel .checker .CelStandardDeclarations .StandardFunction .DURATION ;
2019import static dev .cel .checker .CelStandardDeclarations .StandardFunction .TIMESTAMP ;
2120
@@ -183,9 +182,9 @@ private boolean canFold(CelNavigableMutableExpr navigableExpr) {
183182 if (functionName .equals (Operator .EQUALS .getFunction ())
184183 || functionName .equals (Operator .NOT_EQUALS .getFunction ())) {
185184 if (mutableCall .args ().stream ()
186- .anyMatch (node -> isExprConstantOfKind (node , CelConstant .Kind .BOOLEAN_VALUE ))
185+ .anyMatch (node -> isExprConstantOfKind (node , CelConstant .Kind .BOOLEAN_VALUE ))
187186 || mutableCall .args ().stream ()
188- .allMatch (node -> node .getKind ().equals (Kind .CONSTANT ))) {
187+ .allMatch (node -> node .getKind ().equals (Kind .CONSTANT ))) {
189188 return true ;
190189 }
191190 }
@@ -196,17 +195,69 @@ private boolean canFold(CelNavigableMutableExpr navigableExpr) {
196195
197196 // Default case: all call arguments must be constants. If the argument is a container (ex:
198197 // list, map), then its arguments must be a constant.
199- return areChildrenArgConstant ( navigableExpr );
198+ return navigableExpr . children (). allMatch ( this :: canEvaluate );
200199 case SELECT :
201- CelNavigableMutableExpr operand = navigableExpr .children ().collect (onlyElement ());
202- return areChildrenArgConstant (operand );
200+ return navigableExpr .children ().allMatch (this ::canEvaluate );
203201 case COMPREHENSION :
204- return !isNestedComprehension (navigableExpr );
202+ if (isNestedComprehension (navigableExpr )) {
203+ return false ;
204+ }
205+ CelMutableComprehension comprehension = navigableExpr .expr ().comprehension ();
206+
207+ if (!canEvaluate (CelNavigableMutableExpr .fromExpr (comprehension .iterRange ()))
208+ || !canEvaluate (CelNavigableMutableExpr .fromExpr (comprehension .accuInit ()))) {
209+ return false ;
210+ }
211+
212+ return canEvaluateComprehensionBody (CelNavigableMutableExpr .fromExpr (comprehension .loopStep ()))
213+ && canEvaluateComprehensionBody (CelNavigableMutableExpr .fromExpr (comprehension .loopCondition ()));
205214 default :
206215 return false ;
207216 }
208217 }
209218
219+ /**
220+ * Checks if a subtree is safe to evaluate (i.e: it evaluates down to a constant expression)
221+ */
222+ private boolean canEvaluate (CelNavigableMutableExpr expression ) {
223+ return expression .allNodes ().allMatch (this ::isAllowedInConstantExpr );
224+ }
225+
226+ private boolean canEvaluateComprehensionBody (CelNavigableMutableExpr expression ) {
227+ return expression .allNodes ().allMatch (node -> {
228+ Kind kind = node .getKind ();
229+ if (kind .equals (Kind .IDENT ) || kind .equals (Kind .COMPREHENSION )) {
230+ return true ;
231+ }
232+ return isAllowedInConstantExpr (node );
233+ });
234+ }
235+
236+ private boolean isAllowedInConstantExpr (CelNavigableMutableExpr node ) {
237+ Kind kind = node .getKind ();
238+ if (kind .equals (Kind .CONSTANT )
239+ || kind .equals (Kind .LIST )
240+ || kind .equals (Kind .MAP )
241+ || kind .equals (Kind .STRUCT )
242+ || kind .equals (Kind .SELECT )) {
243+ return true ;
244+ }
245+ if (kind .equals (Kind .CALL )) {
246+ CelMutableCall call = node .expr ().call ();
247+ return foldableFunctions .contains (call .function ());
248+ }
249+
250+ return false ;
251+ }
252+
253+ private boolean isAllowedInFoldableExpr (CelNavigableMutableExpr node ) {
254+ Kind kind = node .getKind ();
255+ if (kind .equals (Kind .IDENT ) || kind .equals (Kind .COMPREHENSION )) {
256+ return true ;
257+ }
258+ return isAllowedInConstantExpr (node );
259+ }
260+
210261 private boolean containsFoldableFunctionOnly (CelNavigableMutableExpr navigableExpr ) {
211262 return navigableExpr
212263 .allNodes ()
@@ -248,22 +299,6 @@ private static boolean canFoldInOperator(CelNavigableMutableExpr navigableExpr)
248299 return true ;
249300 }
250301
251- private static boolean areChildrenArgConstant (CelNavigableMutableExpr expr ) {
252- if (expr .getKind ().equals (Kind .CONSTANT )) {
253- return true ;
254- }
255-
256- if (expr .getKind ().equals (Kind .CALL )
257- || expr .getKind ().equals (Kind .LIST )
258- || expr .getKind ().equals (Kind .MAP )
259- || expr .getKind ().equals (Kind .SELECT )
260- || expr .getKind ().equals (Kind .STRUCT )) {
261- return expr .children ().allMatch (ConstantFoldingOptimizer ::areChildrenArgConstant );
262- }
263-
264- return false ;
265- }
266-
267302 private static boolean isNestedComprehension (CelNavigableMutableExpr expr ) {
268303 Optional <CelNavigableMutableExpr > maybeParent = expr .parent ();
269304 while (maybeParent .isPresent ()) {
0 commit comments