Skip to content

Commit 1d01250

Browse files
committed
Rust: Restrict type propagation into arguments
1 parent 7032f75 commit 1d01250

File tree

10 files changed

+200
-574
lines changed

10 files changed

+200
-574
lines changed

rust/ql/lib/codeql/rust/internal/Type.qll

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ newtype TType =
5151
TSliceType() or
5252
TNeverType() or
5353
TPtrType() or
54+
TContextType() or
5455
TTupleTypeParameter(int arity, int i) { exists(TTuple(arity)) and i in [0 .. arity - 1] } or
5556
TTypeParamTypeParameter(TypeParam t) or
5657
TAssociatedTypeTypeParameter(TypeAlias t) { any(TraitItemNode trait).getAnAssocItem() = t } or
@@ -371,6 +372,26 @@ class PtrType extends Type, TPtrType {
371372
override Location getLocation() { result instanceof EmptyLocation }
372373
}
373374

375+
/**
376+
* A special pseudo type used to indicate that the actual type is to be inferred
377+
* from a context.
378+
*
379+
* For example, a call like `Default::default()` is assigned this type, which
380+
* means that the actual type is to be inferred from the context in which the call
381+
* occurs.
382+
*
383+
* Context types are not restricted to root types, for example in a call like
384+
* `Vec::new()` we assign this type at the type path corresponding to the type
385+
* parameter of `Vec`.
386+
*/
387+
class ContextType extends Type, TContextType {
388+
override TypeParameter getPositionalTypeParameter(int i) { none() }
389+
390+
override string toString() { result = "(context typed)" }
391+
392+
override Location getLocation() { result instanceof EmptyLocation }
393+
}
394+
374395
/** A type parameter. */
375396
abstract class TypeParameter extends Type {
376397
override TypeParameter getPositionalTypeParameter(int i) { none() }

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 157 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -885,6 +885,113 @@ private Type getCallExprTypeQualifier(CallExpr ce, TypePath path) {
885885
)
886886
}
887887

