Skip to content

Commit 32e3727

Browse files
committed
wip
1 parent 44fdb66 commit 32e3727

File tree

8 files changed

+371
-195
lines changed

8 files changed

+371
-195
lines changed

rust/ql/lib/codeql/rust/frameworks/stdlib/Stdlib.qll

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,5 +303,6 @@ class Vec extends Struct {
303303
pragma[nomagic]
304304
Vec() { this.getCanonicalPath() = "alloc::vec::Vec" }
305305

306+
/** Gets the type parameter representing the element type. */
306307
TypeParam getElementTypeParam() { result = this.getGenericParamList().getTypeParam(0) }
307308
}

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 188 additions & 159 deletions
Large diffs are not rendered by default.

rust/ql/lib/codeql/rust/internal/typeinference/BlanketImplementation.qll

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,9 @@ module SatisfiesBlanketConstraint<
118118
predicate useUniversalConditions() { none() }
119119
}
120120

121+
private module SatisfiesBlanketConstraint =
122+
SatisfiesConstraint<ArgumentTypeAndBlanketOffset, SatisfiesBlanketConstraintInput>;
123+
121124
/**
122125
* Holds if the argument type `at` satisfies the first non-trivial blanket
123126
* constraint of `impl`.
@@ -127,8 +130,7 @@ module SatisfiesBlanketConstraint<
127130
exists(ArgumentTypeAndBlanketOffset ato, Trait traitBound |
128131
ato = MkArgumentTypeAndBlanketOffset(at, _) and
129132
SatisfiesBlanketConstraintInput::relevantConstraint(ato, impl, traitBound) and
130-
SatisfiesConstraint<ArgumentTypeAndBlanketOffset, SatisfiesBlanketConstraintInput>::satisfiesConstraintType(ato,
131-
TTrait(traitBound), _, _)
133+
SatisfiesBlanketConstraint::satisfiesConstraintType(ato, TTrait(traitBound), _, _)
132134
)
133135
}
134136
}

rust/ql/lib/codeql/rust/internal/typeinference/FunctionOverloading.qll

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -117,27 +117,11 @@ predicate functionResolutionDependsOnArgument(
117117
* method. In that case we will still resolve several methods.
118118
*/
119119

120-
exists(TraitItemNode trait, string functionName, TypePath path0, Type type0 |
120+
exists(TraitItemNode trait, string functionName |
121121
implHasSibling(impl, trait) and
122-
traitTypeParameterOccurrence(trait, _, functionName, pos, path0, _) and
123-
functionTypeAtPath(f, pos, path0, type0) and
122+
traitTypeParameterOccurrence(trait, _, functionName, pos, path, _) and
123+
functionTypeAtPath(f, pos, path, type) and
124124
f = impl.getASuccessor(functionName) and
125125
not pos.isReturn()
126-
|
127-
exists(TypeParameter tp0, TypePath path1, Type type1 |
128-
complexSelfRoot(type0, tp0) and
129-
path1 = path0.append(TypePath::singleton(tp0)) and
130-
functionTypeAtPath(f, pos, path1, type1)
131-
|
132-
if type1 instanceof TypeParameter
133-
then path = path0 and type = type0
134-
else (
135-
path = path1 and type = type1
136-
)
137-
)
138-
or
139-
not complexSelfRoot(type0, _) and
140-
path = path0 and
141-
type = type0
142126
)
143127
}

