Skip to content

Commit bd8665a

Browse files
committed
wip
1 parent fc9ed22 commit bd8665a

File tree

5 files changed

+140
-81
lines changed

5 files changed

+140
-81
lines changed

rust/ql/lib/codeql/rust/internal/PathResolution.qll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2137,7 +2137,7 @@ private module Debug {
21372137
exists(string filepath, int startline, int startcolumn, int endline, int endcolumn |
21382138
result.getLocation().hasLocationInfo(filepath, startline, startcolumn, endline, endcolumn) and
21392139
filepath.matches("%/main.rs") and
2140-
startline = 52
2140+
startline = 2909
21412141
)
21422142
}
21432143

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 124 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -263,8 +263,11 @@ pragma[nomagic]
263263
private TypeMention getCallExprTypeArgument(CallExpr ce, TypeArgumentPosition apos) {
264264
exists(Path p, int i |
265265
p = CallExprImpl::getFunctionPath(ce) and
266-
result = p.getSegment().getGenericArgList().getTypeArg(pragma[only_bind_into](i)) and
267266
apos.asTypeParam() = resolvePath(p).getTypeParam(pragma[only_bind_into](i))
267+
|
268+
result = p.getSegment().getGenericArgList().getTypeArg(pragma[only_bind_into](i))
269+
or
270+
result = p.(NonAliasPathTypeMention).getPathPositionalTypeArgument(pragma[only_bind_into](i))
268271
)
269272
}
270273

