Skip to content

Commit 066f13d

Browse files
committed
Rust: Improve type inference for closures
1 parent ff26b57 commit 066f13d

File tree

5 files changed

+553
-155
lines changed

5 files changed

+553
-155
lines changed

rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll

Lines changed: 104 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,14 @@ private predicate isPanicMacroCall(MacroExpr me) {
467467
me.getMacroCall().resolveMacro().(MacroRules).getName().getText() = "panic"
468468
}
469469

470+
// Due to "binding modes" the type of the pattern is not necessarily the
471+
// same as the type of the initializer. The pattern being an identifier
472+
// pattern is sufficient to ensure that this is not the case.
473+
private predicate identLetStmt(LetStmt let, IdentPat lhs, Expr rhs) {
474+
let.getPat() = lhs and
475+
let.getInitializer() = rhs
476+
}
477+
470478
/** Module for inferring certain type information. */
471479
module CertainTypeInference {
472480
pragma[nomagic]
@@ -544,11 +552,7 @@ module CertainTypeInference {
544552
// is not a certain type equality.
545553
exists(LetStmt let |
546554
not let.hasTypeRepr() and
547-
// Due to "binding modes" the type of the pattern is not necessarily the
548-
// same as the type of the initializer. The pattern being an identifier
549-
// pattern is sufficient to ensure that this is not the case.
550-
let.getPat().(IdentPat) = n1 and
551-
let.getInitializer() = n2
555+
identLetStmt(let, n1, n2)
552556
)
553557
or
554558
exists(LetExpr let |
@@ -572,6 +576,25 @@ module CertainTypeInference {
572576
)
573577
else prefix2.isEmpty()
574578
)
579+
or
580+
exists(CallExprImpl::DynamicCallExpr dce, TupleType tt, int i |
581+
n1 = dce.getArgList() and
582+
tt.getArity() = dce.getNumberOfSyntacticArguments() and
583+
n2 = dce.getSyntacticPositionalArgument(i) and
584+
prefix1 = TypePath::singleton(tt.getPositionalTypeParameter(i)) and
585+
prefix2.isEmpty()
586+
)
587+
or
588+
exists(ClosureExpr ce, int index |
589+
n1 = ce and
590+
n2 = ce.getParam(index).getPat() and
591+
prefix1 = closureParameterPath(ce.getNumberOfParams(), index) and
592+
prefix2.isEmpty()
593+
)
594+
or
595+
n1 = any(ClosureExpr ce | not ce.hasRetType() and ce.getClosureBody() = n2) and
596+
prefix1 = closureReturnPath() and
597+
prefix2.isEmpty()
575598
}
576599

577600
pragma[nomagic]
@@ -834,17 +857,6 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
834857
n1.(ArrayRepeatExpr).getRepeatOperand() = n2 and
835858
prefix1 = TypePath::singleton(getArrayTypeParameter()) and
836859
prefix2.isEmpty()
837-
or
838-
exists(ClosureExpr ce, int index |
839-
n1 = ce and
840-
n2 = ce.getParam(index).getPat() and
841-
prefix1 = closureParameterPath(ce.getNumberOfParams(), index) and
842-
prefix2.isEmpty()
843-
)
844-
or
845-
n1.(ClosureExpr).getClosureBody() = n2 and
846-
prefix1 = closureReturnPath() and
847-
prefix2.isEmpty()
848860
}
849861