rust/ql/lib/codeql/rust/internal/typeinference/FunctionType.qll

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,3 +241,163 @@ Type substituteLookupTraits(Type t) {
241241
or
242242
result = TTrait(getALookupTrait(t))
243243
}
244+
245+
/**
246+
* A wrapper around `IsInstantiationOf` which ensures to substitute in lookup
247+
* traits when checking whether argument types are instantiations of function
248+
* types.
249+
*/
250+
module ArgIsInstantiationOf<HasTypeTreeSig Arg, IsInstantiationOfInputSig<Arg, FunctionType> Input> {
251+
final private class ArgFinal = Arg;
252+
253+
private class ArgSubst extends ArgFinal {
254+
Type getTypeAt(TypePath path) { result = substituteLookupTraits(super.getTypeAt(path)) }
255+
}
256+
257+
private module IsInstantiationOfInput implements IsInstantiationOfInputSig<ArgSubst, FunctionType>
258+
{
259+
pragma[nomagic]
260+
predicate potentialInstantiationOf(ArgSubst arg, TypeAbstraction abs, FunctionType constraint) {
261+
Input::potentialInstantiationOf(arg, abs, constraint)
262+
}
263+
264+
predicate relevantTypeMention(FunctionType constraint) {
265+
Input::relevantTypeMention(constraint)
266+
}
267+
}
268+
269+
private module ArgSubstIsInstantiationOf =
270+
IsInstantiationOf<ArgSubst, FunctionType, IsInstantiationOfInput>;
271+
272+
predicate argIsInstantiationOf(Arg arg, ImplOrTraitItemNode i, FunctionType constraint) {
273+
ArgSubstIsInstantiationOf::isInstantiationOf(arg, i, constraint)
274+
}
275+
276+
predicate argIsNotInstantiationOf(Arg arg, ImplOrTraitItemNode i, FunctionType constraint) {
277+
ArgSubstIsInstantiationOf::isNotInstantiationOf(arg, i, constraint)
278+
}
279+
}
280+
281+
/**
282+
* Provides the input for `ArgsAreInstantiationsOf`.
283+
*/
284+
signature module ArgsAreInstantiationsOfInputSig {
285+
/**
286+
* Holds if types need to matched against the type `t` at position `pos` of
287+
* `f` inside `i`.
288+
*/
289+
predicate toCheck(ImplOrTraitItemNode i, Function f, FunctionTypePosition pos, FunctionType t);
290+
291+
/** A call whose argument types are to be checked. */
292+
class Call {
293+
string toString();
294+
295+
Location getLocation();
296+
297+
Type getArgType(FunctionTypePosition pos, TypePath path);
298+
299+
predicate hasTargetCand(ImplOrTraitItemNode i, Function f);
300+
}
301+
}
302+
303+
/**
304+
* Provides logic for checking that a set of arguments have types that are
305+
* instantiations of the types at the corresponding positions in a function
306+
* type.
307+
*/
308+
module ArgsAreInstantiationsOf<ArgsAreInstantiationsOfInputSig Input> {
309+
pragma[nomagic]
310+
private predicate toCheckRanked(
311+
ImplOrTraitItemNode i, Function f, FunctionTypePosition pos, int rnk
312+
) {
313+
Input::toCheck(i, f, pos, _) and
314+
pos =
315+
rank[rnk + 1](FunctionTypePosition pos0, int j |
316+
Input::toCheck(i, f, pos0, _) and
317+
(
318+
j = pos0.asPositional()
319+
or
320+
pos0.isSelf() and j = -1
321+
or
322+
pos0.isReturn() and j = -2
323+
)
324+
|
325+
pos0 order by j
326+
)
327+
}
328+
329+
private newtype TCallAndPos =
330+
MkCallAndPos(Input::Call call, FunctionTypePosition pos) { exists(call.getArgType(pos, _)) }
331+
332+
/** A call tagged with a position. */
333+
private class CallAndPos extends MkCallAndPos {
334+
Input::Call call;
335+
FunctionTypePosition pos;
336+
337+
CallAndPos() { this = MkCallAndPos(call, pos) }
338+
339+
Input::Call getCall() { result = call }
340+
341+
FunctionTypePosition getPos() { result = pos }
342+
343+
Location getLocation() { result = call.getLocation() }
344+
345+
Type getTypeAt(TypePath path) { result = call.getArgType(pos, path) }
346+
347+
string toString() { result = call.toString() + " [arg " + pos + "]" }
348+
}
349+
350+
private module ArgIsInstantiationOfInput implements
351+
IsInstantiationOfInputSig<CallAndPos, FunctionType>
352+
{
353+
pragma[nomagic]
354+
private predicate potentialInstantiationOf0(
355+
CallAndPos cp, Input::Call call, FunctionTypePosition pos, int rnk, Function f,
356+
TypeAbstraction abs, FunctionType constraint
357+
) {
358+
cp = MkCallAndPos(call, pos) and
359+
call.hasTargetCand(abs, f) and
360+
toCheckRanked(abs, f, pos, rnk) and
361+
Input::toCheck(abs, f, pos, constraint)
362+
}
363+
364+
pragma[nomagic]
365+
predicate potentialInstantiationOf(CallAndPos cp, TypeAbstraction abs, FunctionType constraint) {
366+
exists(Input::Call call, FunctionTypePosition pos, int rnk, Function f |
367+
potentialInstantiationOf0(cp, call, pos, rnk, f, abs, constraint)
368+
|
369+
rnk = 0
370+
or
371+
argsAreInstantiationsOfFromIndex(call, abs, f, rnk - 1)
372+
)
373+
}
374+
375+
predicate relevantTypeMention(FunctionType constraint) { Input::toCheck(_, _, _, constraint) }
376+
}
377+
378+
private module ArgIsInstantiationOfFromIndex =
379+
ArgIsInstantiationOf<CallAndPos, ArgIsInstantiationOfInput>;
380+
381+
pragma[nomagic]
382+
private predicate argsAreInstantiationsOfFromIndex(
383+
Input::Call call, ImplOrTraitItemNode i, Function f, int rnk
384+
) {
385+
exists(FunctionTypePosition pos |
386+
ArgIsInstantiationOfFromIndex::argIsInstantiationOf(MkCallAndPos(call, pos), i, _) and
387+
call.hasTargetCand(i, f) and
388+
toCheckRanked(i, f, pos, rnk)
389+
)
390+
}
391+
392+
/**
393+
* Holds if all arguments of `call` have types that are instantiations of the
394+
* types of the corresponding parameters of `f` inside `i`.
395+
*/
396+
pragma[nomagic]
397+
predicate argsAreInstantiationsOf(Input::Call call, ImplOrTraitItemNode i, Function f) {
398+
exists(int rnk |
399+
argsAreInstantiationsOfFromIndex(call, i, f, rnk) and
400+
rnk = max(int r | toCheckRanked(i, f, _, r))
401+
)
402+
}
403+
}

