Skip to content

Commit d4435d8

Browse files
committed
Rust: Disambiguate calls to associated functions
1 parent ae8af42 commit d4435d8

File tree

5 files changed

+79
-23
lines changed

5 files changed

+79
-23
lines changed

rust/ql/lib/codeql/rust/elements/internal/CallExprBaseImpl.qll

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ private import codeql.rust.elements.Resolvable
1313
*/
1414
module Impl {
1515
private import rust
16+
private import codeql.rust.internal.TypeInference as TypeInference
1617

1718
pragma[nomagic]
1819
Resolvable getCallResolvable(CallExprBase call) {
@@ -27,7 +28,11 @@ module Impl {
2728
*/
2829
class CallExprBase extends Generated::CallExprBase {
2930
/** Gets the static target of this call, if any. */
30-
Callable getStaticTarget() { none() } // overridden by subclasses, but cannot be made abstract
31+
final Callable getStaticTarget() {
32+
result = TypeInference::resolveMethodCallTarget(this)
33+
or
34+
result = TypeInference::resolveCallTarget(this)
35+
}
3136

3237
override Expr getArg(int index) { result = this.getArgList().getArg(index) }
3338
}

rust/ql/lib/codeql/rust/elements/internal/CallExprImpl.qll

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ private import codeql.rust.elements.PathExpr
1414
module Impl {
1515
private import rust
1616
private import codeql.rust.internal.PathResolution as PathResolution
17-
private import codeql.rust.internal.TypeInference as TypeInference
1817

1918
pragma[nomagic]
2019
Path getFunctionPath(CallExpr ce) { result = ce.getFunction().(PathExpr).getPath() }
@@ -37,15 +36,6 @@ module Impl {
3736
class CallExpr extends Generated::CallExpr {
3837
override string toStringImpl() { result = this.getFunction().toAbbreviatedString() + "(...)" }
3938

40-
override Callable getStaticTarget() {
41-
// If this call is to a trait method, e.g., `Trait::foo(bar)`, then check
42-
// if type inference can resolve it to the correct trait implementation.
43-
result = TypeInference::resolveMethodCallTarget(this)
44-
or
45-
not exists(TypeInference::resolveMethodCallTarget(this)) and
46-
result = getResolvedFunction(this)
47-
}
48-
4939
/** Gets the struct that this call resolves to, if any. */
5040
Struct getStruct() { result = getResolvedFunction(this) }
5141

rust/ql/lib/codeql/rust/elements/internal/CallImpl.qll

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ module Impl {
6262
Function getStaticTarget() {
6363
result = TypeInference::resolveMethodCallTarget(this)
6464
or
65-
not exists(TypeInference::resolveMethodCallTarget(this)) and
6665
result = this.(CallExpr).getStaticTarget()
6766
}
6867

rust/ql/lib/codeql/rust/elements/internal/MethodCallExprImpl.qll

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66

77
private import rust
88
private import codeql.rust.elements.internal.generated.MethodCallExpr
9-
private import codeql.rust.internal.PathResolution
10-
private import codeql.rust.internal.TypeInference
119

1210
/**
1311
* INTERNAL: This module contains the customizable definition of `MethodCallExpr` and should not
@@ -23,8 +21,6 @@ module Impl {
2321
* ```
2422
*/
2523
class MethodCallExpr extends Generated::MethodCallExpr {
26-
override Function getStaticTarget() { result = resolveMethodCallTarget(this) }
27-
2824
private string toStringPart(int index) {
2925
index = 0 and
3026
result = this.getReceiver().toAbbreviatedString()

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 73 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ private import codeql.rust.frameworks.stdlib.Stdlib
1111
private import codeql.rust.frameworks.stdlib.Builtins as Builtins
1212
private import codeql.rust.elements.Call
1313
private import codeql.rust.elements.internal.CallImpl::Impl as CallImpl
14+
private import codeql.rust.elements.internal.CallExprImpl::Impl as CallExprImpl
1415

1516
class Type = T::Type;
1617

@@ -724,8 +725,6 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
724725
}
725726
}
726727

727-
private import codeql.rust.elements.internal.CallExprImpl::Impl as CallExprImpl
728-
729728
final class Access extends Call {
730729
pragma[nomagic]
731730
Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
@@ -771,7 +770,9 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
771770
Declaration getTarget() {
772771
result = resolveMethodCallTarget(this) // mutual recursion; resolving method calls requires resolving types and vice versa
773772
or
774-
result = CallExprImpl::getResolvedFunction(this)
773+
result = resolveCallTargetSimple(this)
774+
or
775+
result = resolveCallTargetComplex(this) // mutual recursion
775776
}
776777
}
777778

@@ -1350,7 +1351,7 @@ private predicate implSiblingCandidate(
13501351
// siblings).
13511352
not exists(impl.getAttributeMacroExpansion()) and
13521353
// We use this for resolving methods, so exclude traits that do not have methods.
1353-
exists(Function f | f = trait.getASuccessor(_) and f.getParamList().hasSelfParam()) and
1354+
exists(Function f | f = trait.getASuccessor(_)) and
13541355
selfTy = impl.getSelfTy() and
13551356
rootType = selfTy.resolveType()
13561357
}
@@ -1499,6 +1500,58 @@ private Function getTraitMethod(ImplTraitReturnType trait, string name) {
14991500
result = getMethodSuccessor(trait.getImplTraitTypeRepr(), name)
15001501
}
15011502

1503+
pragma[nomagic]
1504+
private predicate assocFuncResolutionDependsOnArgument(Function f, Impl impl, int pos) {
1505+
methodResolutionDependsOnArgument(impl, _, f, pos, _, _)
1506+
}
1507+
1508+
private class AssocFuncCallExpr extends CallExpr {
1509+
private int pos;
1510+
1511+
AssocFuncCallExpr() {
1512+
assocFuncResolutionDependsOnArgument(CallExprImpl::getResolvedFunction(this), _, pos)
1513+
}
1514+
1515+
Function getACandidate(Impl impl) {
1516+
result = CallExprImpl::getResolvedFunction(this) and
1517+
assocFuncResolutionDependsOnArgument(result, impl, pos)
1518+
}
1519+
1520+
int getPosition() { result = pos }
1521+
1522+
/** Gets the type of the receiver of the associated function call at `path`. */
1523+
Type getTypeAt(TypePath path) { result = inferType(this.getArg(pos), path) }
1524+
}
1525+
1526+
private module AssocFuncIsInstantiationOfInput implements
1527+
IsInstantiationOfInputSig<AssocFuncCallExpr>
1528+
{
1529+
pragma[nomagic]
1530+
predicate potentialInstantiationOf(
1531+
AssocFuncCallExpr ce, TypeAbstraction impl, TypeMention constraint
1532+
) {
1533+
exists(Function cand |
1534+
cand = ce.getACandidate(impl) and
1535+
constraint = cand.getParam(ce.getPosition()).getTypeRepr()
1536+
)
1537+
}
1538+
}
1539+
1540+
pragma[nomagic]
1541+
ItemNode resolveCallTargetSimple(CallExpr ce) {
1542+
result = CallExprImpl::getResolvedFunction(ce) and
1543+
not assocFuncResolutionDependsOnArgument(result, _, _)
1544+
}
1545+
1546+
pragma[nomagic]
1547+
Function resolveCallTargetComplex(AssocFuncCallExpr ce) {
1548+
exists(Impl impl |
1549+
IsInstantiationOf<AssocFuncCallExpr, AssocFuncIsInstantiationOfInput>::isInstantiationOf(ce,
1550+
impl, _) and
1551+
result = getMethodSuccessor(impl, ce.getACandidate(_).getName().getText())
1552+
)
1553+
}
1554+
15021555
cached
15031556
private module Cached {
15041557
private import codeql.rust.internal.CachedStages
@@ -1541,6 +1594,14 @@ private module Cached {
15411594
result = getTraitMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
15421595
}
15431596

1597+
/** Gets a method that the method call `mc` resolves to, if any. */
1598+
cached
1599+
Function resolveCallTarget(CallExpr ce) {
1600+
result = resolveCallTargetSimple(ce)
1601+
or
1602+
result = resolveCallTargetComplex(ce)
1603+
}
1604+
15441605
pragma[inline]
15451606
private Type inferRootTypeDeref(AstNode n) {
15461607
result = inferType(n) and
@@ -1685,9 +1746,9 @@ private module Debug {
16851746
result = inferType(n, path)
16861747
}
16871748

1688-
Function debugResolveMethodCallTarget(Call mce) {
1689-
mce = getRelevantLocatable() and
1690-
result = resolveMethodCallTarget(mce)
1749+
Function debugResolveCallTarget(Call c) {
1750+
c = getRelevantLocatable() and
1751+
result = [resolveMethodCallTarget(c), resolveCallTarget(c)]
16911752
}
16921753

16931754
predicate debugInferImplicitSelfType(SelfParam self, TypePath path, Type t) {
@@ -1705,6 +1766,11 @@ private module Debug {
17051766
tm.resolveTypeAt(path) = type
17061767
}
17071768

1769+
Type debugInferAnnotatedType(AstNode n, TypePath path) {
1770+
n = getRelevantLocatable() and
1771+
result = inferAnnotatedType(n, path)
1772+
}
1773+
17081774
pragma[nomagic]
17091775
private int countTypesAtPath(AstNode n, TypePath path, Type t) {
17101776
t = inferType(n, path) and

0 commit comments

Comments
 (0)