Skip to content

Commit 46364c7

Browse files
committed
Rust: Rework call resolution and type inference for calls
1 parent b59b8fa commit 46364c7

File tree

7 files changed

+2468
-1519
lines changed

7 files changed

+2468
-1519
lines changed

rust/ql/lib/codeql/rust/elements/internal/OperationImpl.qll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ private import codeql.rust.elements.internal.ExprImpl::Impl as ExprImpl
1212
* the canonical path `path` and the method name `method`, and if it borrows its
1313
* first `borrows` arguments.
1414
*/
15-
private predicate isOverloaded(string op, int arity, string path, string method, int borrows) {
15+
predicate isOverloaded(string op, int arity, string path, string method, int borrows) {
1616
arity = 1 and
1717
(
1818
// Negation

rust/ql/lib/codeql/rust/frameworks/stdlib/Stdlib.qll

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,3 +213,83 @@ class StringStruct extends Struct {
213213
pragma[nomagic]
214214
StringStruct() { this.getCanonicalPath() = "alloc::string::String" }
215215
}
216+
217+
/**
218+
* The [`Deref` trait][1].
219+
*
220+
* [1]: https://doc.rust-lang.org/core/ops/trait.Deref.html
221+
*/
222+
class DerefTrait extends Trait {
223+
pragma[nomagic]
224+
DerefTrait() { this.getCanonicalPath() = "core::ops::deref::Deref" }
225+
226+
/** Gets the `deref` function. */
227+
Function getDerefFunction() { result = this.(TraitItemNode).getAssocItem("deref") }
228+
229+
/** Gets the `Target` associated type. */
230+
pragma[nomagic]
231+
TypeAlias getTargetType() {
232+
result = this.getAssocItemList().getAnAssocItem() and
233+
result.getName().getText() = "Target"
234+
}
235+
}
236+
237+
/**
238+
* The [`Index` trait][1].
239+
*
240+
* [1]: https://doc.rust-lang.org/std/ops/trait.Index.html
241+
*/
242+
class IndexTrait extends Trait {
243+
pragma[nomagic]
244+
IndexTrait() { this.getCanonicalPath() = "core::ops::index::Index" }
245+
246+
/** Gets the `index` function. */
247+
Function getIndexFunction() { result = this.(TraitItemNode).getAssocItem("index") }
248+
249+
/** Gets the `Output` associated type. */
250+
pragma[nomagic]
251+
TypeAlias getOutputType() {
252+
result = this.getAssocItemList().getAnAssocItem() and
253+
result.getName().getText() = "Output"
254+
}
255+
}
256+
257+
/**
258+
* The [`Box` struct][1].
259+
*
260+
* [1]: https://doc.rust-lang.org/std/boxed/struct.Box.html
261+
*/
262+
class BoxStruct extends Struct {
263+
pragma[nomagic]
264+
BoxStruct() { this.getCanonicalPath() = "alloc::boxed::Box" }
265+
}
266+
267+
/**
268+
* The [`Rc` struct][1].
269+
*
270+
* [1]: https://doc.rust-lang.org/std/rc/struct.Rc.html
271+
*/
272+
class RcStruct extends Struct {
273+
pragma[nomagic]
274+
RcStruct() { this.getCanonicalPath() = "alloc::rc::Rc" }
275+
}
276+
277+
/**
278+
* The [`Arc` struct][1].
279+
*
280+
* [1]: https://doc.rust-lang.org/std/sync/struct.Arc.html
281+
*/
282+
class ArcStruct extends Struct {
283+
pragma[nomagic]
284+
ArcStruct() { this.getCanonicalPath() = "alloc::sync::Arc" }
285+
}
286+
287+
/**
288+
* The [`Pin` struct][1].
289+
*
290+
* [1]: https://doc.rust-lang.org/std/pin/struct.Pin.html
291+
*/
292+
class PinStruct extends Struct {
293+
pragma[nomagic]
294+
PinStruct() { this.getCanonicalPath() = "core::pin::Pin" }
295+
}
Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
private import rust
2+
private import TypeInference
3+
private import PathResolution
4+
private import Type
5+
private import TypeMention
6+
private import codeql.rust.elements.Call
7+
8+
private newtype TFunctionTypePosition =
9+
TArgumentFunctionTypePosition(ArgumentPosition pos) or
10+
TReturnFunctionTypePosition()
11+
12+
/**
13+
* A position of a type related to a function.
14+
*
15+
* Either `self`, `return`, or a positional parameter index.
16+
*/
17+
class FunctionTypePosition extends TFunctionTypePosition {
18+
predicate isSelf() { this.asArgumentPosition().isSelf() }
19+
20+
int asPositional() { result = this.asArgumentPosition().asPosition() }
21+
22+
predicate isPositional() { exists(this.asPositional()) }
23+
24+
ArgumentPosition asArgumentPosition() { this = TArgumentFunctionTypePosition(result) }
25+
26+
predicate isReturn() { this = TReturnFunctionTypePosition() }
27+
28+
/** Gets the corresponding position when `f` is invoked via a function call. */
29+
bindingset[f]
30+
FunctionTypePosition getFunctionCallAdjusted(Function f) {
31+
this.isReturn() and
32+
result = this
33+
or
34+
if f.hasSelfParam()
35+
then
36+
this.isSelf() and result.asPositional() = 0
37+
or
38+
result.asPositional() = this.asPositional() + 1
39+
else result = this
40+
}
41+
42+
TypeMention getTypeMention(Function f) {
43+
this.isSelf() and
44+
result = getSelfParamTypeMention(f.getSelfParam())
45+
or
46+
result = f.getParam(this.asPositional()).getTypeRepr()
47+
or
48+
this.isReturn() and
49+
result = f.getRetType().getTypeRepr()
50+
}
51+
52+
string toString() {
53+
result = this.asArgumentPosition().toString()
54+
or
55+
this.isReturn() and
56+
result = "(return)"
57+
}
58+
}
59+
60+
pragma[nomagic]
61+
predicate functionTypeAtPath(Function f, FunctionTypePosition pos, TypePath path, Type type) {
62+
type = pos.getTypeMention(f).resolveTypeAt(path)
63+
}
64+
65+
/**
66+
* A helper module for implementing `Matching(WithEnvironment)InputSig` with
67+
* `DeclarationPosition = AccessPosition = FunctionTypePosition`.
68+
*/
69+
module FunctionTypePositionMatchingInput {
70+
class DeclarationPosition = FunctionTypePosition;
71+
72+
class AccessPosition = DeclarationPosition;
73+
74+
predicate accessDeclarationPositionMatch(AccessPosition apos, DeclarationPosition dpos) {
75+
apos = dpos
76+
}
77+
}
78+
79+
private newtype TFunctionType =
80+
MkFunctionType(Function f, FunctionTypePosition pos, ImplOrTraitItemNode i) {
81+
f = i.getAnAssocItem() and
82+
exists(pos.getTypeMention(f))
83+
} or
84+
MkInheritedFunctionType(
85+
Function f, FunctionTypePosition pos, ImplOrTraitItemNode parent, ImplOrTraitItemNode i
86+
) {
87+
exists(FunctionType inherited |
88+
inherited.appliesTo(f, pos, parent) and
89+
f = i.getASuccessor(_)
90+
|
91+
parent = i.(ImplItemNode).resolveTraitTy()
92+
or
93+
parent = i.(TraitItemNode).resolveABound()
94+
)
95+
}
96+
97+
/**
98+
* The type of a function at a given position, when viewed as a member of a given
99+
* trait or `impl` block.
100+
*
101+
* Example:
102+
*
103+
* ```rust
104+
* trait T1 {
105+
* fn m1(self); // self1
106+
*
107+
* fn m2(self) { ... } // self2
108+
* }
109+
*
110+
* trait T2 : T1 {
111+
* fn m3(self); // self3
112+
* }
113+
*
114+
* impl T2 for X {
115+
* fn m1(self) { ... } // self4
116+
*
117+
* fn m3(self) { ... } // self5
118+
* }
119+
* ```
120+
*
121+
* param | `impl` or trait | type
122+
* ------- | --------------- | ----
123+
* `self1` | `trait T1` | `T1`
124+
* `self1` | `trait T2` | `T2`
125+
* `self2` | `trait T1` | `T1`
126+
* `self2` | `trait T2` | `T2`
127+
* `self2` | `impl T2 for X` | `X`
128+
* `self3` | `trait T2` | `T2`
129+
* `self4` | `impl T2 for X` | `X`
130+
* `self5` | `impl T2 for X` | `X`
131+
*/
132+
class FunctionType extends TFunctionType {
133+
private predicate isFunctionType(Function f, FunctionTypePosition pos, ImplOrTraitItemNode i) {
134+
this = MkFunctionType(f, pos, i)
135+
}
136+
137+
private predicate isInheritedFunctionType(
138+
Function f, FunctionTypePosition pos, ImplOrTraitItemNode parent, ImplOrTraitItemNode i
139+
) {
140+
this = MkInheritedFunctionType(f, pos, parent, i)
141+
}
142+
143+
/**
144+
* Holds if this function type applies to the function `f` at position `pos`,
145+
* when viewed as a member of the `impl` or trait item `i`.
146+
*/
147+
predicate appliesTo(Function f, FunctionTypePosition pos, ImplOrTraitItemNode i) {
148+
this.isFunctionType(f, pos, i)
149+
or
150+
this.isInheritedFunctionType(f, pos, _, i)
151+
}
152+
153+
pragma[nomagic]
154+
private Type getTypeAt0(TypePath path) {
155+
exists(Function f, FunctionTypePosition pos |
156+
this.isFunctionType(f, pos, _) and
157+
functionTypeAtPath(f, pos, path, result)
158+
)
159+
or
160+
exists(
161+
Function f, FunctionTypePosition pos, FunctionType parentType, ImplOrTraitItemNode parent,
162+
ImplOrTraitItemNode i
163+
|
164+
this.isInheritedFunctionType(f, pos, parent, i) and
165+
parentType.appliesTo(f, pos, parent)
166+
|
167+
result = parentType.getTypeAt0(path) and
168+
not result instanceof TSelfTypeParameter
169+
or
170+
exists(TypePath prefix, TypePath suffix |
171+
parentType.hasSelfTypeParameterAt(prefix) and
172+
result = resolveImplOrTraitType(i, suffix) and
173+
path = prefix.append(suffix)
174+
)
175+
)
176+
}
177+
178+
pragma[nomagic]
179+
private predicate hasSelfTypeParameterAt(TypePath path) {
180+
this.getTypeAt0(path) = TSelfTypeParameter(_)
181+
}
182+
183+
/**
184+
* Gets the type of this function at the given position and path.
185+
*
186+
* For functions belonging to a `trait`, we use the type of the trait itself instead
187+
* of the implicit `Self` type parameter, as otherwise any type will match.
188+
*
189+
* Calls should use `substituteLookupTraits` to map receiver types to the relevant
190+
* traits when matching.
191+
*/
192+
Type getTypeAt(TypePath path) {
193+
exists(Type t | t = this.getTypeAt0(path) |
194+
not t instanceof SelfTypeParameter and
195+
result = t
196+
or
197+
result = TTrait(t.(SelfTypeParameter).getTrait())
198+
)
199+
}
200+
201+
private AstNode getReportingNode() {
202+
exists(Function f, FunctionTypePosition pos | this.appliesTo(f, pos, _) |
203+
pos.isSelf() and
204+
exists(SelfParam self | self = f.getSelfParam() |
205+
result = self.getTypeRepr()
206+
or
207+
not self.hasTypeRepr() and
208+
result = self
209+
)
210+
or
211+
result = f.getParam(pos.asPositional()).getTypeRepr()
212+
or
213+
pos.isReturn() and
214+
result = f.getRetType().getTypeRepr()
215+
)
216+
}
217+
218+
string toString() { result = this.getReportingNode().toString() }
219+
220+
Location getLocation() { result = this.getReportingNode().getLocation() }
221+
}
222+
223+
private Trait getALookupTrait(Type t) {
224+
result = t.(TypeParamTypeParameter).getTypeParam().(TypeParamItemNode).resolveABound()
225+
or
226+
result = t.(SelfTypeParameter).getTrait()
227+
or
228+
result = t.(ImplTraitType).getImplTraitTypeRepr().(ImplTraitTypeReprItemNode).resolveABound()
229+
or
230+
result = t.(DynTraitType).getTrait()
231+
}
232+
233+
/**
234+
* Gets the type obtained by substituting in relevant traits in which to do function
235+
* lookup, or `t` itself when no such trait exist.
236+
*/
237+
bindingset[t]
238+
Type substituteLookupTraits(Type t) {
239+
not exists(getALookupTrait(t)) and
240+
result = t
241+
or
242+
result = TTrait(getALookupTrait(t))
243+
}

rust/ql/lib/codeql/rust/internal/PathResolution.qll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -810,7 +810,7 @@ final class ImplItemNode extends ImplOrTraitItemNode instanceof Impl {
810810
}
811811
}
812812

813-
final private class ImplTraitTypeReprItemNode extends TypeItemNode instanceof ImplTraitTypeRepr {
813+
final class ImplTraitTypeReprItemNode extends TypeItemNode instanceof ImplTraitTypeRepr {
814814
pragma[nomagic]
815815
Path getABoundPath() {
816816
result = super.getTypeBoundList().getABound().getTypeRepr().(PathTypeRepr).getPath()
@@ -1914,7 +1914,7 @@ private predicate builtin(string name, ItemNode i) {
19141914

19151915
/** Provides predicates for debugging the path resolution implementation. */
19161916
private module Debug {
1917-
private Locatable getRelevantLocatable() {
1917+
Locatable getRelevantLocatable() {
19181918
exists(string filepath, int startline, int startcolumn, int endline, int endcolumn |
19191919
result.getLocation().hasLocationInfo(filepath, startline, startcolumn, endline, endcolumn) and
19201920
filepath.matches("%/main.rs") and

0 commit comments

Comments
 (0)