diff --git a/java/ql/lib/semmle/code/java/controlflow/Guards.qll b/java/ql/lib/semmle/code/java/controlflow/Guards.qll index be023939b8c0..8aa537cc24d0 100644 --- a/java/ql/lib/semmle/code/java/controlflow/Guards.qll +++ b/java/ql/lib/semmle/code/java/controlflow/Guards.qll @@ -294,27 +294,23 @@ private module GuardsInput implements SharedGuards::InputSig { equals.getNumberOfParameters() = 2 } - class EqualityTest extends Expr { - EqualityTest() { - this instanceof J::EqualityTest or - this.(MethodCall).getMethod() instanceof EqualsMethod or - objectsEquals(this.(MethodCall).getMethod()) - } - - Expr getAnOperand() { - result = this.(J::EqualityTest).getAnOperand() - or - result = this.(MethodCall).getAnArgument() - or - this.(MethodCall).getMethod() instanceof EqualsMethod and - result = this.(MethodCall).getQualifier() - } - - boolean polarity() { - result = this.(J::EqualityTest).polarity() + pragma[nomagic] + predicate equalityTest(Expr eqtest, Expr left, Expr right, boolean polarity) { + exists(EqualityTest eq | eq = eqtest | + eq.getLeftOperand() = left and + eq.getRightOperand() = right and + eq.polarity() = polarity + ) + or + exists(MethodCall call | call = eqtest and polarity = true | + call.getMethod() instanceof EqualsMethod and + call.getQualifier() = left and + call.getAnArgument() = right or - result = true and not this instanceof J::EqualityTest - } + objectsEquals(call.getMethod()) and + call.getArgument(0) = left and + call.getArgument(1) = right + ) } class ConditionalExpr extends Expr instanceof J::ConditionalExpr { diff --git a/shared/controlflow/codeql/controlflow/Guards.qll b/shared/controlflow/codeql/controlflow/Guards.qll index 627e0e1694fe..887eef9021a0 100644 --- a/shared/controlflow/codeql/controlflow/Guards.qll +++ b/shared/controlflow/codeql/controlflow/Guards.qll @@ -188,13 +188,12 @@ signature module InputSig { Expr getEqualChildExpr(); } - class EqualityTest extends Expr { - /** Gets an operand of this expression. */ - Expr getAnOperand(); - - /** Gets a boolean indicating whether this test is equality (true) or inequality (false). */ - boolean polarity(); - } + /** + * Holds if `eqtest` is an equality or inequality test between `left` and + * `right`. The `polarity` indicates whether this is an equality test (true) + * or inequality test (false). + */ + predicate equalityTest(Expr eqtest, Expr left, Expr right, boolean polarity); class ConditionalExpr extends Expr { /** Gets the condition of this expression. */ @@ -351,12 +350,10 @@ module Make Input> { c.nonMatchEdge(bb1, bb2) } - pragma[nomagic] - private predicate eqtestHasOperands(EqualityTest eqtest, Expr e1, Expr e2, boolean polarity) { - eqtest.getAnOperand() = e1 and - eqtest.getAnOperand() = e2 and - e1 != e2 and - eqtest.polarity() = polarity + private predicate equalityTestSymmetric(Expr eqtest, Expr e1, Expr e2, boolean eqval) { + equalityTest(eqtest, e1, e2, eqval) + or + equalityTest(eqtest, e2, e1, eqval) } private predicate constcaseEquality(PreGuard g, Expr e1, ConstantExpr e2) { @@ -424,7 +421,7 @@ module Make Input> { * to `eqval`. */ predicate isEquality(Expr e1, Expr e2, boolean eqval) { - eqtestHasOperands(this, e1, e2, eqval) + equalityTestSymmetric(this, e1, e2, eqval) or constcaseEquality(this, e1, e2) and eqval = true or @@ -466,7 +463,7 @@ module Make Input> { ) or exists(NonNullExpr nonnull | - eqtestHasOperands(g1, g2, nonnull, v1.asBooleanValue()) and + equalityTestSymmetric(g1, g2, nonnull, v1.asBooleanValue()) and v2.isNonNullValue() ) or @@ -589,7 +586,7 @@ module Make Input> { private predicate guardChecksEqualVars( Guard guard, SsaDefinition v1, SsaDefinition v2, boolean branch ) { - eqtestHasOperands(guard, v1.getARead(), v2.getARead(), branch) + equalityTestSymmetric(guard, v1.getARead(), v2.getARead(), branch) } private predicate guardReadsSsaVar(Guard guard, SsaDefinition def) { @@ -773,7 +770,7 @@ module Make Input> { or exists(Expr nonnull | exprHasValue(nonnull, v2) and - eqtestHasOperands(g1, g2, nonnull, v1.asBooleanValue()) and + equalityTestSymmetric(g1, g2, nonnull, v1.asBooleanValue()) and v2.isNonNullValue() ) }