850862
/**
@@ -887,6 +899,19 @@ private predicate lubCoercion(AstNode parent, AstNode child, TypePath prefix) {
887899
)
888900
}
889901

902+
private Type inferUnknownTypeFromAnnotation(AstNode n, TypePath path) {
903+
inferType(n, path) = TUnknownType() and
904+
// Normally, these are coercion sites, but in case a type is unknown we
905+
// allow for type information to flow from the type annotation.
906+
exists(TypeMention tm | result = tm.getTypeAt(path) |
907+
tm = any(LetStmt let | identLetStmt(let, _, n)).getTypeRepr()
908+
or
909+
tm = any(ClosureExpr ce | n = ce.getBody()).getRetType().getTypeRepr()
910+
or
911+
tm = getReturnTypeMention(any(Function f | n = f.getBody()))
912+
)
913+
}
914+
890915
/**
891916
* Holds if the type tree of `n1` at `prefix1` should be equal to the type tree
892917
* of `n2` at `prefix2`, but type information should only propagate from `n1` to
@@ -1547,6 +1572,8 @@ private module AssocFunctionResolution {
15471572
* 3. `AssocFunctionCallCallExpr`: a qualified function call, `Q::f(x)`; or
15481573
* 4. `AssocFunctionCallOperation`: an operation expression, `x + y`, which is syntactic sugar
15491574
* for `Add::add(x, y)`.
1575+
* 5. `ClosureMethodCall`: a call to a closure, `c(x)`, which is syntactic sugar for
1576+
* `c.call_once(x)`, `c.call_mut(x)`, or `c.call(x)`.
15501577
*
15511578
* Note that only in case 1 and 2 is auto-dereferencing and borrowing allowed.
15521579
*
@@ -1563,7 +1590,7 @@ private module AssocFunctionResolution {
15631590
pragma[nomagic]
15641591
abstract predicate hasNameAndArity(string name, int arity);
15651592

1566-
abstract Expr getNonReturnNodeAt(FunctionPosition pos);
1593+
abstract AstNode getNonReturnNodeAt(FunctionPosition pos);
15671594

15681595
AstNode getNodeAt(FunctionPosition pos) {
15691596
result = this.getNonReturnNodeAt(pos)
@@ -2097,7 +2124,7 @@ private module AssocFunctionResolution {
20972124
}
20982125
}
20992126

2100-
private class AssocFunctionCallMethodCallExpr extends AssocFunctionCall instanceof MethodCallExpr {
2127+
private class MethodCallExprAssocFunctionCall extends AssocFunctionCall instanceof MethodCallExpr {
21012128
override predicate hasNameAndArity(string name, int arity) {
21022129
name = super.getIdentifier().getText() and
21032130
arity = super.getNumberOfSyntacticArguments()
@@ -2117,7 +2144,7 @@ private module AssocFunctionResolution {
21172144
override Trait getTrait() { none() }
21182145
}
21192146

2120-
private class AssocFunctionCallIndexExpr extends AssocFunctionCall, IndexExpr {
2147+
private class IndexExprAssocFunctionCall extends AssocFunctionCall, IndexExpr {
21212148
private predicate isInMutableContext() {
21222149
// todo: does not handle all cases yet
21232150
VariableImpl::assignmentOperationDescendant(_, this)
@@ -2147,8 +2174,8 @@ private module AssocFunctionResolution {
21472174
}
21482175
}
21492176

2150-
private class AssocFunctionCallCallExpr extends AssocFunctionCall, CallExpr {
2151-
AssocFunctionCallCallExpr() {
2177+
private class CallExprAssocFunctionCall extends AssocFunctionCall, CallExpr {
2178+
CallExprAssocFunctionCall() {
21522179
exists(getCallExprPathQualifier(this)) and
21532180
// even if a target cannot be resolved by path resolution, it may still
21542181
// be possible to resolve a blanket implementation (so not `forex`)
@@ -2180,7 +2207,7 @@ private module AssocFunctionResolution {
21802207
override Trait getTrait() { result = getCallExprTraitQualifier(this) }
21812208
}
21822209

2183-
final class AssocFunctionCallOperation extends AssocFunctionCall, Operation {
2210+
final class OperationAssocFunctionCall extends AssocFunctionCall, Operation {
21842211
override predicate hasNameAndArity(string name, int arity) {
21852212
this.isOverloaded(_, name, _) and
21862213
arity = this.getNumberOfOperands()
@@ -2238,6 +2265,29 @@ private module AssocFunctionResolution {
22382265
override Trait getTrait() { this.isOverloaded(result, _, _) }
22392266
}
22402267

2268+
private class DynamicAssocFunctionCall extends AssocFunctionCall instanceof CallExprImpl::DynamicCallExpr
2269+
{
2270+
pragma[nomagic]
2271+
override predicate hasNameAndArity(string name, int arity) {
2272+
name = "call_once" and // todo: handle call_mut and call
2273+
arity = 2 // args are passed in a tuple
2274+
}
2275+
2276+
override predicate hasReceiver() { any() }
2277+
2278+
override AstNode getNonReturnNodeAt(FunctionPosition pos) {
2279+
pos.asPosition() = 0 and
2280+
result = super.getFunction()
2281+
or
2282+
pos.asPosition() = 1 and
2283+
result = super.getArgList()
2284+
}
2285+
2286+
override predicate supportsAutoDerefAndBorrow() { any() }
2287+
2288+
override Trait getTrait() { result instanceof AnyFnTrait }
2289+
}
2290+
22412291
pragma[nomagic]
22422292
private AssocFunctionDeclaration getAssocFunctionSuccessor(
22432293
ImplOrTraitItemNode i, string name, int arity
@@ -3163,7 +3213,7 @@ private module OperationMatchingInput implements MatchingInputSig {
31633213
}
31643214
}
31653215

3166-
class Access extends AssocFunctionResolution::AssocFunctionCallOperation {
3216+
class Access extends AssocFunctionResolution::OperationAssocFunctionCall {
31673217
Type getTypeArgument(TypeArgumentPosition apos, TypePath path) { none() }
31683218

31693219
pragma[nomagic]
@@ -3725,14 +3775,6 @@ private module InvokedClosureSatisfiesTypeInput implements SatisfiesTypeInputSig
37253775
}
37263776
}
37273777

3728-
private module InvokedClosureSatisfiesType =
3729-
SatisfiesType<InvokedClosureExpr, InvokedClosureSatisfiesTypeInput>;
3730-
3731-
/** Gets the type of `ce` when viewed as an implementation of `FnOnce`. */
3732-
private Type invokedClosureFnTypeAt(InvokedClosureExpr ce, TypePath path) {
3733-
InvokedClosureSatisfiesType::satisfiesConstraintType(ce, _, path, result)
3734-
}
3735-
37363778
/**
37373779
* Gets the root type of a closure.
37383780
*
@@ -3759,73 +3801,39 @@ private TypePath closureParameterPath(int arity, int index) {
37593801
TypePath::singleton(getTupleTypeParameter(arity, index)))
37603802
}
37613803

3762-
/** Gets the path to the return type of the `FnOnce` trait. */
3763-
private TypePath fnReturnPath() {
3764-
result = TypePath::singleton(getAssociatedTypeTypeParameter(any(FnOnceTrait t).getOutputType()))
3765-
}
3766-
3767-
/**
3768-
* Gets the path to the parameter type of the `FnOnce` trait with arity `arity`
3769-
* and index `index`.
3770-
*/
37713804
pragma[nomagic]
3772-
private TypePath fnParameterPath(int arity, int index) {
3773-
result =
3774-
TypePath::cons(TTypeParamTypeParameter(any(FnOnceTrait t).getTypeParam()),
3775-
TypePath::singleton(getTupleTypeParameter(arity, index)))
3776-
}
3777-
3778-
pragma[nomagic]
3779-
private Type inferDynamicCallExprType(Expr n, TypePath path) {
3780-
exists(InvokedClosureExpr ce |
3781-
// Propagate the function's return type to the call expression
3782-
exists(TypePath path0 | result = invokedClosureFnTypeAt(ce, path0) |
3783-
n = ce.getCall() and
3784-
path = path0.stripPrefix(fnReturnPath())
3805+
private Type inferClosureExprType(AstNode n, TypePath path) {
3806+
exists(ClosureExpr ce |
3807+
n = ce and
3808+
(
3809+
path.isEmpty() and
3810+
result = closureRootType()
3811+
or
3812+
path = TypePath::singleton(TDynTraitTypeParameter(_, any(FnTrait t).getTypeParam())) and
3813+
result.(TupleType).getArity() = ce.getNumberOfParams()
37853814
or
3786-
// Propagate the function's parameter type to the arguments
3787-
exists(int index |
3788-
n = ce.getCall().getSyntacticPositionalArgument(index) and
3789-
path =
3790-
path0.stripPrefix(fnParameterPath(ce.getCall().getArgList().getNumberOfArgs(), index))
3815+
exists(TypePath path0 |
3816+
result = ce.getRetType().getTypeRepr().(TypeMention).getTypeAt(path0) and
3817+
path = closureReturnPath().append(path0)
37913818
)
37923819
)
37933820
or
3794-
// _If_ the invoked expression has the type of a closure, then we propagate
3795-
// the surrounding types into the closure.
3796-
exists(int arity, TypePath path0 | ce.getTypeAt(TypePath::nil()) = closureRootType() |
3797-
// Propagate the type of arguments to the parameter types of closure
3798-
exists(int index, ArgList args |
3799-
n = ce and
3800-
args = ce.getCall().getArgList() and
3801-
arity = args.getNumberOfArgs() and
3802-
result = inferType(args.getArg(index), path0) and
3803-
path = closureParameterPath(arity, index).append(path0)
3804-
)
3805-
or
3806-
// Propagate the type of the call expression to the return type of the closure
3807-
n = ce and
3808-
arity = ce.getCall().getArgList().getNumberOfArgs() and
3809-
result = inferType(ce.getCall(), path0) and
3810-
path = closureReturnPath().append(path0)
3821+
exists(Param p |
3822+
p = ce.getAParam() and
3823+
not p.hasTypeRepr() and
3824+
n = p.getPat() and
3825+
result = TUnknownType() and
3826+
path.isEmpty()
38113827
)
38123828
)
38133829
}
38143830

38153831
pragma[nomagic]
3816-
private Type inferClosureExprType(AstNode n, TypePath path) {
3817-
exists(ClosureExpr ce |
3818-
n = ce and
3819-
path.isEmpty() and
3820-
result = closureRootType()
3821-
or
3822-
n = ce and
3823-
path = TypePath::singleton(TDynTraitTypeParameter(_, any(FnTrait t).getTypeParam())) and
3824-
result.(TupleType).getArity() = ce.getNumberOfParams()
3825-
or
3826-
// Propagate return type annotation to body
3827-
n = ce.getClosureBody() and
3828-
result = ce.getRetType().getTypeRepr().(TypeMention).getTypeAt(path)
3832+
private TupleType inferArgList(ArgList args, TypePath path) {
3833+
exists(CallExprImpl::DynamicCallExpr dce |
3834+
args = dce.getArgList() and
3835+
result.getArity() = dce.getNumberOfSyntacticArguments() and
3836+
path.isEmpty()
38293837
)
38303838
}
38313839

@@ -3873,7 +3881,8 @@ private module Cached {
38733881
or
38743882
i instanceof ImplItemNode and dispatch = false
38753883
|
3876-
result = call.(AssocFunctionResolution::AssocFunctionCall).resolveCallTarget(i, _, _, _)
3884+
result = call.(AssocFunctionResolution::AssocFunctionCall).resolveCallTarget(i, _, _, _) and
3885+
not call instanceof CallExprImpl::DynamicCallExpr // todo
38773886
)
38783887
}
38793888

@@ -3980,11 +3989,13 @@ private module Cached {
39803989
or
39813990
result = inferForLoopExprType(n, path)
39823991
or
3983-
result = inferDynamicCallExprType(n, path)
3984-
or
39853992
result = inferClosureExprType(n, path)
39863993
or
3994+
result = inferArgList(n, path)
3995+
or
39873996
result = inferDeconstructionPatType(n, path)
3997+
or
3998+
result = inferUnknownTypeFromAnnotation(n, path)
39883999
)
39894000
}
39904001
}
@@ -4001,8 +4012,8 @@ private module Debug {
40014012
Locatable getRelevantLocatable() {
40024013
exists(string filepath, int startline, int startcolumn, int endline, int endcolumn |
40034014
result.getLocation().hasLocationInfo(filepath, startline, startcolumn, endline, endcolumn) and
4004-
filepath.matches("%/main.rs") and
4005-
startline = 103
4015+
filepath.matches("%/regressions.rs") and
4016+
startline = 24
40064017
)
40074018
}
40084019

0 commit comments

Comments
 (0)