@@ -748,6 +751,8 @@ private Type inferTypeEquality(AstNode n, TypePath path) {
748751
/**
749752
* A matching configuration for resolving types of struct expressions
750753
* like `Foo { bar = baz }`.
754+
*
755+
* This also includes nullary struct expressions like `None`.
751756
*/
752757
private module StructExprMatchingInput implements MatchingInputSig {
753758
private newtype TPos =
@@ -830,26 +835,86 @@ private module StructExprMatchingInput implements MatchingInputSig {
830835

831836
class AccessPosition = DeclarationPosition;
832837

833-
class Access extends StructExpr {
838+
abstract class Access extends AstNode {
839+
pragma[nomagic]
840+
abstract AstNode getNodeAt(AccessPosition apos);
841+
842+
pragma[nomagic]
843+
Type getInferredType(AccessPosition apos, TypePath path) {
844+
result = inferType(this.getNodeAt(apos), path)
845+
}
846+
847+
pragma[nomagic]
848+
abstract Path getStructPath();
849+
850+
pragma[nomagic]
851+
Declaration getTarget() { result = resolvePath(this.getStructPath()) }
852+
853+
pragma[nomagic]
834854
Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
855+
exists(TypeMention tm, TypePath path0 |
856+
tm = this.getStructPath() and
857+
result = tm.resolveTypeAt(path0) and
858+
path0.isCons(TTypeParamTypeParameter(apos.asTypeParam()), path)
859+
)
860+
}
861+
862+
/**
863+
* Holds if the return type of this call at `path` may have to be inferred
864+
* from the context.
865+
*/
866+
pragma[nomagic]
867+
predicate isContextTypedAt(DeclarationPosition pos, TypePath path) {
868+
// Struct declarations, such as `Foo::Bar{field = ...}`, may also be context typed
869+
exists(Declaration td, TypeParameter tp |
870+
td = this.getTarget() and
871+
pos.isStructPos() and
872+
tp = td.getDeclaredType(pos, path) and
873+
not exists(DeclarationPosition paramDpos |
874+
not paramDpos.isStructPos() and
875+
tp = td.getDeclaredType(paramDpos, _)
876+
) and
877+
// check that no explicit type arguments have been supplied for `tp`
878+
not exists(TypeArgumentPosition tapos |
879+
exists(this.getTypeArgument(tapos, _)) and
880+
TTypeParamTypeParameter(tapos.asTypeParam()) = tp
881+
)
882+
)
883+
}
884+
}
885+
886+
private class StructExprAccess extends Access, StructExpr {
887+
override Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
888+
result = super.getTypeArgument(apos, path)
889+
or
835890
exists(TypePath suffix |
836891
suffix.isCons(TTypeParamTypeParameter(apos.asTypeParam()), path) and
837892
result = CertainTypeInference::inferCertainType(this, suffix)
838893
)
839894
}
840895

841-
AstNode getNodeAt(AccessPosition apos) {
896+
override AstNode getNodeAt(AccessPosition apos) {
842897
result = this.getFieldExpr(apos.asFieldPos()).getExpr()
843898
or
844899
result = this and
845900
apos.isStructPos()
846901
}
847902

848-
Type getInferredType(AccessPosition apos, TypePath path) {
849-
result = inferType(this.getNodeAt(apos), path)
903+
override Path getStructPath() { result = this.getPath() }
904+
}
905+
906+
/**
907+
* A potential nullary struct/variant construction such as `None`.
908+
*/
909+
private class PathExprAccess extends Access, PathExpr {
910+
PathExprAccess() { not exists(CallExpr ce | this = ce.getFunction()) }
911+
912+
override AstNode getNodeAt(AccessPosition apos) {
913+
result = this and
914+
apos.isStructPos()
850915
}
851916

852-
Declaration getTarget() { result = resolvePath(this.getPath()) }
917+
override Path getStructPath() { result = this.getPath() }
853918
}
854919

855920
predicate accessDeclarationPositionMatch(AccessPosition apos, DeclarationPosition dpos) {
@@ -859,36 +924,32 @@ private module StructExprMatchingInput implements MatchingInputSig {
859924

860925
private module StructExprMatching = Matching<StructExprMatchingInput>;
861926

862-
/**
863-
* Gets the type of `n` at `path`, where `n` is either a struct expression or
864-
* a field expression of a struct expression.
865-
*/
866927
pragma[nomagic]
867-
private Type inferStructExprType(AstNode n, TypePath path) {
928+
private Type inferStructExprType0(AstNode n, boolean isReturn, TypePath path) {
868929
exists(StructExprMatchingInput::Access a, StructExprMatchingInput::AccessPosition apos |
869930
n = a.getNodeAt(apos) and
931+
if apos.isStructPos() then isReturn = true else isReturn = false
932+
|
870933
result = StructExprMatching::inferAccessType(a, apos, path)
934+
or
935+
a.isContextTypedAt(apos, path) and
936+
result = TContextType()
871937
)
872938
}
873939

940+
/**
941+
* Gets the type of `n` at `path`, where `n` is either a struct expression or
942+
* a field expression of a struct expression.
943+
*/
944+
private predicate inferStructExprType =
945+
ContextTyping::CheckContextTyping<inferStructExprType0/3>::check/2;
946+
874947
pragma[nomagic]
875948
private Type inferTupleRootType(AstNode n) {
876949
// `typeEquality` handles the non-root cases
877950
result = TTuple([n.(TupleExpr).getNumberOfFields(), n.(TuplePat).getTupleArity()])
878951
}
879952

880-
pragma[nomagic]
881-
private Type inferPathExprType(PathExpr pe, TypePath path) {
882-
// nullary struct/variant constructors
883-
not exists(CallExpr ce | pe = ce.getFunction()) and
884-
path.isEmpty() and
885-
exists(ItemNode i | i = resolvePath(pe.getPath()) |
886-
result = TEnum(i.(Variant).getEnum())
887-
or
888-
result = TStruct(i)
889-
)
890-
}
891-
892953
pragma[nomagic]
893954
private Path getCallExprPathQualifier(CallExpr ce) {
894955
result = CallExprImpl::getFunctionPath(ce).getQualifier()
@@ -982,7 +1043,7 @@ private module ContextTyping {
9821043
pragma[nomagic]
9831044
private predicate isContextTyped(AstNode n) { isContextTyped(n, _) }
9841045

985-
signature Type inferCallTypeSig(AstNode n, FunctionPosition pos, TypePath path);
1046+
signature Type inferCallTypeSig(AstNode n, boolean isReturn, TypePath path);
9861047

9871048
/**
9881049
* Given a predicate `inferCallType` for inferring the type of a call at a given
@@ -992,30 +1053,24 @@ private module ContextTyping {
9921053
*/
9931054
module CheckContextTyping<inferCallTypeSig/3 inferCallType> {
9941055
pragma[nomagic]
995-
private Type inferCallTypeFromContextCand(
996-
AstNode n, FunctionPosition pos, TypePath path, TypePath prefix
997-
) {
998-
result = inferCallType(n, pos, path) and
999-
not pos.isReturn() and
1056+
private Type inferCallTypeFromContextCand(AstNode n, TypePath path, TypePath prefix) {
1057+
result = inferCallType(n, false, path) and
10001058
isContextTyped(n) and
10011059
prefix = path
10021060
or
10031061
exists(TypePath mid |
1004-
result = inferCallTypeFromContextCand(n, pos, path, mid) and
1062+
result = inferCallTypeFromContextCand(n, path, mid) and
10051063
mid.isSnoc(prefix, _)
10061064
)
10071065
}
10081066

10091067
pragma[nomagic]
10101068
Type check(AstNode n, TypePath path) {
1011-
exists(FunctionPosition pos |
1012-
result = inferCallType(n, pos, path) and
1013-
pos.isReturn()
1014-
or
1015-
exists(TypePath prefix |
1016-
result = inferCallTypeFromContextCand(n, pos, path, prefix) and
1017-
isContextTyped(n, prefix)
1018-
)
1069+
result = inferCallType(n, true, path)
1070+
or
1071+
exists(TypePath prefix |
1072+
result = inferCallTypeFromContextCand(n, path, prefix) and
1073+
isContextTyped(n, prefix)
10191074
)
10201075
}
10211076
}
@@ -2131,11 +2186,13 @@ private Type inferMethodCallType0(
21312186
}
21322187

21332188
pragma[nomagic]
2134-
private Type inferMethodCallType1(
2135-
AstNode n, MethodCallMatchingInput::AccessPosition apos, TypePath path
2136-
) {
2137-
exists(MethodCallMatchingInput::Access a, string derefChainBorrow, TypePath path0 |
2138-
result = inferMethodCallType0(a, apos, n, derefChainBorrow, path0)
2189+
private Type inferMethodCallType1(AstNode n, boolean isReturn, TypePath path) {
2190+
exists(
2191+
MethodCallMatchingInput::Access a, MethodCallMatchingInput::AccessPosition apos,
2192+
string derefChainBorrow, TypePath path0
2193+
|
2194+
result = inferMethodCallType0(a, apos, n, derefChainBorrow, path0) and
2195+
if apos.isReturn() then isReturn = true else isReturn = false
21392196
|
21402197
(
21412198
not apos.isSelf()
@@ -2460,6 +2517,9 @@ private module NonMethodResolution {
24602517
/**
24612518
* A matching configuration for resolving types of calls like
24622519
* `foo::bar(baz)` where the target is not a method.
2520+
*
2521+
* This also includes "calls" to tuple variants and tuple structs such
2522+
* as `Result::Ok(42)`.
24632523
*/
24642524
private module NonMethodCallMatchingInput implements MatchingInputSig {
24652525
import FunctionPositionMatchingInput
@@ -2581,6 +2641,12 @@ private module NonMethodCallMatchingInput implements MatchingInputSig {
25812641
pragma[nomagic]
25822642
override Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
25832643
result = getCallExprTypeArgument(this, apos).resolveTypeAt(path)
2644+
// todo: enum variants like `Self(...)`
2645+
}
2646+
2647+
pragma[nomagic]
2648+
AstNode getNodeAt(FunctionPosition pos) {
2649+
result = NonMethodResolution::NonMethodCall.super.getNodeAt(pos)
25842650
}
25852651

25862652
pragma[nomagic]
@@ -2591,26 +2657,23 @@ private module NonMethodCallMatchingInput implements MatchingInputSig {
25912657
result = inferType(this.getNodeAt(apos), path)
25922658
}
25932659

2660+
pragma[nomagic]
25942661
Declaration getTarget() {
2595-
result = this.resolveCallTarget() // potential mutual recursion; resolving some associated function calls requires resolving types
2662+
result = super.resolveCallTarget() // potential mutual recursion; resolving some associated function calls requires resolving types
25962663
}
25972664

2598-
/**
2599-
* Holds if the return type of this call at `path` may have to be inferred
2600-
* from the context.
2601-
*/
26022665
pragma[nomagic]
26032666
predicate isContextTypedAt(FunctionPosition pos, TypePath path) {
26042667
exists(ImplOrTraitItemNode i |
26052668
this.isContextTypedAt(i,
26062669
[
2607-
this.resolveCallTargetViaPathResolution().(NonMethodFunction),
2608-
this.resolveCallTargetViaTypeInference(i),
2609-
this.resolveTraitFunctionViaPathResolution(i)
2670+
super.resolveCallTargetViaPathResolution().(NonMethodFunction),
2671+
super.resolveCallTargetViaTypeInference(i),
2672+
super.resolveTraitFunctionViaPathResolution(i)
26102673
], pos, path)
26112674
)
26122675
or
2613-
// Tuple declarations, such as `None`, may also be context typed
2676+
// Tuple declarations, such as `Result::Ok(...)`, may also be context typed
26142677
exists(TupleDeclaration td, TypeParameter tp |
26152678
td = this.resolveCallTargetViaPathResolution() and
26162679
pos.isReturn() and
@@ -2629,10 +2692,11 @@ private module NonMethodCallMatchingInput implements MatchingInputSig {
26292692
private module NonMethodCallMatching = Matching<NonMethodCallMatchingInput>;
26302693

26312694
pragma[nomagic]
2632-
private Type inferNonMethodCallType0(
2633-
AstNode n, NonMethodCallMatchingInput::AccessPosition apos, TypePath path
2634-
) {
2635-
exists(NonMethodCallMatchingInput::Access a | n = a.getNodeAt(apos) |
2695+
private Type inferNonMethodCallType0(AstNode n, boolean isReturn, TypePath path) {
2696+
exists(NonMethodCallMatchingInput::Access a, NonMethodCallMatchingInput::AccessPosition apos |
2697+
n = a.getNodeAt(apos) and
2698+
if apos.isReturn() then isReturn = true else isReturn = false
2699+
|
26362700
result = NonMethodCallMatching::inferAccessType(a, apos, path)
26372701
or
26382702
a.isContextTypedAt(apos, path) and
@@ -2715,12 +2779,11 @@ private module OperationMatchingInput implements MatchingInputSig {
27152779
private module OperationMatching = Matching<OperationMatchingInput>;
27162780

27172781
pragma[nomagic]
2718-
private Type inferOperationType0(
2719-
AstNode n, OperationMatchingInput::AccessPosition apos, TypePath path
2720-
) {
2721-
exists(OperationMatchingInput::Access a |
2782+
private Type inferOperationType0(AstNode n, boolean isReturn, TypePath path) {
2783+
exists(OperationMatchingInput::Access a, OperationMatchingInput::AccessPosition apos |
27222784
n = a.getNodeAt(apos) and
2723-
result = OperationMatching::inferAccessType(a, apos, path)
2785+
result = OperationMatching::inferAccessType(a, apos, path) and
2786+
if apos.isReturn() then isReturn = true else isReturn = false
27242787
)
27252788
}
27262789

@@ -3488,8 +3551,6 @@ private module Cached {
34883551
or
34893552
result = inferStructExprType(n, path)
34903553
or
3491-
result = inferPathExprType(n, path)
3492-
or
34933554
result = inferMethodCallType(n, path)
34943555
or
34953556
result = inferNonMethodCallType(n, path)

rust/ql/lib/codeql/rust/internal/TypeMention.qll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ class NonAliasPathTypeMention extends PathTypeMention {
147147
* Gets the positional type argument at index `i` that occurs in this path, if
148148
* any.
149149
*/
150-
private TypeMention getPathPositionalTypeArgument(int i) {
150+
TypeMention getPathPositionalTypeArgument(int i) {
151151
result = this.getSegment().getGenericArgList().getTypeArg(i)
152152
or
153153
// `Option::<i32>::Some` is valid in addition to `Option::Some::<i32>`

rust/ql/test/library-tests/type-inference/main.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2908,12 +2908,12 @@ mod context_typed {
29082908
pub fn f() {
29092909
let x = None; // $ type=x:T.i32
29102910
let x: Option<i32> = x;
2911-
let x = Option::<i32>::None; // $ MISSING: type=x:T.i32
2912-
let x = Option::None::<i32>; // $ MISSING: type=x:T.i32
2911+
let x = Option::<i32>::None; // $ type=x:T.i32
2912+
let x = Option::None::<i32>; // $ type=x:T.i32
29132913

29142914
fn pin_option<T>(opt: Option<T>, x: T) {}
29152915

2916-
let x = None; // $ MISSING: type=x:T.i32
2916+
let x = None; // $ type=x:T.i32
29172917
pin_option(x, 0); // $ target=pin_option
29182918

29192919
enum MyEither<T1, T2> {
@@ -2932,7 +2932,7 @@ mod context_typed {
29322932
fn pin_my_either<T>(e: MyEither<T, String>, x: T) {}
29332933

29342934
#[rustfmt::skip]
2935-
let x = MyEither::B { // $ type=x:T2.String $ MISSING: type=x:T1.i32
2935+
let x = MyEither::B { // $ type=x:T1.i32 type=x:T2.String
29362936
right: String::new(), // $ target=new
29372937
};
29382938
pin_my_either(x, 0); // $ target=pin_my_either

0 commit comments

Comments
 (0)