888+
/**
889+
* Provides functionality related to context-based typing of calls.
890+
*/
891+
private module ContextTyping {
892+
/**
893+
* Holds if the return type of the function `f` at path `path` is `tp`,
894+
* and `tp` does not appear in the type of any parameter of `f`.
895+
*
896+
* In this case, the context in which `f` is called may be needed to infer
897+
* the instantiation of `tp`.
898+
*/
899+
pragma[nomagic]
900+
private predicate assocFunctionReturnContextTypedAt(
901+
Function f, FunctionPosition pos, TypePath path, TypeParameter tp
902+
) {
903+
exists(ImplOrTraitItemNode i |
904+
pos.isReturn() and
905+
assocFunctionTypeAt(f, i, pos, path, tp) and
906+
not exists(FunctionPosition nonResPos |
907+
not nonResPos.isReturn() and
908+
assocFunctionTypeAt(f, i, nonResPos, _, tp)
909+
)
910+
)
911+
}
912+
913+
/**
914+
* A call where the type of the result may have to be inferred from the
915+
* context in which the call appears, for example a call like
916+
* `Default::default()`.
917+
*/
918+
abstract class ContextTypedCallCand extends AstNode {
919+
abstract Type getTypeArgument(TypeArgumentPosition apos, TypePath path);
920+
921+
private predicate hasTypeArgument(TypeArgumentPosition apos) {
922+
exists(this.getTypeArgument(apos, _))
923+
}
924+
925+
/**
926+
* Holds if `this` call resolves to `target` and the type at `pos` and `path`
927+
* may have to be inferred from the context.
928+
*/
929+
bindingset[this, target]
930+
predicate isContextTypedAt(Function target, TypePath path, FunctionPosition pos) {
931+
exists(TypeParameter tp |
932+
assocFunctionReturnContextTypedAt(target, pos, path, tp) and
933+
// check that no explicit type arguments have been supplied for `tp`
934+
not exists(TypeArgumentPosition tapos | this.hasTypeArgument(tapos) |
935+
exists(int i |
936+
i = tapos.asMethodTypeArgumentPosition() and
937+
tp = TTypeParamTypeParameter(target.getGenericParamList().getTypeParam(i))
938+
)
939+
or
940+
TTypeParamTypeParameter(tapos.asTypeParam()) = tp
941+
) and
942+
not (
943+
tp instanceof TSelfTypeParameter and
944+
exists(getCallExprTypeQualifier(this, _))
945+
)
946+
)
947+
}
948+
}
949+
950+
pragma[nomagic]
951+
private predicate isContextTyped(AstNode n, TypePath path) { inferType(n, path) = TContextType() }
952+
953+
pragma[nomagic]
954+
private predicate isContextTyped(AstNode n) { isContextTyped(n, _) }
955+
956+
signature Type inferCallTypeSig(AstNode n, FunctionPosition pos, TypePath path);
957+
958+
/**
959+
* Given a predicate `inferCallType` for inferring the type of a call at a given
960+
* position, this module exposes the predicate `check`, which wraps the input
961+
* predicate and checks that types are only propagated into arguments when they
962+
* are context-typed.
963+
*/
964+
module CheckContextTyping<inferCallTypeSig/3 inferCallType> {
965+
pragma[nomagic]
966+
private Type inferCallTypeFromContextCand(
967+
AstNode n, FunctionPosition pos, TypePath path, TypePath prefix
968+
) {
969+
result = inferCallType(n, pos, path) and
970+
not pos.isReturn() and
971+
isContextTyped(n) and
972+
prefix = path
973+
or
974+
exists(TypePath mid |
975+
result = inferCallTypeFromContextCand(n, pos, path, mid) and
976+
mid.isSnoc(prefix, _)
977+
)
978+
}
979+
980+
pragma[nomagic]
981+
Type check(AstNode n, TypePath path) {
982+
exists(FunctionPosition pos |
983+
result = inferCallType(n, pos, path) and
984+
pos.isReturn()
985+
or
986+
exists(TypePath prefix |
987+
result = inferCallTypeFromContextCand(n, pos, path, prefix) and
988+
isContextTyped(n, prefix)
989+
)
990+
)
991+
}
992+
}
993+
}
994+
888995
/**
889996
* Holds if function `f` with the name `name` and the arity `arity` exists in
890997
* `i`, and the type at position `pos` is `t`.
@@ -1890,14 +1997,14 @@ private module MethodCallMatchingInput implements MatchingWithEnvironmentInputSi
18901997

18911998
final private class MethodCallFinal = MethodResolution::MethodCall;
18921999

1893-
class Access extends MethodCallFinal {
2000+
class Access extends MethodCallFinal, ContextTyping::ContextTypedCallCand {
18942001
Access() {
18952002
// handled in the `OperationMatchingInput` module
18962003
not this instanceof Operation
18972004
}
18982005

18992006
pragma[nomagic]
1900-
Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
2007+
override Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
19012008
exists(TypeMention arg | result = arg.resolveTypeAt(path) |
19022009
arg =
19032010
this.(MethodCallExpr).getGenericArgList().getTypeArg(apos.asMethodTypeArgumentPosition())
@@ -1961,7 +2068,12 @@ private Type inferMethodCallType0(
19612068
) {
19622069
exists(TypePath path0 |
19632070
n = a.getNodeAt(apos) and
1964-
result = MethodCallMatching::inferAccessType(a, derefChainBorrow, apos, path0)
2071+
(
2072+
result = MethodCallMatching::inferAccessType(a, derefChainBorrow, apos, path0)
2073+
or
2074+
a.isContextTypedAt(a.getTarget(derefChainBorrow), path0, apos) and
2075+
result = TContextType()
2076+
)
19652077
|
19662078
if
19672079
// index expression `x[i]` desugars to `*x.index(i)`, so we must account for
@@ -1973,16 +2085,11 @@ private Type inferMethodCallType0(
19732085
)
19742086
}
19752087

1976-
/**
1977-
* Gets the type of `n` at `path`, where `n` is either a method call or an
1978-
* argument/receiver of a method call.
1979-
*/
19802088
pragma[nomagic]
1981-
private Type inferMethodCallType(AstNode n, TypePath path) {
1982-
exists(
1983-
MethodCallMatchingInput::Access a, MethodCallMatchingInput::AccessPosition apos,
1984-
string derefChainBorrow, TypePath path0
1985-
|
2089+
private Type inferMethodCallType1(
2090+
AstNode n, MethodCallMatchingInput::AccessPosition apos, TypePath path
2091+
) {
2092+
exists(MethodCallMatchingInput::Access a, string derefChainBorrow, TypePath path0 |
19862093
result = inferMethodCallType0(a, apos, n, derefChainBorrow, path0)
19872094
|
19882095
(
@@ -2004,6 +2111,15 @@ private Type inferMethodCallType(AstNode n, TypePath path) {
20042111
)
20052112
}
20062113

2114+
/**
2115+
* Gets the type of `n` at `path`, where `n` is either a method call or an
2116+
* argument/receiver of a method call.
2117+
*/
2118+
pragma[nomagic]
2119+
private Type inferMethodCallType(AstNode n, TypePath path) {
2120+
result = ContextTyping::CheckContextTyping<inferMethodCallType1/3>::check(n, path)
2121+
}
2122+
20072123
/**
20082124
* Provides logic for resolving calls to non-method items. This includes
20092125
* "calls" to tuple variants and tuple structs.
@@ -2171,6 +2287,12 @@ private module NonMethodResolution {
21712287
or
21722288
result = this.resolveCallTargetRec()
21732289
}
2290+
2291+
pragma[nomagic]
2292+
Function resolveTraitFunction() {
2293+
this.(Call).hasTrait() and
2294+
result = this.getPathResolutionResolved()
2295+
}
21742296
}
21752297

21762298
private newtype TCallAndBlanketPos =
@@ -2405,9 +2527,9 @@ private module NonMethodCallMatchingInput implements MatchingInputSig {
24052527
}
24062528
}
24072529

2408-
class Access extends NonMethodResolution::NonMethodCall {
2530+
class Access extends NonMethodResolution::NonMethodCall, ContextTyping::ContextTypedCallCand {
24092531
pragma[nomagic]
2410-
Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
2532+
override Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
24112533
result = getCallExprTypeArgument(this, apos).resolveTypeAt(path)
24122534
}
24132535

@@ -2428,13 +2550,22 @@ private module NonMethodCallMatchingInput implements MatchingInputSig {
24282550
private module NonMethodCallMatching = Matching<NonMethodCallMatchingInput>;
24292551

24302552
pragma[nomagic]
2431-
private Type inferNonMethodCallType(AstNode n, TypePath path) {
2432-
exists(NonMethodCallMatchingInput::Access a, NonMethodCallMatchingInput::AccessPosition apos |
2433-
n = a.getNodeAt(apos) and
2553+
private Type inferNonMethodCallType0(
2554+
AstNode n, NonMethodCallMatchingInput::AccessPosition apos, TypePath path
2555+
) {
2556+
exists(NonMethodCallMatchingInput::Access a | n = a.getNodeAt(apos) |
24342557
result = NonMethodCallMatching::inferAccessType(a, apos, path)
2558+
or
2559+
a.isContextTypedAt([a.resolveCallTarget().(Function), a.resolveTraitFunction()], path, apos) and
2560+
result = TContextType()
24352561
)
24362562
}
24372563

2564+
pragma[nomagic]
2565+
private Type inferNonMethodCallType(AstNode n, TypePath path) {
2566+
result = ContextTyping::CheckContextTyping<inferNonMethodCallType0/3>::check(n, path)
2567+
}
2568+
24382569
/**
24392570
* A matching configuration for resolving types of operations like `a + b`.
24402571
*/
@@ -2507,13 +2638,20 @@ private module OperationMatchingInput implements MatchingInputSig {
25072638
private module OperationMatching = Matching<OperationMatchingInput>;
25082639

25092640
pragma[nomagic]
2510-
private Type inferOperationType(AstNode n, TypePath path) {
2511-
exists(OperationMatchingInput::Access a, OperationMatchingInput::AccessPosition apos |
2641+
private Type inferOperationType0(
2642+
AstNode n, OperationMatchingInput::AccessPosition apos, TypePath path
2643+
) {
2644+
exists(OperationMatchingInput::Access a |
25122645
n = a.getNodeAt(apos) and
25132646
result = OperationMatching::inferAccessType(a, apos, path)
25142647
)
25152648
}
25162649

2650+
pragma[nomagic]
2651+
private Type inferOperationType(AstNode n, TypePath path) {
2652+
result = ContextTyping::CheckContextTyping<inferOperationType0/3>::check(n, path)
2653+
}
2654+
25172655
pragma[nomagic]
25182656
private Type getFieldExprLookupType(FieldExpr fe, string name) {
25192657
exists(TypePath path |
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
multipleCallTargets
2+
| test.rs:24:24:24:34 | row.take(...) |
3+
| test.rs:111:24:111:34 | row.take(...) |
14
multiplePathResolutions
25
| test.rs:10:28:10:65 | Result::<...> |
36
| test.rs:97:40:97:49 | Result::<...> |
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
multipleCallTargets
2+
| test.rs:288:7:288:36 | ... .as_str() |

0 commit comments

Comments
 (0)