Skip to content

Commit eb1b5a3

Browse files
Modernize inconsistent equality
1 parent 4c5c4e0 commit eb1b5a3

File tree

6 files changed

+92
-54
lines changed

6 files changed

+92
-54
lines changed

python/ql/lib/semmle/python/Class.qll

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,12 @@ class Class extends Class_, Scope, AstNode {
9191
/** Gets a method defined in this class */
9292
Function getAMethod() { result.getScope() = this }
9393

94+
/** Gets the method defined in this class with the specified name, if any. */
95+
Function getMethod(string name) {
96+
result = this.getAMethod() and
97+
result.getName() = name
98+
}
99+
94100
override Location getLocation() { py_scope_location(result, this) }
95101

96102
/** Gets the scope (module, class or function) in which this class is defined */
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
/** Helper definitions for reasoning about comparison methods. */
2+
3+
import python
4+
import semmle.python.ApiGraphs
5+
6+
/** Holds if `cls` has the `functools.total_ordering` decorator. */
7+
predicate totalOrdering(Class cls) {
8+
cls.getADecorator() =
9+
API::moduleImport("functools").getMember("total_ordering").asSource().asExpr()
10+
}

python/ql/src/Classes/Comparisons/EqualsOrNotEquals.ql

Lines changed: 24 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
* @name Inconsistent equality and inequality
33
* @description Defining only an equality method or an inequality method for a class violates the object model.
44
* @kind problem
5-
* @tags reliability
5+
* @tags quality
6+
* reliability
67
* correctness
78
* @problem.severity warning
89
* @sub-severity high
@@ -11,38 +12,29 @@
1112
*/
1213

1314
import python
14-
import Equality
15+
import Comparisons
16+
import semmle.python.dataflow.new.internal.DataFlowDispatch
17+
import Classes.Equality
1518

16-
string equals_or_ne() { result = "__eq__" or result = "__ne__" }
17-
18-
predicate total_ordering(Class cls) {
19-
exists(Attribute a | a = cls.getADecorator() | a.getName() = "total_ordering")
19+
predicate missingEquality(Class cls, Function defined, string missing) {
20+
defined = cls.getMethod("__ne__") and
21+
not exists(cls.getMethod("__eq__")) and
22+
missing = "__eq__"
2023
or
21-
exists(Name n | n = cls.getADecorator() | n.getId() = "total_ordering")
22-
}
23-
24-
CallableValue implemented_method(ClassValue c, string name) {
25-
result = c.declaredAttribute(name) and name = equals_or_ne()
26-
}
27-
28-
string unimplemented_method(ClassValue c) {
29-
not c.declaresAttribute(result) and result = equals_or_ne()
30-
}
31-
32-
predicate violates_equality_contract(
33-
ClassValue c, string present, string missing, CallableValue method
34-
) {
35-
missing = unimplemented_method(c) and
36-
method = implemented_method(c, present) and
37-
not c.failedInference(_) and
38-
not total_ordering(c.getScope()) and
39-
/* Python 3 automatically implements __ne__ if __eq__ is defined, but not vice-versa */
40-
not (major_version() = 3 and present = "__eq__" and missing = "__ne__") and
41-
not method.getScope() instanceof DelegatingEqualityMethod and
42-
not c.lookup(missing).(CallableValue).getScope() instanceof DelegatingEqualityMethod
24+
// In python 3, __ne__ automatically delegates to __eq__ if its not defined in the hierarchy
25+
// However if it is defined in a superclass (and isn't a delegation method) then it will use the superclass method (which may be incorrect)
26+
defined = cls.getMethod("__eq__") and
27+
not exists(cls.getMethod("__ne__")) and
28+
exists(Function neMeth |
29+
neMeth = getADirectSuperclass+(cls).getMethod("__ne__") and
30+
not neMeth instanceof DelegatingEqualityMethod
31+
) and
32+
missing = "__ne__"
4333
}
4434

45-
from ClassValue c, string present, string missing, CallableValue method
46-
where violates_equality_contract(c, present, missing, method)
47-
select method, "Class $@ implements " + present + " but does not implement " + missing + ".", c,
48-
c.getName()
35+
from Class cls, Function defined, string missing
36+
where
37+
not totalOrdering(cls) and
38+
missingEquality(cls, defined, missing)
39+
select cls, "This class implements $@, but does not implement " + missing + ".", defined,
40+
defined.getName()

python/ql/src/Classes/Comparisons/IncompleteOrdering.ql

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,29 +14,20 @@
1414
import python
1515
import semmle.python.dataflow.new.internal.DataFlowDispatch
1616
import semmle.python.ApiGraphs
17-
18-
predicate totalOrdering(Class cls) {
19-
cls.getADecorator() =
20-
API::moduleImport("functools").getMember("total_ordering").asSource().asExpr()
21-
}
22-
23-
Function getMethod(Class cls, string name) {
24-
result = cls.getAMethod() and
25-
result.getName() = name
26-
}
17+
import Comparisons
2718

2819
predicate definesStrictOrdering(Class cls, Function meth) {
29-
meth = getMethod(cls, "__lt__")
20+
meth = cls.getMethod("__lt__")
3021
or
31-
not exists(getMethod(cls, "__lt__")) and
32-
meth = getMethod(cls, "__gt__")
22+
not exists(cls.getMethod("__lt__")) and
23+
meth = cls.getMethod("__gt__")
3324
}
3425

3526
predicate definesNonStrictOrdering(Class cls, Function meth) {
36-
meth = getMethod(cls, "__le__")
27+
meth = cls.getMethod("__le__")
3728
or
38-
not exists(getMethod(cls, "__le__")) and
39-
meth = getMethod(cls, "__ge__")
29+
not exists(cls.getMethod("__le__")) and
30+
meth = cls.getMethod("__ge__")
4031
}
4132

4233
predicate missingComparison(Class cls, Function defined, string missing) {
@@ -53,5 +44,5 @@ from Class cls, Function defined, string missing
5344
where
5445
not totalOrdering(cls) and
5546
missingComparison(cls, defined, missing)
56-
select cls, "This class implements $@, but does not implement an " + missing + " method.", defined,
47+
select cls, "This class implements $@, but does not implement " + missing + ".", defined,
5748
defined.getName()

python/ql/src/Classes/Comparisons/examples/EqualsOrNotEquals.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,27 @@ def __eq__(self, other):
3030
def __ne__(self, other): # Improved: equality and inequality method defined (hash method still missing)
3131
return not self == other
3232

33+
34+
35+
class A:
36+
def __init__(self, a):
37+
self.a = a
38+
39+
def __eq__(self, other):
40+
print("A eq")
41+
return self.a == other.a
42+
43+
def __ne__(self, other):
44+
print("A ne")
45+
return self.a != other.a
46+
47+
class B(A):
48+
def __init__(self, a, b):
49+
self.a = a
50+
self.b = b
51+
52+
def __eq__(self, other):
53+
print("B eq")
54+
return self.a == other.a and self.b == other.b
55+
56+
print(B(1,2) != B(1,3))

python/ql/src/Classes/Equality.qll

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
/** Utility definitions for reasoning about equality methods. */
2+
13
import python
4+
import semmle.python.dataflow.new.DataFlow
25

36
private Attribute dictAccess(LocalVariable var) {
47
result.getName() = "__dict__" and
@@ -59,16 +62,28 @@ class IdentityEqMethod extends Function {
5962
/** An (in)equality method that delegates to its complement */
6063
class DelegatingEqualityMethod extends Function {
6164
DelegatingEqualityMethod() {
62-
exists(Return ret, UnaryExpr not_, Compare comp, Cmpop op, Parameter p0, Parameter p1 |
65+
exists(Return ret, UnaryExpr not_, Expr comp, Parameter p0, Parameter p1 |
6366
ret.getScope() = this and
6467
ret.getValue() = not_ and
6568
not_.getOp() instanceof Not and
66-
not_.getOperand() = comp and
67-
comp.compares(p0.getVariable().getAnAccess(), op, p1.getVariable().getAnAccess())
69+
not_.getOperand() = comp
6870
|
69-
this.getName() = "__eq__" and op instanceof NotEq
71+
exists(Cmpop op |
72+
comp.(Compare).compares(p0.getVariable().getAnAccess(), op, p1.getVariable().getAnAccess())
73+
|
74+
this.getName() = "__eq__" and op instanceof NotEq
75+
or
76+
this.getName() = "__ne__" and op instanceof Eq
77+
)
7078
or
71-
this.getName() = "__ne__" and op instanceof Eq
79+
exists(DataFlow::MethodCallNode call, string name |
80+
call.calls(DataFlow::exprNode(p0.getVariable().getAnAccess()), name) and
81+
call.getArg(0).asExpr() = p1.getVariable().getAnAccess()
82+
|
83+
this.getName() = "__eq__" and name = "__ne__"
84+
or
85+
this.getName() = "__ne__" and name = "__eq__"
86+
)
7287
)
7388
}
7489
}

0 commit comments

Comments
 (0)