rust/ql/test/library-tests/type-inference/CONSISTENCY/PathResolutionConsistency.expected

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,6 @@ multipleCallTargets
55
| dereference.rs:184:17:184:30 | ... .foo() |
66
| dereference.rs:186:17:186:25 | S.bar(...) |
77
| dereference.rs:187:17:187:29 | S.bar(...) |
8-
| main.rs:2383:9:2383:34 | ...::my_from2(...) |
9-
| main.rs:2384:9:2384:33 | ...::my_from2(...) |
10-
| main.rs:2385:9:2385:38 | ...::my_from2(...) |
118
| main.rs:2437:13:2437:31 | ...::from(...) |
129
| main.rs:2438:13:2438:31 | ...::from(...) |
1310
| main.rs:2439:13:2439:31 | ...::from(...) |

rust/ql/test/library-tests/type-inference/main.rs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -698,9 +698,9 @@ mod function_trait_bounds {
698698
T2::assoc(x) // $ target=assoc
699699
}
700700
fn call_trait_assoc_2<T1, T2: MyTrait<T1> + Copy>(x: T2) -> T1 {
701-
let y = MyTrait::assoc(x); // $ MISSING: target=assoc
702-
y; // $ MISSING: type=y:T1
703-
MyTrait::assoc(x) // $ MISSING: target=assoc
701+
let y = MyTrait::assoc(x); // $ target=assoc
702+
y; // $ type=y:T1
703+
MyTrait::assoc(x) // $ target=assoc
704704
}
705705

