diff --git a/shared/controlflow/codeql/controlflow/Guards.qll b/shared/controlflow/codeql/controlflow/Guards.qll index 17aee2a7caee..8fd39279ff65 100644 --- a/shared/controlflow/codeql/controlflow/Guards.qll +++ b/shared/controlflow/codeql/controlflow/Guards.qll @@ -1008,6 +1008,8 @@ module Make< * wrappers. */ private module WrapperGuard { + private import codeql.util.DenseRank + final private class FinalExpr = Expr; class ReturnExpr extends FinalExpr { @@ -1019,21 +1021,58 @@ module Make< BasicBlock getBasicBlock() { result = super.getBasicBlock() } } + private module DenseRankInput implements DenseRankInputSig1 { + class C = NonOverridableMethod; + + class Ranked = ReturnExpr; + + int getRank(NonOverridableMethod m, ReturnExpr ret) { + m.getAReturnExpr() = ret and + result = ret.getLocation().getStartLine() + } + } + + private module ReturnExprRank = DenseRank1; + + private predicate rankedReturnExpr = ReturnExprRank::denseRank/2; + + private int maxRank(NonOverridableMethod m) { + result = max(int rnk | exists(rankedReturnExpr(m, rnk))) + } + private predicate relevantCallValue(NonOverridableMethodCall call, GuardValue val) { BranchImplies::guardControls(call, val, _, _) or ReturnImplies::guardControls(call, val, _, _) } - predicate relevantReturnValue(NonOverridableMethod m, GuardValue val) { + /** + * Holds if a call to `m` having a return value of `retval` is reachable + * by a chain of implications. + */ + predicate relevantReturnValue(NonOverridableMethod m, GuardValue retval) { exists(NonOverridableMethodCall call | - relevantCallValue(call, val) and + relevantCallValue(call, retval) and call.getMethod() = m and - not val instanceof TException + not retval instanceof TException + ) + } + + /** + * Holds if a call to `m` having a return value of `retval` is reachable + * by a chain of implications, and `ret` is a return expression in `m` + * that could possibly have the value `retval`. + */ + predicate relevantReturnExprValue(NonOverridableMethod m, ReturnExpr ret, GuardValue retval) { + relevantReturnValue(m, retval) and + ret = m.getAReturnExpr() and + not exists(GuardValue notRetval | + exprHasValue(ret, notRetval) and + disjointValues(notRetval, retval) ) } private predicate returnGuard(Guard guard, GuardValue val) { - relevantReturnValue(guard.(ReturnExpr).getMethod(), val) + relevantReturnExprValue(_, guard, val) } module ReturnImplies = ImpliesTC; @@ -1043,28 +1082,58 @@ module Make< guard.directlyValueControls(ret.getBasicBlock(), val) } + private predicate parameterControlsReturnExpr( + SsaParameterInit param, GuardValue val, ReturnExpr ret + ) { + exists(Guard g0, GuardValue v0 | + directlyControlsReturn(g0, v0, ret) and + BranchImplies::ssaControls(param, val, g0, v0) + ) + } + /** * Holds if `ret` is a return expression in a non-overridable method that * on a return value of `retval` allows the conclusion that the `ppos`th * parameter has the value `val`. */ private predicate validReturnInCustomGuard( - ReturnExpr ret, ParameterPosition ppos, GuardValue retval, GuardValue val + ReturnExpr ret, int rnk, NonOverridableMethod m, ParameterPosition ppos, GuardValue retval, + GuardValue val ) { - exists(NonOverridableMethod m, SsaParameterInit param | - m.getAReturnExpr() = ret and + exists(SsaParameterInit param | + ret = rankedReturnExpr(m, rnk) and param.getParameter() = m.getParameter(ppos) | - exists(Guard g0, GuardValue v0 | - directlyControlsReturn(g0, v0, ret) and - BranchImplies::ssaControls(param, val, g0, v0) and - relevantReturnValue(m, retval) - ) + parameterControlsReturnExpr(param, val, ret) and + relevantReturnExprValue(m, ret, retval) or ReturnImplies::ssaControls(param, val, ret, retval) ) } + private predicate validReturnInCustomGuardToRank( + int rnk, NonOverridableMethod m, ParameterPosition ppos, GuardValue retval, GuardValue val + ) { + // The forall-range has been pushed all the way into + // `relevantReturnExprValue` and `validReturnInCustomGuard`. This means + // that this base case ensures that at least one return expression + // non-vacuously satisfies that it's a valid implication from return + // value to parameter value. + validReturnInCustomGuard(_, _, m, ppos, retval, val) and rnk = 0 + or + validReturnInCustomGuardToRank(rnk - 1, m, ppos, retval, val) and + rnk <= maxRank(m) and + forall(ReturnExpr ret | + ret = rankedReturnExpr(m, rnk) and + not exists(GuardValue notRetval | + exprHasValue(ret, notRetval) and + disjointValues(notRetval, retval) + ) + | + validReturnInCustomGuard(ret, rnk, m, ppos, retval, val) + ) + } + private predicate guardDirectlyControlsExit(Guard guard, GuardValue val) { exists(BasicBlock bb | guard.directlyValueControls(bb, val) and @@ -1080,15 +1149,7 @@ module Make< private NonOverridableMethod wrapperGuard( ParameterPosition ppos, GuardValue retval, GuardValue val ) { - forex(ReturnExpr ret | - result.getAReturnExpr() = ret and - not exists(GuardValue notRetval | - exprHasValue(ret, notRetval) and - disjointValues(notRetval, retval) - ) - | - validReturnInCustomGuard(ret, ppos, retval, val) - ) + validReturnInCustomGuardToRank(maxRank(result), result, ppos, retval, val) or exists(SsaParameterInit param, Guard g0, GuardValue v0 | param.getParameter() = result.getParameter(ppos) and @@ -1166,7 +1227,7 @@ module Make< guardChecksDef(guard, param, val, state) | guard.valueControls(ret.getBasicBlock(), val) and - relevantReturnValue(m, retval) + relevantReturnExprValue(m, ret, retval) or ReturnImplies::guardControls(guard, val, ret, retval) )