706706
// Type parameter with bound occurs nested within another type.
@@ -2380,9 +2380,9 @@ mod method_determined_by_argument_type {
23802380
let x = i64::my_from(73i64); // $ target=MyFrom<i64>::my_from
23812381
let y = i64::my_from(true); // $ target=MyFrom<bool>::my_from
23822382
let z: i64 = MyFrom::my_from(73i64); // $ target=MyFrom<i64>::my_from
2383-
i64::my_from2(73i64, 0i64); // $ target=MyFrom2<i64>::my_from2 $ SPURIOUS: target=MyFrom2<bool>::my_from2
2384-
i64::my_from2(true, 0i64); // $ target=MyFrom2<bool>::my_from2 $ SPURIOUS: target=MyFrom2<i64>::my_from2
2385-
MyFrom2::my_from2(73i64, 0i64); // $ target=MyFrom2<i64>::my_from2 $ SPURIOUS: target=MyFrom2<bool>::my_from2
2383+
i64::my_from2(73i64, 0i64); // $ target=MyFrom2<i64>::my_from2
2384+
i64::my_from2(true, 0i64); // $ target=MyFrom2<bool>::my_from2
2385+
MyFrom2::my_from2(73i64, 0i64); // $ target=MyFrom2<i64>::my_from2
23862386

23872387
i64::f1(73i64); // $ target=MySelfTrait<i64>::f1
23882388
i64::f2(73i64); // $ target=MySelfTrait<i64>::f2
@@ -2446,9 +2446,9 @@ mod loops {
24462446
String::from("bar"), // $ target=from
24472447
String::from("baz"), // $ target=from
24482448
];
2449-
for s in strings3 {} // $ MISSING: type=s:String
2449+
for s in strings3 {} // $ type=s:&T.String
24502450

2451-
let callables = [MyCallable::new(), MyCallable::new(), MyCallable::new()]; // $ target=new $ MISSING: type=callables:[T;...].MyCallable; 3
2451+
let callables = [MyCallable::new(), MyCallable::new(), MyCallable::new()]; // $ target=new $ type=callables:[T;...].MyCallable
24522452
for c // $ type=c:MyCallable
24532453
in callables
24542454
{
@@ -2502,10 +2502,10 @@ mod loops {
25022502
let mut map1 = std::collections::HashMap::new(); // $ target=new type=map1:K.i32 type=map1:V.Box $ MISSING: type=map1:Hashmap type1=map1:V.T.&T.str
25032503
map1.insert(1, Box::new("one")); // $ target=insert target=new
25042504
map1.insert(2, Box::new("two")); // $ target=insert target=new
2505-
for key in map1.keys() {} // $ target=keys MISSING: type=key:i32
2506-
for value in map1.values() {} // $ target=values MISSING: type=value:Box type=value:T.&T.str
2507-
for (key, value) in map1.iter() {} // $ target=iter MISSING: type=key:i32 type=value:Box type=value:T.&T.str
2508-
for (key, value) in &map1 {} // $ MISSING: type=key:i32 type=value:Box type=value:T.&T.str
2505+
for key in map1.keys() {} // $ target=keys type=key:&T.i32
2506+
for value in map1.values() {} // $ target=values type=value:&T.Box type=value:&T.T.&T.str
2507+
for (key, value) in map1.iter() {} // $ target=iter type=key:&T.i32 type=value:&T.Box type=value:&T.T.&T.str
2508+
for (key, value) in &map1 {} // $ type=key:&T.i32 type=value:&T.Box type=value:&T.T.&T.str
25092509

25102510
// while loops
25112511

rust/ql/test/library-tests/type-inference/type-inference.expected

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1992,7 +1992,10 @@ inferType
19921992
| main.rs:698:19:698:19 | x | | main.rs:695:31:695:52 | T2 |
19931993
| main.rs:700:55:700:55 | x | | main.rs:700:31:700:52 | T2 |
19941994
| main.rs:700:68:704:5 | { ... } | | main.rs:700:27:700:28 | T1 |
1995+
| main.rs:701:13:701:13 | y | | main.rs:700:27:700:28 | T1 |
1996+
| main.rs:701:17:701:33 | ...::assoc(...) | | main.rs:700:27:700:28 | T1 |
19951997
| main.rs:701:32:701:32 | x | | main.rs:700:31:700:52 | T2 |
1998+
| main.rs:702:9:702:9 | y | | main.rs:700:27:700:28 | T1 |
19961999
| main.rs:703:9:703:25 | ...::assoc(...) | | main.rs:700:27:700:28 | T1 |
19972000
| main.rs:703:24:703:24 | x | | main.rs:700:31:700:52 | T2 |
19982001
| main.rs:708:49:708:49 | x | | main.rs:656:5:659:5 | MyThing |

0 commit comments

Comments
 (0)