diff --git a/rust/ql/lib/codeql/rust/internal/CachedStages.qll b/rust/ql/lib/codeql/rust/internal/CachedStages.qll index a92770ed2384..f76006ddcbc7 100644 --- a/rust/ql/lib/codeql/rust/internal/CachedStages.qll +++ b/rust/ql/lib/codeql/rust/internal/CachedStages.qll @@ -147,9 +147,9 @@ module Stages { predicate backref() { 1 = 1 or - exists(Type t) + (exists(Type t) implies any()) or - exists(inferType(_)) + (exists(inferType(_)) implies any()) } } diff --git a/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll b/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll index 423ad21ae4ac..1285256c1fb5 100644 --- a/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll +++ b/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll @@ -10,11 +10,11 @@ private import TypeAbstraction as TA private import Type as T private import TypeMention private import codeql.rust.internal.typeinference.DerefChain +private import codeql.rust.internal.CachedStages private import FunctionType private import FunctionOverloading as FunctionOverloading private import BlanketImplementation as BlanketImplementation private import codeql.rust.elements.internal.VariableImpl::Impl as VariableImpl -private import codeql.rust.internal.CachedStages private import codeql.typeinference.internal.TypeInference private import codeql.rust.frameworks.stdlib.Stdlib private import codeql.rust.frameworks.stdlib.Builtins as Builtins @@ -37,10 +37,7 @@ private module Input1 implements InputSig1 { class Type = T::Type; - predicate isPseudoType(Type t) { - t instanceof UnknownType or - t instanceof NeverType - } + class UnknownType = T::UnknownType; class TypeParameter = T::TypeParameter; @@ -276,17 +273,409 @@ private module M2 = Make2; import M2 +private module Input3 implements InputSig3 { + private import rust as Rust + + predicate cachedStageRevRef() { + Stages::TypeInferenceStage::ref() + or + (implicitDerefChainBorrow(_, _, _) implies any()) + or + (exists(resolveCallTarget(_, _)) implies any()) + or + (exists(resolveStructFieldExpr(_, _)) implies any()) + or + (exists(resolveTupleFieldExpr(_, _)) implies any()) + } + + predicate inferType = M3::inferType/2; + + class BoolType extends DataType { + BoolType() { this.getTypeItem() instanceof Builtins::Bool } + } + + class AstNode = Rust::AstNode; + + TypeMention getTypeAnnotation(AstNode n) { + exists(LetStmt let | + n = let.getPat() and + result = let.getTypeRepr() + ) + or + result = n.(SelfParam).getTypeRepr() + or + exists(Param p | + n = p.getPat() and + result = p.getTypeRepr() + ) + or + result = n.(ShorthandSelfParameterMention) + } + + class Expr = Rust::Expr; + + class Switch extends Rust::MatchExpr { + Expr getExpr() { result = this.getScrutinee() } + + Case getCase(int index) { result = this.getArm(index) } + } + + class Case extends Rust::MatchArm { + AstNode getAPattern() { result = this.getPat() } + + AstNode getBody() { result = this.getExpr() } + } + + class ConditionalExpr extends IfExpr { + Expr getThen() { result = super.getThen() } + } + + class BinaryExpr extends Rust::BinaryExpr { + Expr getLeftOperand() { result = super.getLhs() } + + Expr getRightOperand() { result = super.getRhs() } + } + + class LogicalAndExpr extends BinaryExpr, Rust::LogicalAndExpr { } + + class LogicalOrExpr extends BinaryExpr, Rust::LogicalOrExpr { } + + abstract class Assignment extends BinaryExpr { } + + class AssignExpr extends Assignment, Rust::AssignmentExpr { } + + class ParenExpr = Rust::ParenExpr; + + class Variable extends Rust::Variable { + AstNode getDefiningNode() { + result = this.getPat().getName() or + result = this.getParameter().(SelfParam) + } + + Expr getAnAccess() { result = super.getAnAccess() } + } + + abstract class LetDeclaration extends AstNode { + abstract predicate isCoercionSite(); + + abstract AstNode getLeftOperand(); + + abstract AstNode getRightOperand(); + } + + private class LetExprLetDeclaration extends LetDeclaration, LetExpr { + override predicate isCoercionSite() { not this.getPat() instanceof IdentPat } + + override AstNode getLeftOperand() { result = this.getPat() } + + override AstNode getRightOperand() { result = this.getScrutinee() } + } + + private class LetStmtLetDeclaration extends LetDeclaration, LetStmt { + override predicate isCoercionSite() { + this.hasTypeRepr() or + not identLetStmt(this, _, _) + } + + override AstNode getLeftOperand() { result = this.getPat() } + + override AstNode getRightOperand() { result = this.getInitializer() } + } + + class CallResolutionContext = FunctionCallMatchingInput::AccessEnvironment; + + class TypePosition = FunctionPosition; + + class Callable extends FunctionCallMatchingInput::Declaration { + TypeMention getAdditionalTypeParameterConstraint(TypeParameter tp) { + result = + tp.(TypeParamTypeParameter) + .getTypeParam() + .getAdditionalTypeBound(this.getFunction(), _) + .getTypeRepr() + } + } + + class Call extends FunctionCallMatchingInput::Access { + /** Gets the target of this call. */ + Callable getTargetCertain() { + exists(ImplOrTraitItemNodeOption i, FunctionDeclaration f, Path p | + result.isFunction(i, f) and + p = CallExprImpl::getFunctionPath(this) and + f = resolvePath(p) and + f.isDirectlyFor(i) + ) + } + + Callable getTarget(string derefChainBorrow) { result = super.getTarget(derefChainBorrow) } + } + + bindingset[derefChainBorrow] + Type inferCallTypeBottomUp(Call call, string derefChainBorrow, FunctionPosition pos, TypePath path) { + result = call.(FunctionCallMatchingInput::Access).getInferredType(derefChainBorrow, pos, path) + } + + Type inferCallReturnType(AstNode n, TypePath path) { + exists(Call call, TypePath path0 | + result = M3::inferCallReturnType(call, _, n, path0) and + if + // index expression `x[i]` desugars to `*x.index(i)`, so we must account for + // the implicit deref + call instanceof IndexExpr + then path0.isCons(getRefTypeParameter(_), path) + else path = path0 + ) + } + + Type inferCallArgumentTypeTopDown(AstNode n, TypePath path) { + exists(FunctionCallMatchingInput::Access call, FunctionPosition pos | + result = inferCallArgumentTypeTopDown(call, pos, n, _, _, path) and + not call.(AssocFunctionResolution::AssocFunctionCall).hasReceiverAtPos(pos) + ) + or + exists(FunctionCallMatchingInput::Access a | + result = inferFunctionCallSelfArgumentTypeTopDown(a, n, DerefChain::nil(), path) and + if a.(AssocFunctionResolution::AssocFunctionCall).hasReceiver() + then not path.isEmpty() + else any() + ) + } + + predicate inferStepSymmetricCertain(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) { + n1 = + any(IdentPat ip | + n2 = ip.getName() and + prefix1.isEmpty() and + if ip.isRef() + then + exists(boolean isMutable | if ip.isMut() then isMutable = true else isMutable = false | + prefix2 = TypePath::singleton(getRefTypeParameter(isMutable)) + ) + else prefix2.isEmpty() + ) + or + exists(CallExprImpl::DynamicCallExpr dce, TupleType tt, int i | + n1 = dce.getArgList() and + tt.getArity() = dce.getNumberOfSyntacticArguments() and + n2 = dce.getSyntacticPositionalArgument(i) and + prefix1 = TypePath::singleton(tt.getPositionalTypeParameter(i)) and + prefix2.isEmpty() + ) + or + exists(ClosureExpr ce, int index | + n1 = ce and + n2 = ce.getParam(index).getPat() and + prefix1 = closureParameterPath(ce.getNumberOfParams(), index) and + prefix2.isEmpty() + ) + } + + Type inferTypeCertainSpecific(AstNode n, TypePath path) { + result = inferFunctionBodyType(n, path) + or + result = inferLiteralType(n, path, true) + or + result = inferRefPatType(n) and + path.isEmpty() + or + result = inferRefExprType(n) and + path.isEmpty() + or + result = inferCertainStructExprType(n, path) + or + result = inferCertainStructPatType(n, path) + or + result = inferRangeExprType(n) and + path.isEmpty() + or + result = inferTupleRootType(n) and + path.isEmpty() + or + result = inferBlockExprType(n, path) + or + result = inferArrayExprType(n) and + path.isEmpty() + or + result = inferCastExprType(n, path) + or + exprHasUnitType(n) and + path.isEmpty() and + result instanceof UnitType + or + isPanicMacroCall(n) and + path.isEmpty() and + result instanceof NeverType + or + n instanceof ClosureExpr and + path.isEmpty() and + result = closureRootType() + } + + predicate inferStepSymmetric(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) { + prefix1.isEmpty() and + prefix2.isEmpty() and + ( + n1 = n2.(OrPat).getAPat() + or + n1 = n2.(ParenPat).getPat() + or + n1 = n2.(LiteralPat).getLiteral() + or + exists(BreakExpr break | + break.getExpr() = n1 and + break.getTarget() = n2.(LoopExpr) + ) + or + n1 = n2.(MacroExpr).getMacroCall().getMacroCallExpansion() and + not isPanicMacroCall(n2) + or + n1 = n2.(MacroPat).getMacroCall().getMacroCallExpansion() + ) + or + n2 = + any(RefExpr re | + n1 = re.getExpr() and + prefix1.isEmpty() and + prefix2 = TypePath::singleton(inferRefExprType(re).getPositionalTypeParameter(0)) + ) + or + n2 = + any(RefPat rp | + n1 = rp.getPat() and + prefix1.isEmpty() and + exists(boolean isMutable | if rp.isMut() then isMutable = true else isMutable = false | + prefix2 = TypePath::singleton(getRefTypeParameter(isMutable)) + ) + ) + or + exists(int i, int arity | + prefix1.isEmpty() and + prefix2 = TypePath::singleton(getTupleTypeParameter(arity, i)) + | + arity = n2.(TupleExpr).getNumberOfFields() and + n1 = n2.(TupleExpr).getField(i) + or + arity = n2.(TuplePat).getTupleArity() and + n1 = n2.(TuplePat).getField(i) + ) + or + exists(BlockExpr be | + n1 = be and + n2 = be.getStmtList().getTailExpr() and + if be.isAsync() + then + prefix1 = TypePath::singleton(getDynFutureOutputTypeParameter()) and + prefix2.isEmpty() + else ( + prefix1.isEmpty() and + prefix2.isEmpty() + ) + ) + or + // an array repeat expression (`[1; 3]`) has the type of the repeat operand + n1.(ArrayRepeatExpr).getRepeatOperand() = n2 and + prefix1 = TypePath::singleton(getArrayTypeParameter()) and + prefix2.isEmpty() + } + + predicate inferStep(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) { + // When `n2` is `*n1` propagate type information from a raw pointer type + // parameter at `n1`. The other direction is handled in + // `inferDereferencedExprPtrType`. + n1 = n2.(DerefExpr).getExpr() and + prefix1 = TypePath::singleton(getPtrTypeParameter()) and + prefix2.isEmpty() + or + n2 = any(ClosureExpr ce | not ce.hasRetType() and ce.getClosureBody() = n1) and + prefix2 = closureReturnPath() and + prefix1.isEmpty() + } + + predicate inferLubStep(AstNode n1, TypePath path1, AstNode n2, TypePath path2) { + path1.isEmpty() and + ( + n1 = n2.(ArrayListExpr).getAnExpr() and + path2 = TypePath::singleton(getArrayTypeParameter()) + or + exists(ReturnExpr re, Rust::Callable c | + n1 = re.getExpr() and + c = re.getEnclosingCallable() and + n2 = c.getBody() and + path2.isEmpty() + ) + or + exists(Struct s | + n1 = [n2.(RangeExpr).getStart(), n2.(RangeExpr).getEnd()] and + path2 = + TypePath::singleton(TTypeParamTypeParameter(s.getGenericParamList().getATypeParam())) and + s = getRangeType(n2) + ) + ) + } + + Type inferTypeTopDown(AstNode n, TypePath path) { + result = inferTypeFromAnnotationTopDown(n, path) + or + result = inferClosureExprBodyTypeTopDown(n, path) + or + exists(FunctionPosition pos | not pos.isReturn() | + result = inferConstructionType(n, pos, path) + or + result = inferOperationType(n, pos, path) + ) + } + + Type inferTypeSpecific(AstNode n, TypePath path) { + result = inferAssignmentOperationType(n, path) + or + exists(FunctionPosition pos | pos.isReturn() | + result = inferConstructionType(n, pos, path) + or + result = inferOperationType(n, pos, path) + ) + or + result = inferFieldExprType(n, path) + or + result = inferTryExprType(n, path) + or + result = inferLiteralType(n, path, false) + or + result = inferAwaitExprType(n, path) + or + result = inferDereferencedExprPtrType(n, path) + or + result = inferForLoopExprType(n, path) + or + result = inferClosureExprType(n, path) + or + result = inferArgList(n, path) + or + result = inferDeconstructionPatType(n, path) + or + result = inferUnknownType(n, path) + } +} + +private module M3 = Make3; + +// import M3 +predicate inferType = M3::inferType/1; + +predicate inferType = M3::inferType/2; + +predicate inferTypeCertain = M3::inferTypeCertain/2; + module Consistency { import M2::Consistency - private Type inferCertainTypeAdj(AstNode n, TypePath path) { - result = CertainTypeInference::inferCertainType(n, path) and + private Type inferTypeCertainAdj(AstNode n, TypePath path) { + result = inferTypeCertain(n, path) and not result = TNeverType() } predicate nonUniqueCertainType(AstNode n, TypePath path, Type t) { - strictcount(inferCertainTypeAdj(n, path)) > 1 and - t = inferCertainTypeAdj(n, path) and + strictcount(inferTypeCertainAdj(n, path)) > 1 and + t = inferTypeCertainAdj(n, path) and // Suppress the inconsistency if `n` is a self parameter and the type // mention for the self type has multiple types for a path. not exists(ImplItemNode impl, TypePath selfTypePath | @@ -408,16 +797,38 @@ private class AssocFunctionDeclaration extends FunctionDeclaration { } pragma[nomagic] -private TypeMention getCallExprTypeMentionArgument(CallExpr ce, TypeArgumentPosition apos) { - exists(Path p, int i | p = CallExprImpl::getFunctionPath(ce) | - apos.asTypeParam() = resolvePath(p).getTypeParam(pragma[only_bind_into](i)) and - result = getPathTypeArgument(p, pragma[only_bind_into](i)) +private TypePath getPathToImplSelfTypeParam(TypeParam tp) { + exists(ImplItemNode impl | + tp = impl.getTypeParam(_) and + TTypeParamTypeParameter(tp) = impl.(Impl).getSelfTy().(TypeMention).getTypeAt(result) ) } pragma[nomagic] private Type getCallExprTypeArgument(CallExpr ce, TypeArgumentPosition apos, TypePath path) { - result = getCallExprTypeMentionArgument(ce, apos).getTypeAt(path) + exists(Path p, ItemNode resolved, TypeParam tp | + p = CallExprImpl::getFunctionPath(ce) and + resolved = resolvePath(p) and + apos.asTypeParam() = tp + | + // For type parameters of the function we must resolve their + // instantiation from the path. For instance, for `fn bar(a: A) -> A` + // and the path `bar`, we must resolve `A` to `i64`. + exists(int i | + tp = resolved.getTypeParam(pragma[only_bind_into](i)) and + result = getPathTypeArgument(p, pragma[only_bind_into](i)).getTypeAt(path) + ) + or + // For type parameters of the `impl` block we must resolve their + // instantiation from the path. For instance, for `impl for Foo` + // and the path `Foo::bar` we must resolve `A` to `i64`. + exists(ImplItemNode impl, TypePath pathToTp | + resolved = impl.getASuccessor(_) and + tp = impl.getTypeParam(_) and + pathToTp = getPathToImplSelfTypeParam(tp) and + result = p.getQualifier().(TypeMention).getTypeAt(pathToTp.appendInverse(path)) + ) + ) or // Handle constructions that use `Self(...)` syntax exists(Path p, TypePath path0 | @@ -427,29 +838,6 @@ private Type getCallExprTypeArgument(CallExpr ce, TypeArgumentPosition apos, Typ ) } -/** Gets the type annotation that applies to `n`, if any. */ -private TypeMention getTypeAnnotation(AstNode n) { - exists(LetStmt let | - n = let.getPat() and - result = let.getTypeRepr() - ) - or - result = n.(SelfParam).getTypeRepr() - or - exists(Param p | - n = p.getPat() and - result = p.getTypeRepr() - ) -} - -/** Gets the type of `n`, which has an explicit type annotation. */ -pragma[nomagic] -private Type inferAnnotatedType(AstNode n, TypePath path) { - result = getTypeAnnotation(n).getTypeAt(path) - or - result = n.(ShorthandSelfParameterMention).getTypeAt(path) -} - pragma[nomagic] private Type inferFunctionBodyType(AstNode n, TypePath path) { exists(Function f | @@ -507,242 +895,12 @@ private TypePath closureParameterPath(int arity, int index) { TypePath::singleton(getTupleTypeParameter(arity, index))) } -/** Module for inferring certain type information. */ -module CertainTypeInference { - pragma[nomagic] - private predicate callResolvesTo(CallExpr ce, Path p, Function f) { - p = CallExprImpl::getFunctionPath(ce) and - f = resolvePath(p) - } - - pragma[nomagic] - private Type getCallExprType(CallExpr ce, Path p, FunctionDeclaration f, TypePath path) { - exists(ImplOrTraitItemNodeOption i | - callResolvesTo(ce, p, f) and - result = f.getReturnType(i, path) and - f.isDirectlyFor(i) - ) - } - - pragma[nomagic] - private Type getCertainCallExprType(CallExpr ce, Path p, TypePath tp) { - forex(Function f | callResolvesTo(ce, p, f) | result = getCallExprType(ce, p, f, tp)) - } - - pragma[nomagic] - private TypePath getPathToImplSelfTypeParam(TypeParam tp) { - exists(ImplItemNode impl | - tp = impl.getTypeParam(_) and - TTypeParamTypeParameter(tp) = impl.(Impl).getSelfTy().(TypeMention).getTypeAt(result) - ) - } - - pragma[nomagic] - private Type inferCertainCallExprType(CallExpr ce, TypePath path) { - exists(Type ty, TypePath prefix, Path p | ty = getCertainCallExprType(ce, p, prefix) | - exists(TypePath suffix, TypeParam tp | - tp = ty.(TypeParamTypeParameter).getTypeParam() and - path = prefix.append(suffix) - | - // For type parameters of the `impl` block we must resolve their - // instantiation from the path. For instance, for `impl for Foo` - // and the path `Foo::bar` we must resolve `A` to `i64`. - exists(TypePath pathToTp | - pathToTp = getPathToImplSelfTypeParam(tp) and - result = p.getQualifier().(TypeMention).getTypeAt(pathToTp.appendInverse(suffix)) - ) - or - // For type parameters of the function we must resolve their - // instantiation from the path. For instance, for `fn bar(a: A) -> A` - // and the path `bar`, we must resolve `A` to `i64`. - result = getCallExprTypeArgument(ce, TTypeParamTypeArgumentPosition(tp), suffix) - ) - or - not ty instanceof TypeParameter and - result = ty and - path = prefix - ) - } - - private Type inferCertainStructExprType(StructExpr se, TypePath path) { - result = se.getPath().(TypeMention).getTypeAt(path) - } - - private Type inferCertainStructPatType(StructPat sp, TypePath path) { - result = sp.getPath().(TypeMention).getTypeAt(path) - } - - predicate certainTypeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) { - prefix1.isEmpty() and - prefix2.isEmpty() and - ( - exists(Variable v | n1 = v.getAnAccess() | - n2 = v.getPat().getName() or n2 = v.getParameter().(SelfParam) - ) - or - // A `let` statement with a type annotation is a coercion site and hence - // is not a certain type equality. - exists(LetStmt let | - not let.hasTypeRepr() and - identLetStmt(let, n1, n2) - ) - or - exists(LetExpr let | - // Similarly as for let statements, we need to rule out binding modes - // changing the type. - let.getPat().(IdentPat) = n1 and - let.getScrutinee() = n2 - ) - or - n1 = n2.(ParenExpr).getExpr() - ) - or - n1 = - any(IdentPat ip | - n2 = ip.getName() and - prefix1.isEmpty() and - if ip.isRef() - then - exists(boolean isMutable | if ip.isMut() then isMutable = true else isMutable = false | - prefix2 = TypePath::singleton(getRefTypeParameter(isMutable)) - ) - else prefix2.isEmpty() - ) - or - exists(CallExprImpl::DynamicCallExpr dce, TupleType tt, int i | - n1 = dce.getArgList() and - tt.getArity() = dce.getNumberOfSyntacticArguments() and - n2 = dce.getSyntacticPositionalArgument(i) and - prefix1 = TypePath::singleton(tt.getPositionalTypeParameter(i)) and - prefix2.isEmpty() - ) - or - exists(ClosureExpr ce, int index | - n1 = ce and - n2 = ce.getParam(index).getPat() and - prefix1 = closureParameterPath(ce.getNumberOfParams(), index) and - prefix2.isEmpty() - ) - } - - pragma[nomagic] - private Type inferCertainTypeEquality(AstNode n, TypePath path) { - exists(TypePath prefix1, AstNode n2, TypePath prefix2, TypePath suffix | - result = inferCertainType(n2, prefix2.appendInverse(suffix)) and - path = prefix1.append(suffix) - | - certainTypeEquality(n, prefix1, n2, prefix2) - or - certainTypeEquality(n2, prefix2, n, prefix1) - ) - } - - /** - * Holds if `n` has complete and certain type information and if `n` has the - * resulting type at `path`. - */ - cached - Type inferCertainType(AstNode n, TypePath path) { - result = inferAnnotatedType(n, path) and - Stages::TypeInferenceStage::ref() - or - result = inferFunctionBodyType(n, path) - or - result = inferCertainCallExprType(n, path) - or - result = inferCertainTypeEquality(n, path) - or - result = inferLiteralType(n, path, true) - or - result = inferRefPatType(n) and - path.isEmpty() - or - result = inferRefExprType(n) and - path.isEmpty() - or - result = inferLogicalOperationType(n, path) - or - result = inferCertainStructExprType(n, path) - or - result = inferCertainStructPatType(n, path) - or - result = inferRangeExprType(n) and - path.isEmpty() - or - result = inferTupleRootType(n) and - path.isEmpty() - or - result = inferBlockExprType(n, path) - or - result = inferArrayExprType(n) and - path.isEmpty() - or - result = inferCastExprType(n, path) - or - exprHasUnitType(n) and - path.isEmpty() and - result instanceof UnitType - or - isPanicMacroCall(n) and - path.isEmpty() and - result instanceof NeverType - or - n instanceof ClosureExpr and - path.isEmpty() and - result = closureRootType() - or - infersCertainTypeAt(n, path, result.getATypeParameter()) - } - - /** - * Holds if `n` has complete and certain type information at the type path - * `prefix.tp`. This entails that the type at `prefix` must be the type - * that declares `tp`. - */ - pragma[nomagic] - private predicate infersCertainTypeAt(AstNode n, TypePath prefix, TypeParameter tp) { - exists(TypePath path | - exists(inferCertainType(n, path)) and - path.isSnoc(prefix, tp) - ) - } - - /** - * Holds if `n` has complete and certain type information at `path`. - */ - pragma[nomagic] - predicate hasInferredCertainType(AstNode n, TypePath path) { exists(inferCertainType(n, path)) } - - /** - * Holds if `n` having type `t` at `path` conflicts with certain type information - * at `prefix`. - */ - bindingset[n, prefix, path, t] - pragma[inline_late] - predicate certainTypeConflict(AstNode n, TypePath prefix, TypePath path, Type t) { - inferCertainType(n, path) != t - or - // If we infer that `n` has _some_ type at `T1.T2....Tn`, and we also - // know that `n` certainly has type `certainType` at `T1.T2...Ti`, `0 <= i < n`, - // then it must be the case that `T(i+1)` is a type parameter of `certainType`, - // otherwise there is a conflict. - // - // Below, `prefix` is `T1.T2...Ti` and `tp` is `T(i+1)`. - exists(TypePath suffix, TypeParameter tp, Type certainType | - path = prefix.appendInverse(suffix) and - tp = suffix.getHead() and - inferCertainType(n, prefix) = certainType and - not certainType.getATypeParameter() = tp - ) - } +private Type inferCertainStructExprType(StructExpr se, TypePath path) { + result = se.getPath().(TypeMention).getTypeAt(path) } -private Type inferLogicalOperationType(AstNode n, TypePath path) { - exists(Builtins::Bool t, BinaryLogicalOperation be | - n = [be, be.getLhs(), be.getRhs()] and - path.isEmpty() and - result = TDataType(t) - ) +private Type inferCertainStructPatType(StructPat sp, TypePath path) { + result = sp.getPath().(TypeMention).getTypeAt(path) } private Type inferAssignmentOperationType(AstNode n, TypePath path) { @@ -772,171 +930,8 @@ private Struct getRangeType(RangeExpr re) { result instanceof RangeToInclusiveStruct } -private predicate bodyReturns(Expr body, Expr e) { - exists(ReturnExpr re, Callable c | - e = re.getExpr() and - c = re.getEnclosingCallable() and - body = c.getBody() - ) -} - -/** - * Holds if the type tree of `n1` at `prefix1` should be equal to the type tree - * of `n2` at `prefix2` and type information should propagate in both directions - * through the type equality. - */ -private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) { - CertainTypeInference::certainTypeEquality(n1, prefix1, n2, prefix2) - or - prefix1.isEmpty() and - prefix2.isEmpty() and - ( - exists(LetStmt let | - let.getPat() = n1 and - let.getInitializer() = n2 - ) - or - n2 = - any(MatchExpr me | - n1 = me.getAnArm().getExpr() and - me.getNumberOfArms() = 1 - ) - or - exists(LetExpr let | - n1 = let.getScrutinee() and - n2 = let.getPat() - ) - or - exists(MatchExpr me | - n1 = me.getScrutinee() and - n2 = me.getAnArm().getPat() - ) - or - n1 = n2.(OrPat).getAPat() - or - n1 = n2.(ParenPat).getPat() - or - n1 = n2.(LiteralPat).getLiteral() - or - exists(BreakExpr break | - break.getExpr() = n1 and - break.getTarget() = n2.(LoopExpr) - ) - or - exists(AssignmentExpr be | - n1 = be.getLhs() and - n2 = be.getRhs() - ) - or - n1 = n2.(MacroExpr).getMacroCall().getMacroCallExpansion() and - not isPanicMacroCall(n2) - or - n1 = n2.(MacroPat).getMacroCall().getMacroCallExpansion() - or - bodyReturns(n1, n2) and - strictcount(Expr e | bodyReturns(n1, e)) = 1 - ) - or - n2 = - any(RefExpr re | - n1 = re.getExpr() and - prefix1.isEmpty() and - prefix2 = TypePath::singleton(inferRefExprType(re).getPositionalTypeParameter(0)) - ) - or - n2 = - any(RefPat rp | - n1 = rp.getPat() and - prefix1.isEmpty() and - exists(boolean isMutable | if rp.isMut() then isMutable = true else isMutable = false | - prefix2 = TypePath::singleton(getRefTypeParameter(isMutable)) - ) - ) - or - exists(int i, int arity | - prefix1.isEmpty() and - prefix2 = TypePath::singleton(getTupleTypeParameter(arity, i)) - | - arity = n2.(TupleExpr).getNumberOfFields() and - n1 = n2.(TupleExpr).getField(i) - or - arity = n2.(TuplePat).getTupleArity() and - n1 = n2.(TuplePat).getField(i) - ) - or - exists(BlockExpr be | - n1 = be and - n2 = be.getStmtList().getTailExpr() and - if be.isAsync() - then - prefix1 = TypePath::singleton(getDynFutureOutputTypeParameter()) and - prefix2.isEmpty() - else ( - prefix1.isEmpty() and - prefix2.isEmpty() - ) - ) - or - // an array list expression with only one element (such as `[1]`) has type from that element - n1 = - any(ArrayListExpr ale | - ale.getAnExpr() = n2 and - ale.getNumberOfExprs() = 1 - ) and - prefix1 = TypePath::singleton(getArrayTypeParameter()) and - prefix2.isEmpty() - or - // an array repeat expression (`[1; 3]`) has the type of the repeat operand - n1.(ArrayRepeatExpr).getRepeatOperand() = n2 and - prefix1 = TypePath::singleton(getArrayTypeParameter()) and - prefix2.isEmpty() -} - -/** - * Holds if `child` is a child of `parent`, and the Rust compiler applies [least - * upper bound (LUB) coercion][1] to infer the type of `parent` from the type of - * `child`. - * - * In this case, we want type information to only flow from `child` to `parent`, - * to avoid (a) either having to model LUB coercions, or (b) risk combinatorial - * explosion in inferred types. - * - * [1]: https://doc.rust-lang.org/reference/type-coercions.html#r-coerce.least-upper-bound - */ -private predicate lubCoercion(AstNode parent, AstNode child, TypePath prefix) { - child = parent.(IfExpr).getABranch() and - prefix.isEmpty() - or - parent = - any(MatchExpr me | - child = me.getAnArm().getExpr() and - me.getNumberOfArms() > 1 - ) and - prefix.isEmpty() - or - parent = - any(ArrayListExpr ale | - child = ale.getAnExpr() and - ale.getNumberOfExprs() > 1 - ) and - prefix = TypePath::singleton(getArrayTypeParameter()) - or - bodyReturns(parent, child) and - strictcount(Expr e | bodyReturns(parent, e)) > 1 and - prefix.isEmpty() - or - parent = any(ClosureExpr ce | not ce.hasRetType() and ce.getClosureBody() = child) and - prefix = closureReturnPath() - or - exists(Struct s | - child = [parent.(RangeExpr).getStart(), parent.(RangeExpr).getEnd()] and - prefix = TypePath::singleton(TTypeParamTypeParameter(s.getGenericParamList().getATypeParam())) and - s = getRangeType(parent) - ) -} - -private Type inferUnknownTypeFromAnnotation(AstNode n, TypePath path) { - inferType(n, path) = TUnknownType() and +pragma[nomagic] +private Type inferTypeFromAnnotationTopDown(AstNode n, TypePath path) { // Normally, these are coercion sites, but in case a type is unknown we // allow for type information to flow from the type annotation. exists(TypeMention tm | result = tm.getTypeAt(path) | @@ -948,46 +943,6 @@ private Type inferUnknownTypeFromAnnotation(AstNode n, TypePath path) { ) } -/** - * Holds if the type tree of `n1` at `prefix1` should be equal to the type tree - * of `n2` at `prefix2`, but type information should only propagate from `n1` to - * `n2`. - */ -private predicate typeEqualityAsymmetric(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) { - lubCoercion(n2, n1, prefix2) and - prefix1.isEmpty() - or - exists(AstNode mid, TypePath prefixMid, TypePath suffix | - typeEquality(n1, prefixMid, mid, prefix2) or - typeEquality(mid, prefix2, n1, prefixMid) - | - lubCoercion(mid, n2, suffix) and - not lubCoercion(mid, n1, _) and - prefix1 = prefixMid.append(suffix) - ) - or - // When `n2` is `*n1` propagate type information from a raw pointer type - // parameter at `n1`. The other direction is handled in - // `inferDereferencedExprPtrType`. - n1 = n2.(DerefExpr).getExpr() and - prefix1 = TypePath::singleton(getPtrTypeParameter()) and - prefix2.isEmpty() -} - -pragma[nomagic] -private Type inferTypeEquality(AstNode n, TypePath path) { - exists(TypePath prefix1, AstNode n2, TypePath prefix2, TypePath suffix | - result = inferType(n2, prefix2.appendInverse(suffix)) and - path = prefix1.append(suffix) - | - typeEquality(n, prefix1, n2, prefix2) - or - typeEquality(n2, prefix2, n, prefix1) - or - typeEqualityAsymmetric(n2, prefix2, n, prefix1) - ) -} - pragma[nomagic] private TupleType inferTupleRootType(AstNode n) { // `typeEquality` handles the non-root cases @@ -1132,7 +1087,7 @@ private module ContextTyping { * context in which the call appears, for example a call like * `Default::default()`. */ - abstract class ContextTypedCallCand extends AstNode { + abstract class ContextTypedCallCand extends Expr { abstract Type getTypeArgument(TypeArgumentPosition apos, TypePath path); predicate hasTypeArgument(TypeArgumentPosition apos) { exists(this.getTypeArgument(apos, _)) } @@ -1163,53 +1118,6 @@ private module ContextTyping { ) } } - - pragma[nomagic] - private predicate hasUnknownTypeAt(AstNode n, TypePath path) { - inferType(n, path) = TUnknownType() - } - - pragma[nomagic] - private predicate hasUnknownType(AstNode n) { hasUnknownTypeAt(n, _) } - - newtype FunctionPositionKind = - SelfKind() or - ReturnKind() or - PositionalKind() - - signature Type inferCallTypeSig(AstNode n, FunctionPositionKind kind, TypePath path); - - /** - * Given a predicate `inferCallType` for inferring the type of a call at a given - * position, this module exposes the predicate `check`, which wraps the input - * predicate and checks that types are only propagated into arguments when they - * are context-typed. - */ - module CheckContextTyping { - pragma[nomagic] - private Type inferCallNonReturnType( - AstNode n, FunctionPositionKind kind, TypePath prefix, TypePath path - ) { - result = inferCallType(n, kind, path) and - hasUnknownType(n) and - kind != ReturnKind() and - prefix = path.getAPrefix() - } - - pragma[nomagic] - Type check(AstNode n, TypePath path) { - result = inferCallType(n, ReturnKind(), path) - or - exists(FunctionPositionKind kind, TypePath prefix | - result = inferCallNonReturnType(n, kind, prefix, path) and - hasUnknownTypeAt(n, prefix) - | - // Never propagate type information directly into the receiver, since its type - // must already have been known in order to resolve the call - if kind = SelfKind() then not prefix.isEmpty() else any() - ) - } - } } /** @@ -2677,6 +2585,11 @@ private module FunctionCallMatchingInput implements MatchingWithEnvironmentInput FunctionDeclaration getFunction() { result = f } + predicate isFunction(ImplOrTraitItemNodeOption i_, Function f_) { + i_ = i and + f_ = f + } + predicate isAssocFunction(ImplOrTraitItemNode i_, Function f_) { i_ = i.asSome() and f_ = f @@ -2745,7 +2658,9 @@ private module FunctionCallMatchingInput implements MatchingWithEnvironmentInput ) } - abstract class Access extends ContextTyping::ContextTypedCallCand { + final class Access = AccessImpl; + + abstract private class AccessImpl extends ContextTyping::ContextTypedCallCand { abstract AstNode getNodeAt(FunctionPosition pos); bindingset[derefChainBorrow] @@ -2760,7 +2675,7 @@ private module FunctionCallMatchingInput implements MatchingWithEnvironmentInput abstract predicate hasUnknownTypeAt(string derefChainBorrow, FunctionPosition pos, TypePath path); } - private class AssocFunctionCallAccess extends Access instanceof AssocFunctionResolution::AssocFunctionCall + private class AssocFunctionCallAccess extends AccessImpl instanceof AssocFunctionResolution::AssocFunctionCall { AssocFunctionCallAccess() { // handled in the `OperationMatchingInput` module @@ -2847,7 +2762,7 @@ private module FunctionCallMatchingInput implements MatchingWithEnvironmentInput } } - private class NonAssocFunctionCallAccess extends Access instanceof NonAssocCallExpr, + private class NonAssocFunctionCallAccess extends AccessImpl instanceof NonAssocCallExpr, CallExprImpl::CallExprCall { pragma[nomagic] @@ -2900,39 +2815,14 @@ private module FunctionCallMatchingInput implements MatchingWithEnvironmentInput } } -private module FunctionCallMatching = MatchingWithEnvironment; - pragma[nomagic] -private Type inferFunctionCallType0( +private Type inferCallArgumentTypeTopDown( FunctionCallMatchingInput::Access call, FunctionPosition pos, AstNode n, DerefChain derefChain, BorrowKind borrow, TypePath path ) { - exists(TypePath path0 | - n = call.getNodeAt(pos) and - exists(string derefChainBorrow | - FunctionCallMatchingInput::decodeDerefChainBorrow(derefChainBorrow, derefChain, borrow) - | - result = FunctionCallMatching::inferAccessType(call, derefChainBorrow, pos, path0) - or - call.hasUnknownTypeAt(derefChainBorrow, pos, path0) and - result = TUnknownType() - ) - | - if - // index expression `x[i]` desugars to `*x.index(i)`, so we must account for - // the implicit deref - pos.isReturn() and - call instanceof IndexExpr - then path0.isCons(getRefTypeParameter(_), path) - else path = path0 - ) -} - -pragma[nomagic] -private Type inferFunctionCallTypeNonSelf(AstNode n, FunctionPosition pos, TypePath path) { - exists(FunctionCallMatchingInput::Access call | - result = inferFunctionCallType0(call, pos, n, _, _, path) and - not call.(AssocFunctionResolution::AssocFunctionCall).hasReceiverAtPos(pos) + exists(string derefChainBorrow | + FunctionCallMatchingInput::decodeDerefChainBorrow(derefChainBorrow, derefChain, borrow) and + result = M3::inferCallArgumentTypeTopDown(call, derefChainBorrow, pos, n, path) ) } @@ -2944,12 +2834,12 @@ private Type inferFunctionCallTypeNonSelf(AstNode n, FunctionPosition pos, TypeP * empty, at which point the inferred type can be applied back to `n`. */ pragma[nomagic] -private Type inferFunctionCallTypeSelf( +private Type inferFunctionCallSelfArgumentTypeTopDown( FunctionCallMatchingInput::Access call, AstNode n, DerefChain derefChain, TypePath path ) { exists(FunctionPosition pos, BorrowKind borrow, TypePath path0 | call.(AssocFunctionResolution::AssocFunctionCall).hasReceiverAtPos(pos) and - result = inferFunctionCallType0(call, pos, n, derefChain, borrow, path0) + result = inferCallArgumentTypeTopDown(call, pos, n, derefChain, borrow, path0) | borrow.isNoBorrow() and path = path0 @@ -2966,7 +2856,7 @@ private Type inferFunctionCallTypeSelf( DerefChain derefChain0, Type t0, TypePath path0, DerefImplItemNode impl, Type selfParamType, TypePath selfPath | - t0 = inferFunctionCallTypeSelf(call, n, derefChain0, path0) and + t0 = inferFunctionCallSelfArgumentTypeTopDown(call, n, derefChain0, path0) and derefChain0.isCons(impl, derefChain) and selfParamType = impl.resolveSelfTypeAt(selfPath) | @@ -2983,31 +2873,6 @@ private Type inferFunctionCallTypeSelf( ) } -private Type inferFunctionCallTypePreCheck( - AstNode n, ContextTyping::FunctionPositionKind kind, TypePath path -) { - exists(FunctionPosition pos | - result = inferFunctionCallTypeNonSelf(n, pos, path) and - if pos.isPosition() - then kind = ContextTyping::PositionalKind() - else kind = ContextTyping::ReturnKind() - ) - or - exists(FunctionCallMatchingInput::Access a | - result = inferFunctionCallTypeSelf(a, n, DerefChain::nil(), path) and - if a.(AssocFunctionResolution::AssocFunctionCall).hasReceiver() - then kind = ContextTyping::SelfKind() - else kind = ContextTyping::PositionalKind() - ) -} - -/** - * Gets the type of `n` at `path`, where `n` is either a function call or an - * argument/receiver of a function call. - */ -private predicate inferFunctionCallType = - ContextTyping::CheckContextTyping::check/2; - abstract private class Constructor extends Addressable { final TypeParameter getTypeParameter(TypeParameterPosition ppos) { typeParamMatchPosition(this.getTypeItem().getGenericParamList().getATypeParam(), result, ppos) @@ -3141,7 +3006,7 @@ private module ConstructionMatchingInput implements MatchingInputSig { or exists(TypePath suffix | suffix.isCons(TTypeParamTypeParameter(apos.asTypeParam()), path) and - result = CertainTypeInference::inferCertainType(this, suffix) + result = inferTypeCertain(this, suffix) ) } @@ -3166,24 +3031,34 @@ private module ConstructionMatchingInput implements MatchingInputSig { private module ConstructionMatching = Matching; pragma[nomagic] -private Type inferConstructionTypePreCheck( - AstNode n, ContextTyping::FunctionPositionKind kind, TypePath path -) { - exists(ConstructionMatchingInput::Access a, FunctionPosition pos | +private Type inferConstructionType(AstNode n, FunctionPosition pos, TypePath path) { + exists(ConstructionMatchingInput::Access a | n = a.getNodeAt(pos) and - if pos.isPosition() - then kind = ContextTyping::PositionalKind() - else kind = ContextTyping::ReturnKind() - | result = ConstructionMatching::inferAccessType(a, pos, path) - or - a.hasUnknownTypeAt(pos, path) and - result = TUnknownType() ) } -private predicate inferConstructionType = - ContextTyping::CheckContextTyping::check/2; +pragma[nomagic] +private Type inferUnknownType(AstNode n, TypePath path) { + result = TUnknownType() and + ( + exists(FunctionCallMatchingInput::Access call, FunctionPosition pos | + n = call.getNodeAt(pos) and + call.hasUnknownTypeAt(_, pos, path) + ) + or + exists(ConstructionMatchingInput::Access a, FunctionPosition pos | + n = a.getNodeAt(pos) and + a.hasUnknownTypeAt(pos, path) + ) + or + exists(Param p | + not p.hasTypeRepr() and + n = p.getPat() and + path.isEmpty() + ) + ) +} /** * A matching configuration for resolving types of operations like `a + b`. @@ -3248,24 +3123,14 @@ private module OperationMatchingInput implements MatchingInputSig { private module OperationMatching = Matching; pragma[nomagic] -private Type inferOperationTypePreCheck( - AstNode n, ContextTyping::FunctionPositionKind kind, TypePath path -) { - exists(OperationMatchingInput::Access a, FunctionPosition pos | +private Type inferOperationType(AstNode n, FunctionPosition pos, TypePath path) { + exists(OperationMatchingInput::Access a | n = a.getNodeAt(pos) and result = OperationMatching::inferAccessType(a, pos, path) and - if pos.asPosition() = 0 - then kind = ContextTyping::SelfKind() - else - if pos.isPosition() - then kind = ContextTyping::PositionalKind() - else kind = ContextTyping::ReturnKind() + if pos.asPosition() = 0 then not path.isEmpty() else any() ) } -private predicate inferOperationType = - ContextTyping::CheckContextTyping::check/2; - pragma[nomagic] private Type getFieldExprLookupType(FieldExpr fe, string name, DerefChain derefChain) { exists(TypePath path | @@ -3285,6 +3150,20 @@ private Type getFieldExprLookupType(FieldExpr fe, string name, DerefChain derefC ) } +/** + * Gets the struct field that the field expression `fe` resolves to, if any. + */ +cached +StructField resolveStructFieldExpr(FieldExpr fe, DerefChain derefChain) { + M3::CachedStage::ref() and + exists(string name, DataType ty | + ty = getFieldExprLookupType(fe, pragma[only_bind_into](name), derefChain) + | + result = ty.(StructType).getTypeItem().getStructField(pragma[only_bind_into](name)) or + result = ty.(UnionType).getTypeItem().getStructField(pragma[only_bind_into](name)) + ) +} + pragma[nomagic] private Type getTupleFieldExprLookupType(FieldExpr fe, int pos, DerefChain derefChain) { exists(string name | @@ -3293,6 +3172,21 @@ private Type getTupleFieldExprLookupType(FieldExpr fe, int pos, DerefChain deref ) } +/** + * Gets the tuple field that the field expression `fe` resolves to, if any. + */ +cached +TupleField resolveTupleFieldExpr(FieldExpr fe, DerefChain derefChain) { + M3::CachedStage::ref() and + exists(int i | + result = + getTupleFieldExprLookupType(fe, pragma[only_bind_into](i), derefChain) + .(StructType) + .getTypeItem() + .getTupleField(pragma[only_bind_into](i)) + ) +} + /** * A matching configuration for resolving types of field expressions like `x.field`. */ @@ -3739,26 +3633,21 @@ private Type inferForLoopExprType(AstNode n, TypePath path) { } pragma[nomagic] -private Type inferClosureExprType(AstNode n, TypePath path) { +private Type inferClosureExprBodyTypeTopDown(AstNode n, TypePath path) { exists(ClosureExpr ce | - n = ce and - ( - path = TypePath::singleton(TDynTraitTypeParameter(_, any(FnTrait t).getTypeParam())) and - result.(TupleType).getArity() = ce.getNumberOfParams() - or - exists(TypePath path0 | - result = ce.getRetType().getTypeRepr().(TypeMention).getTypeAt(path0) and - path = closureReturnPath().append(path0) - ) - ) - or - exists(Param p | - p = ce.getAParam() and - not p.hasTypeRepr() and - n = p.getPat() and - result = TUnknownType() and - path.isEmpty() - ) + n = ce.getClosureBody() and + result = inferType(ce, closureReturnPath().appendInverse(path)) + ) +} + +pragma[nomagic] +private Type inferClosureExprType(ClosureExpr ce, TypePath path) { + path = TypePath::singleton(TDynTraitTypeParameter(_, any(FnTrait t).getTypeParam())) and + result.(TupleType).getArity() = ce.getNumberOfParams() + or + exists(TypePath suffix | + result = ce.getRetType().getTypeRepr().(TypeMention).getTypeAt(suffix) and + path = closureReturnPath().append(suffix) ) } @@ -3776,170 +3665,50 @@ private Type inferCastExprType(CastExpr ce, TypePath path) { result = ce.getTypeRepr().(TypeMention).getTypeAt(path) } +/** Holds if `n` is implicitly dereferenced and/or borrowed. */ cached -private module Cached { - /** Holds if `n` is implicitly dereferenced and/or borrowed. */ - cached - predicate implicitDerefChainBorrow(Expr e, DerefChain derefChain, boolean borrow) { - exists(BorrowKind bk | - any(AssocFunctionResolution::AssocFunctionCall afc) - .argumentHasImplicitDerefChainBorrow(e, derefChain, bk) and - if bk.isNoBorrow() then borrow = false else borrow = true - ) - or - e = - any(FieldExpr fe | - exists(resolveStructFieldExpr(fe, derefChain)) - or - exists(resolveTupleFieldExpr(fe, derefChain)) - ).getContainer() and - not derefChain.isEmpty() and - borrow = false - } - - /** - * Gets an item (function or tuple struct/variant) that `call` resolves to, if - * any. - * - * The parameter `dispatch` is `true` if and only if the resolved target is a - * trait item because a precise target could not be determined from the - * types (for instance in the presence of generics or `dyn` types) - */ - cached - Addressable resolveCallTarget(InvocationExpr call, boolean dispatch) { - dispatch = false and - result = call.(NonAssocCallExpr).resolveCallTargetViaPathResolution() - or - exists(ImplOrTraitItemNode i | - i instanceof TraitItemNode and dispatch = true - or - i instanceof ImplItemNode and dispatch = false - | - result = call.(AssocFunctionResolution::AssocFunctionCall).resolveCallTarget(i, _, _, _) and - not call instanceof CallExprImpl::DynamicCallExpr and - not i instanceof Builtins::BuiltinImpl - ) - } - - /** - * Gets the struct field that the field expression `fe` resolves to, if any. - */ - cached - StructField resolveStructFieldExpr(FieldExpr fe, DerefChain derefChain) { - exists(string name, DataType ty | - ty = getFieldExprLookupType(fe, pragma[only_bind_into](name), derefChain) - | - result = ty.(StructType).getTypeItem().getStructField(pragma[only_bind_into](name)) or - result = ty.(UnionType).getTypeItem().getStructField(pragma[only_bind_into](name)) - ) - } - - /** - * Gets the tuple field that the field expression `fe` resolves to, if any. - */ - cached - TupleField resolveTupleFieldExpr(FieldExpr fe, DerefChain derefChain) { - exists(int i | - result = - getTupleFieldExprLookupType(fe, pragma[only_bind_into](i), derefChain) - .(StructType) - .getTypeItem() - .getTupleField(pragma[only_bind_into](i)) - ) - } - - /** - * Gets a type at `path` that `n` infers to, if any. - * - * The type inference implementation works by computing all possible types, so - * the result is not necessarily unique. For example, in - * - * ```rust - * trait MyTrait { - * fn foo(&self) -> &Self; - * - * fn bar(&self) -> &Self { - * self.foo() - * } - * } - * - * struct MyStruct; - * - * impl MyTrait for MyStruct { - * fn foo(&self) -> &MyStruct { - * self - * } - * } - * - * fn baz() { - * let x = MyStruct; - * x.bar(); - * } - * ``` - * - * the type inference engine will roughly make the following deductions: - * - * 1. `MyStruct` has type `MyStruct`. - * 2. `x` has type `MyStruct` (via 1.). - * 3. The return type of `bar` is `&Self`. - * 3. `x.bar()` has type `&MyStruct` (via 2 and 3, by matching the implicit `Self` - * type parameter with `MyStruct`.). - * 4. The return type of `bar` is `&MyTrait`. - * 5. `x.bar()` has type `&MyTrait` (via 2 and 4). - */ - cached - Type inferType(AstNode n, TypePath path) { - Stages::TypeInferenceStage::ref() and - result = CertainTypeInference::inferCertainType(n, path) - or - // Don't propagate type information into a node which conflicts with certain - // type information. - forall(TypePath prefix | - CertainTypeInference::hasInferredCertainType(n, prefix) and - prefix.isPrefixOf(path) - | - not CertainTypeInference::certainTypeConflict(n, prefix, path, result) - ) and - ( - result = inferAssignmentOperationType(n, path) - or - result = inferTypeEquality(n, path) - or - result = inferFunctionCallType(n, path) - or - result = inferConstructionType(n, path) - or - result = inferOperationType(n, path) - or - result = inferFieldExprType(n, path) - or - result = inferTryExprType(n, path) - or - result = inferLiteralType(n, path, false) - or - result = inferAwaitExprType(n, path) - or - result = inferDereferencedExprPtrType(n, path) - or - result = inferForLoopExprType(n, path) - or - result = inferClosureExprType(n, path) - or - result = inferArgList(n, path) - or - result = inferDeconstructionPatType(n, path) +predicate implicitDerefChainBorrow(Expr e, DerefChain derefChain, boolean borrow) { + M3::CachedStage::ref() and + exists(BorrowKind bk | + any(AssocFunctionResolution::AssocFunctionCall afc) + .argumentHasImplicitDerefChainBorrow(e, derefChain, bk) and + if bk.isNoBorrow() then borrow = false else borrow = true + ) + or + e = + any(FieldExpr fe | + exists(resolveStructFieldExpr(fe, derefChain)) or - result = inferUnknownTypeFromAnnotation(n, path) - ) - } + exists(resolveTupleFieldExpr(fe, derefChain)) + ).getContainer() and + not derefChain.isEmpty() and + borrow = false } -import Cached - /** - * Gets a type that `n` infers to, if any. + * Gets an item (function or tuple struct/variant) that `call` resolves to, if + * any. + * + * The parameter `dispatch` is `true` if and only if the resolved target is a + * trait item because a precise target could not be determined from the + * types (for instance in the presence of generics or `dyn` types) */ -Type inferType(AstNode n) { result = inferType(n, TypePath::nil()) } +cached +Addressable resolveCallTarget(InvocationExpr call, boolean dispatch) { + M3::CachedStage::ref() and + dispatch = false and + result = call.(NonAssocCallExpr).resolveCallTargetViaPathResolution() + or + exists(ImplOrTraitItemNode i | + i instanceof TraitItemNode and dispatch = true + or + i instanceof ImplItemNode and dispatch = false + | + result = call.(AssocFunctionResolution::AssocFunctionCall).resolveCallTarget(i, _, _, _) and + not call instanceof CallExprImpl::DynamicCallExpr and + not i instanceof Builtins::BuiltinImpl + ) +} /** Provides predicates for debugging the type inference implementation. */ private module Debug { @@ -3973,26 +3742,11 @@ private module Debug { t = self.getTypeAt(path) } - predicate debugInferFunctionCallType(AstNode n, TypePath path, Type t) { - n = getRelevantLocatable() and - t = inferFunctionCallType(n, path) - } - - predicate debugInferConstructionType(AstNode n, TypePath path, Type t) { - n = getRelevantLocatable() and - t = inferConstructionType(n, path) - } - predicate debugTypeMention(TypeMention tm, TypePath path, Type type) { tm = getRelevantLocatable() and tm.getTypeAt(path) = type } - Type debugInferAnnotatedType(AstNode n, TypePath path) { - n = getRelevantLocatable() and - result = inferAnnotatedType(n, path) - } - pragma[nomagic] private int countTypesAtPath(AstNode n, TypePath path, Type t) { t = inferType(n, path) and @@ -4041,9 +3795,9 @@ private module Debug { c = max(countTypePaths(_, _, _)) } - Type debugInferCertainType(AstNode n, TypePath path) { + Type debugInferTypeCertain(AstNode n, TypePath path) { n = getRelevantLocatable() and - result = CertainTypeInference::inferCertainType(n, path) + result = inferTypeCertain(n, path) } Type debugInferCertainNonUniqueType(AstNode n, TypePath path) { diff --git a/rust/ql/test/library-tests/dataflow/models/CONSISTENCY/PathResolutionConsistency.expected b/rust/ql/test/library-tests/dataflow/models/CONSISTENCY/PathResolutionConsistency.expected new file mode 100644 index 000000000000..cfad81d2796a --- /dev/null +++ b/rust/ql/test/library-tests/dataflow/models/CONSISTENCY/PathResolutionConsistency.expected @@ -0,0 +1,2 @@ +multipleResolvedTargets +| main.rs:218:20:218:25 | ... != ... | diff --git a/rust/ql/test/library-tests/type-inference/type-inference.expected b/rust/ql/test/library-tests/type-inference/type-inference.expected index 3344fc45f74f..94b98d92f7da 100644 --- a/rust/ql/test/library-tests/type-inference/type-inference.expected +++ b/rust/ql/test/library-tests/type-inference/type-inference.expected @@ -8961,10 +8961,8 @@ inferType | main.rs:826:16:826:16 | 3 | | {EXTERNAL LOCATION} | i32 | | main.rs:826:16:826:20 | ... > ... | | {EXTERNAL LOCATION} | bool | | main.rs:826:20:826:20 | 2 | | {EXTERNAL LOCATION} | i32 | -| main.rs:826:22:828:13 | { ... } | | main.rs:820:20:820:22 | Tr2 | | main.rs:827:17:827:20 | self | | {EXTERNAL LOCATION} | & | | main.rs:827:17:827:20 | self | TRef | main.rs:820:5:832:5 | Self [trait MyTrait2] | -| main.rs:827:17:827:25 | self.m1() | | main.rs:820:20:820:22 | Tr2 | | main.rs:828:20:830:13 | { ... } | | main.rs:820:20:820:22 | Tr2 | | main.rs:829:17:829:31 | ...::m1(...) | | main.rs:820:20:820:22 | Tr2 | | main.rs:829:26:829:30 | * ... | | main.rs:820:5:832:5 | Self [trait MyTrait2] | @@ -11481,13 +11479,9 @@ inferType | main.rs:2099:13:2103:13 | if value {...} else {...} | | {EXTERNAL LOCATION} | i64 | | main.rs:2099:16:2099:20 | value | | {EXTERNAL LOCATION} | bool | | main.rs:2099:22:2101:13 | { ... } | | {EXTERNAL LOCATION} | i32 | -| main.rs:2099:22:2101:13 | { ... } | | {EXTERNAL LOCATION} | i64 | | main.rs:2100:17:2100:17 | 1 | | {EXTERNAL LOCATION} | i32 | -| main.rs:2100:17:2100:17 | 1 | | {EXTERNAL LOCATION} | i64 | | main.rs:2101:20:2103:13 | { ... } | | {EXTERNAL LOCATION} | i32 | -| main.rs:2101:20:2103:13 | { ... } | | {EXTERNAL LOCATION} | i64 | | main.rs:2102:17:2102:17 | 0 | | {EXTERNAL LOCATION} | i32 | -| main.rs:2102:17:2102:17 | 0 | | {EXTERNAL LOCATION} | i64 | | main.rs:2113:19:2113:22 | SelfParam | | main.rs:2107:5:2107:19 | S | | main.rs:2113:19:2113:22 | SelfParam | T | main.rs:2109:10:2109:17 | T | | main.rs:2113:25:2113:29 | other | | main.rs:2107:5:2107:19 | S | @@ -11542,13 +11536,9 @@ inferType | main.rs:2154:13:2158:13 | if value {...} else {...} | | {EXTERNAL LOCATION} | i64 | | main.rs:2154:16:2154:20 | value | | {EXTERNAL LOCATION} | bool | | main.rs:2154:22:2156:13 | { ... } | | {EXTERNAL LOCATION} | i32 | -| main.rs:2154:22:2156:13 | { ... } | | {EXTERNAL LOCATION} | i64 | | main.rs:2155:17:2155:17 | 1 | | {EXTERNAL LOCATION} | i32 | -| main.rs:2155:17:2155:17 | 1 | | {EXTERNAL LOCATION} | i64 | | main.rs:2156:20:2158:13 | { ... } | | {EXTERNAL LOCATION} | i32 | -| main.rs:2156:20:2158:13 | { ... } | | {EXTERNAL LOCATION} | i64 | | main.rs:2157:17:2157:17 | 0 | | {EXTERNAL LOCATION} | i32 | -| main.rs:2157:17:2157:17 | 0 | | {EXTERNAL LOCATION} | i64 | | main.rs:2164:21:2164:25 | value | | main.rs:2162:19:2162:19 | T | | main.rs:2164:31:2164:31 | x | | main.rs:2162:5:2165:5 | Self [trait MyFrom2] | | main.rs:2169:21:2169:25 | value | | {EXTERNAL LOCATION} | i64 | @@ -11710,9 +11700,7 @@ inferType | main.rs:2265:21:2265:31 | [...] | TArray | {EXTERNAL LOCATION} | u8 | | main.rs:2265:22:2265:24 | 1u8 | | {EXTERNAL LOCATION} | u8 | | main.rs:2265:27:2265:27 | 2 | | {EXTERNAL LOCATION} | i32 | -| main.rs:2265:27:2265:27 | 2 | | {EXTERNAL LOCATION} | u8 | | main.rs:2265:30:2265:30 | 3 | | {EXTERNAL LOCATION} | i32 | -| main.rs:2265:30:2265:30 | 3 | | {EXTERNAL LOCATION} | u8 | | main.rs:2266:9:2266:25 | for ... in ... { ... } | | {EXTERNAL LOCATION} | () | | main.rs:2266:13:2266:13 | u | | {EXTERNAL LOCATION} | i32 | | main.rs:2266:13:2266:13 | u | | {EXTERNAL LOCATION} | u8 | @@ -11738,11 +11726,8 @@ inferType | main.rs:2271:31:2271:39 | [...] | TArray | {EXTERNAL LOCATION} | i32 | | main.rs:2271:31:2271:39 | [...] | TArray | {EXTERNAL LOCATION} | u32 | | main.rs:2271:32:2271:32 | 1 | | {EXTERNAL LOCATION} | i32 | -| main.rs:2271:32:2271:32 | 1 | | {EXTERNAL LOCATION} | u32 | | main.rs:2271:35:2271:35 | 2 | | {EXTERNAL LOCATION} | i32 | -| main.rs:2271:35:2271:35 | 2 | | {EXTERNAL LOCATION} | u32 | | main.rs:2271:38:2271:38 | 3 | | {EXTERNAL LOCATION} | i32 | -| main.rs:2271:38:2271:38 | 3 | | {EXTERNAL LOCATION} | u32 | | main.rs:2272:9:2272:25 | for ... in ... { ... } | | {EXTERNAL LOCATION} | () | | main.rs:2272:13:2272:13 | u | | {EXTERNAL LOCATION} | u32 | | main.rs:2272:18:2272:22 | vals3 | | {EXTERNAL LOCATION} | [;] | @@ -11883,7 +11868,6 @@ inferType | main.rs:2308:19:2308:25 | 0u8..10 | Idx | {EXTERNAL LOCATION} | i32 | | main.rs:2308:19:2308:25 | 0u8..10 | Idx | {EXTERNAL LOCATION} | u8 | | main.rs:2308:24:2308:25 | 10 | | {EXTERNAL LOCATION} | i32 | -| main.rs:2308:24:2308:25 | 10 | | {EXTERNAL LOCATION} | u8 | | main.rs:2308:28:2308:29 | { ... } | | {EXTERNAL LOCATION} | () | | main.rs:2309:13:2309:17 | range | | {EXTERNAL LOCATION} | Range | | main.rs:2309:13:2309:17 | range | Idx | {EXTERNAL LOCATION} | i32 | @@ -12636,11 +12620,9 @@ inferType | main.rs:2583:12:2583:12 | b | | {EXTERNAL LOCATION} | bool | | main.rs:2583:14:2586:9 | { ... } | | {EXTERNAL LOCATION} | Box | | main.rs:2583:14:2586:9 | { ... } | A | {EXTERNAL LOCATION} | Global | -| main.rs:2583:14:2586:9 | { ... } | T | main.rs:2547:5:2549:5 | dyn MyTrait | | main.rs:2583:14:2586:9 | { ... } | T | main.rs:2551:5:2552:19 | S | | main.rs:2583:14:2586:9 | { ... } | T.T | main.rs:2551:5:2552:19 | S | | main.rs:2583:14:2586:9 | { ... } | T.T.T | {EXTERNAL LOCATION} | i32 | -| main.rs:2583:14:2586:9 | { ... } | T.dyn(T) | {EXTERNAL LOCATION} | i32 | | main.rs:2584:17:2584:17 | x | | main.rs:2551:5:2552:19 | S | | main.rs:2584:17:2584:17 | x | T | main.rs:2551:5:2552:19 | S | | main.rs:2584:17:2584:17 | x | T.T | {EXTERNAL LOCATION} | i32 | @@ -12651,26 +12633,20 @@ inferType | main.rs:2584:21:2584:26 | x.m2() | T.T | {EXTERNAL LOCATION} | i32 | | main.rs:2585:13:2585:23 | ...::new(...) | | {EXTERNAL LOCATION} | Box | | main.rs:2585:13:2585:23 | ...::new(...) | A | {EXTERNAL LOCATION} | Global | -| main.rs:2585:13:2585:23 | ...::new(...) | T | main.rs:2547:5:2549:5 | dyn MyTrait | | main.rs:2585:13:2585:23 | ...::new(...) | T | main.rs:2551:5:2552:19 | S | | main.rs:2585:13:2585:23 | ...::new(...) | T.T | main.rs:2551:5:2552:19 | S | | main.rs:2585:13:2585:23 | ...::new(...) | T.T.T | {EXTERNAL LOCATION} | i32 | -| main.rs:2585:13:2585:23 | ...::new(...) | T.dyn(T) | {EXTERNAL LOCATION} | i32 | | main.rs:2585:22:2585:22 | x | | main.rs:2551:5:2552:19 | S | | main.rs:2585:22:2585:22 | x | T | main.rs:2551:5:2552:19 | S | | main.rs:2585:22:2585:22 | x | T.T | {EXTERNAL LOCATION} | i32 | | main.rs:2586:16:2588:9 | { ... } | | {EXTERNAL LOCATION} | Box | | main.rs:2586:16:2588:9 | { ... } | A | {EXTERNAL LOCATION} | Global | -| main.rs:2586:16:2588:9 | { ... } | T | main.rs:2547:5:2549:5 | dyn MyTrait | | main.rs:2586:16:2588:9 | { ... } | T | main.rs:2551:5:2552:19 | S | | main.rs:2586:16:2588:9 | { ... } | T.T | {EXTERNAL LOCATION} | i32 | -| main.rs:2586:16:2588:9 | { ... } | T.dyn(T) | {EXTERNAL LOCATION} | i32 | | main.rs:2587:13:2587:23 | ...::new(...) | | {EXTERNAL LOCATION} | Box | | main.rs:2587:13:2587:23 | ...::new(...) | A | {EXTERNAL LOCATION} | Global | -| main.rs:2587:13:2587:23 | ...::new(...) | T | main.rs:2547:5:2549:5 | dyn MyTrait | | main.rs:2587:13:2587:23 | ...::new(...) | T | main.rs:2551:5:2552:19 | S | | main.rs:2587:13:2587:23 | ...::new(...) | T.T | {EXTERNAL LOCATION} | i32 | -| main.rs:2587:13:2587:23 | ...::new(...) | T.dyn(T) | {EXTERNAL LOCATION} | i32 | | main.rs:2587:22:2587:22 | x | | main.rs:2551:5:2552:19 | S | | main.rs:2587:22:2587:22 | x | T | {EXTERNAL LOCATION} | i32 | | main.rs:2593:22:2597:5 | { ... } | | {EXTERNAL LOCATION} | () | diff --git a/rust/ql/test/library-tests/type-inference/type-inference.ql b/rust/ql/test/library-tests/type-inference/type-inference.ql index 8dcc34ad8001..374884ec4574 100644 --- a/rust/ql/test/library-tests/type-inference/type-inference.ql +++ b/rust/ql/test/library-tests/type-inference/type-inference.ql @@ -12,7 +12,7 @@ private predicate relevantNode(AstNode n) { } query predicate inferCertainType(AstNode n, TypePath path, Type t) { - t = TypeInference::CertainTypeInference::inferCertainType(n, path) and + t = TypeInference::inferTypeCertain(n, path) and t != TUnknownType() and relevantNode(n) } @@ -70,7 +70,7 @@ module TypeTest implements TestSig { ( tag = "type" or - t = TypeInference::CertainTypeInference::inferCertainType(n, path) and + t = TypeInference::inferTypeCertain(n, path) and tag = "certainType" ) and location = n.getLocation() and diff --git a/shared/typeinference/codeql/typeinference/internal/TypeInference.qll b/shared/typeinference/codeql/typeinference/internal/TypeInference.qll index cf82d77b5e1d..0f646ab3552c 100644 --- a/shared/typeinference/codeql/typeinference/internal/TypeInference.qll +++ b/shared/typeinference/codeql/typeinference/internal/TypeInference.qll @@ -146,17 +146,25 @@ signature module InputSig1 { } /** - * Holds if `t` is a pseudo type. Pseudo types are skipped when checking for - * non-instantiations in `isNotInstantiationOf`. + * A special pseudo type used to represent cases where the actual type needs + * to be inferred from the context in a top-down manner. For example, in + * + * ```rust + * let x = Vec::new(); + * x.push(42); + * ``` + * + * the element type of `x` is assigned an unknown type, which allows for type + * information to flow into `x` from the call to `push`. */ - predicate isPseudoType(Type t); + class UnknownType extends Type; /** A type parameter. */ class TypeParameter extends Type; /** - * A type abstraction. I.e., a place in the program where type variables are - * introduced. + * A type abstraction. I.e., a place in the program where type variables may + * be introduced. * * Example in C#: * ```csharp @@ -171,7 +179,7 @@ signature module InputSig1 { * ``` */ class TypeAbstraction { - /** Gets a type parameter introduced by this abstraction. */ + /** Gets a type parameter introduced by this abstraction, if any. */ TypeParameter getATypeParameter(); /** Gets a textual representation of this type abstraction. */ @@ -332,6 +340,8 @@ module Make1 Input1> { * code. For example, in * * ```csharp + * class Base { } + * * class C : Base, Interface { } * ``` * @@ -341,7 +351,7 @@ module Make1 Input1> { * `TypePath` | `Type` * ---------- | ------- * `""` | ``Base`1`` - * `"0"` | `T` + * `"B"` | `T` */ signature module InputSig2 { /** @@ -666,7 +676,8 @@ module Make1 Input1> { } private Type getNonPseudoTypeAt(App app, TypePath path) { - result = app.getTypeAt(path) and not isPseudoType(result) + result = app.getTypeAt(path) and + not result instanceof UnknownType } pragma[nomagic] @@ -2127,5 +2138,698 @@ module Make1 Input1> { not exists(tm.getTypeAt(TypePath::nil())) and exists(tm.getLocation()) } } + + /** + * Provides the input to `Make3`. + * + * TODO: Eventually align the AST signature with that of the shared CFG library. + */ + signature module InputSig3 { + /** + * A predicate used to reference cached predicates that should be included to the + * cached stage of type inference. Such predicates should themselves reference + * `CachedStage::ref`. + */ + default predicate cachedStageRevRef() { none() } + + /** + * Point this predicate to the `inferType` predicate from the output of this module. + * + * Needed to be able to refer to `inferType` in default signature implementations. + */ + Type inferType(AstNode n, TypePath path); + + /** A boolean type. */ + class BoolType extends Type; + + /** An AST node. */ + class AstNode { + /** Gets a textual representation of this AST node. */ + string toString(); + + /** Gets the location of this AST node. */ + Location getLocation(); + } + + /** Gets the type annotation that applies to `n`, if any. */ + TypeMention getTypeAnnotation(AstNode n); + + /** An expression. */ + class Expr extends AstNode; + + /** + * A switch. + */ + class Switch extends AstNode { + /** + * Gets the expression being switched on. + */ + Expr getExpr(); + + /** Gets the case at the specified (zero-based) `index`. */ + Case getCase(int index); + } + + /** A case in a switch. */ + class Case extends AstNode { + /** Gets a pattern being matched by this case. */ + AstNode getAPattern(); + + /** Gets the body of this case. */ + AstNode getBody(); + } + + /** A ternary conditional expression. */ + class ConditionalExpr extends Expr { + /** Gets the condition of this expression. */ + Expr getCondition(); + + /** Gets the true branch of this expression. */ + Expr getThen(); + + /** Gets the false branch of this expression. */ + Expr getElse(); + } + + /** A binary expression. */ + class BinaryExpr extends Expr { + /** Gets the left operand of this binary expression. */ + Expr getLeftOperand(); + + /** Gets the right operand of this binary expression. */ + Expr getRightOperand(); + } + + /** A short-circuiting logical AND expression. */ + class LogicalAndExpr extends BinaryExpr; + + /** A short-circuiting logical OR expression. */ + class LogicalOrExpr extends BinaryExpr; + + /** + * An assignment expression, either compound or simple. + * + * Examples: + * + * ``` + * x = y + * sum += element + * ``` + */ + class Assignment extends BinaryExpr; + + /** A simple assignment expression, for example `x = y`. */ + class AssignExpr extends Assignment; + + /** A parenthesized expression. */ + class ParenExpr extends Expr { + Expr getExpr(); + } + + /** A variable, for example a local variable or a field. */ + class Variable { + /** Gets the AST node that defines this variable. */ + AstNode getDefiningNode(); + + /** Gets an access to this variable. */ + Expr getAnAccess(); + + /** Gets a textual representation of this element. */ + string toString(); + + /** Gets the location of this element. */ + Location getLocation(); + } + + /** + * A `let` declaration, for example a local variable declaration. + */ + class LetDeclaration extends AstNode { + /** + * Holds if this declaration is a coercion site, meaning that the type of the right + * operand may have to be coerced to the type of the left operand. + */ + predicate isCoercionSite(); + + /** Gets the left operand of this declaration. */ + AstNode getLeftOperand(); + + /** Gets the right operand of this declaration. */ + AstNode getRightOperand(); + } + + /** + * A position where a callable can have a declared type and a call can have + * an inferred type. + */ + class TypePosition { + /** Holds if this position represents the return type of a callable. */ + predicate isReturn(); + + /** Gets a textual representation of this position. */ + string toString(); + } + + /** + * A context needed to resolve calls. + * + * For example, in Rust, we need an additional context to represent the + * candidate receiver type when resolving method calls. + * + * When not used, simply instantiate this class with `Unit`. + */ + bindingset[this] + class CallResolutionContext { + /** Gets a textual representation of this context. */ + bindingset[this] + string toString(); + } + + /** A callable. */ + class Callable { + /** Gets the type parameter at position `ppos` of this callable, if any. */ + TypeParameter getTypeParameter(TypeParameterPosition ppos); + + /** + * Gets an additional type parameter constraint for the given type parameter, + * which applies to this callable. For example, in Rust, a function can apply + * additional constraints on type parameters belonging to the `impl` block + * that the function is defined in. + */ + TypeMention getAdditionalTypeParameterConstraint(TypeParameter tp); + + /** Gets the declared type of this callable at `path` for position `pos`. */ + Type getDeclaredType(TypePosition pos, TypePath path); + + /** Gets a textual representation of this callable. */ + string toString(); + + /** Gets the location of this callable. */ + Location getLocation(); + } + + /** A call expression. */ + class Call extends Expr { + /** Gets the explicit type argument at position `apos` and `path` for this call, if any. */ + Type getTypeArgument(TypeArgumentPosition apos, TypePath path); + + /** Gets the AST node corresponding to the position `pos` of this call. */ + AstNode getNodeAt(TypePosition pos); + + /** + * Gets the target of this call, to be used when inferring certain types. + */ + Callable getTargetCertain(); + + /** Gets the target of this call in the given context. */ + Callable getTarget(CallResolutionContext ctx); + } + + /** + * Gets the inferred type of `call` at `path` and position `pos` in context `ctx`. + * + * By default, this is the inferred type of the node at the given position, but + * in for example Rust, the inferred type of the receiver of a method call needs + * to take the call context into account, in order to use the correct candidate + * receiver type. + * + * The type information provided by this predicate is used to derive type information + * about the call via the call target, such as the return type. + */ + bindingset[ctx] + default Type inferCallTypeBottomUp( + Call call, CallResolutionContext ctx, TypePosition pos, TypePath path + ) { + result = inferType(call.getNodeAt(pos), path) and + exists(ctx) + } + + /** + * Gets the inferred return type of `call` at `path`. + * + * When no post-processing is needed, simply implement this predicate as + * `result = inferCallReturnType(_, _, n, path)`. + */ + Type inferCallReturnType(AstNode n, TypePath path); + + /** + * Gets the top-down inferred type of `call` at `path` and argument position + * `pos`. + * + * This predicate is used to propagate type information from the call target + * into call arguments, for example when an implicitly typed lambda is passed + * as an argument. + * + * Type information is only propagated into arguments with an explicitly unknown + * type. + * + * When no call-context based post-processing is needed, simply implement this + * predicate as `result = inferCallArgumentTypeTopDown(_, _, _, n, path)`. + */ + Type inferCallArgumentTypeTopDown(AstNode n, TypePath path); + + /** + * Holds if `n1` having certain type `t` at `path1` implies that `n2` has + * certain type `t` at `path2`, but not necessarily the other way around. + */ + default predicate inferStepCertain(AstNode n1, TypePath path1, AstNode n2, TypePath path2) { + none() + } + + /** + * Holds if `n1` having certain type `t` at `path1` implies that `n2` has + * certain type `t` at `path2`, and vice versa. + */ + default predicate inferStepSymmetricCertain( + AstNode n1, TypePath path1, AstNode n2, TypePath path2 + ) { + none() + } + + /** + * Gets the inferred certain type of `n` at `path`. + * + * This predicate will be included directly in the exposed `inferTypeCertain` predicate. + */ + default Type inferTypeCertainSpecific(AstNode n, TypePath path) { none() } + + /** + * Holds if `n1` having type `t` at `path1` implies that `n2` has type `t` at `path2`, + * but not necessarily the other way around. + */ + predicate inferStep(AstNode n1, TypePath path1, AstNode n2, TypePath path2); + + /** + * Holds if `n1` having type `t` at `path1` implies that `n2` has type `t` at `path2`, + * and vice versa. + */ + predicate inferStepSymmetric(AstNode n1, TypePath path1, AstNode n2, TypePath path2); + + /** + * Holds if `n1` having type `t` at `path1` implies that `n2` has a type `lub` at + * `path2`, where `lub` is a least-upper-bound of the types of all the nodes that + * have lub steps into `n2`. + * + * For example, for a ternary conditional expression, there are lub steps from each + * of the branches into the conditional expression itself. + * + * We don't actually model the least-upper-bound computation, instead we interpret + * `inferLubStep(n1, path1, n2, path2)` as + * + * - `inferStep(n1, path1, n2, path2)`, that is type information flows directly into + * the lub, and + * - `inferStep(n2, path2, n1, path1)`, provided that `n1` is unique, that is, type + * type information flows from the lub back into the unique input `n1`, and + * - type information is allowed to flow from the lub into any of its inputs, provided + * that they have an explicitly unknown type. + */ + default predicate inferLubStep(AstNode n1, TypePath path1, AstNode n2, TypePath path2) { + none() + } + + /** + * Gets the top-down inferred type of `n` at `path`. + * + * Type information is only propagated into nodes with an explicitly unknown + * type. + */ + default Type inferTypeTopDown(AstNode n, TypePath path) { none() } + + /** + * Gets the inferred type of `n` at `path`. + * + * This predicate will be included directly in the exposed `inferType` predicate. + */ + Type inferTypeSpecific(AstNode n, TypePath path); + } + + module Make3 { + private import Input3 + + /** Provides logic for inferring certain type information. */ + private module Certain { + /** Gets the type of `n`, which has an explicit type annotation. */ + pragma[nomagic] + Type inferAnnotatedType(AstNode n, TypePath path) { + result = getTypeAnnotation(n).getTypeAt(path) + } + + private predicate stepSymmetricCertain( + AstNode n1, TypePath path1, AstNode n2, TypePath path2 + ) { + path1.isEmpty() and + path2.isEmpty() and + ( + exists(Variable v | n1 = v.getAnAccess() and n2 = v.getDefiningNode()) + or + exists(LetDeclaration let | + not let.isCoercionSite() and + n1 = let.getLeftOperand() and + n2 = let.getRightOperand() + ) + or + n1 = n2.(ParenExpr).getExpr() + ) + or + inferStepSymmetricCertain(n1, path1, n2, path2) + } + + predicate stepCertain(AstNode n1, TypePath path1, AstNode n2, TypePath path2) { + stepSymmetricCertain(n1, path1, n2, path2) + or + stepSymmetricCertain(n2, path2, n1, path1) + or + inferStepCertain(n1, path1, n2, path2) + } + + pragma[nomagic] + private Type inferTypeFromStepCertain(AstNode n, TypePath path) { + exists(TypePath prefix1, AstNode n2, TypePath prefix2, TypePath suffix | + result = inferTypeCertain(n2, prefix2.appendInverse(suffix)) and + path = prefix1.append(suffix) and + stepCertain(n2, prefix2, n, prefix1) + ) + } + + private Type inferLogicalOperationType(AstNode n, TypePath path) { + ( + exists(LogicalAndExpr lae | n = [lae, lae.getLeftOperand(), lae.getRightOperand()]) or + exists(LogicalOrExpr loe | n = [loe, loe.getLeftOperand(), loe.getRightOperand()]) + ) and + result instanceof BoolType and + path.isEmpty() + } + + pragma[nomagic] + private Type getCertainCallExprReturnType(Call call, TypePath path) { + exists(TypePosition ret | + ret.isReturn() and + forex(Callable target | target = call.getTargetCertain() | + result = target.getDeclaredType(ret, path) + ) + ) + } + + pragma[nomagic] + private Type inferCertainCallExprReturnType(Call call, TypePath path) { + exists(Type ty, TypePath prefix | ty = getCertainCallExprReturnType(call, prefix) | + exists( + Callable target, TypePath suffix, TypeParameterPosition tppos, + TypeArgumentPosition tapos + | + ty = target.getTypeParameter(tppos) and + path = prefix.append(suffix) and + result = call.getTypeArgument(tapos, suffix) and + typeArgumentParameterPositionMatch(tapos, tppos) + ) + or + not ty instanceof TypeParameter and + result = ty and + path = prefix + ) + } + + /** Gets the inferred certain type of `n` at `path`. */ + cached + Type inferTypeCertain(AstNode n, TypePath path) { + CachedStage::ref() and + result = inferAnnotatedType(n, path) + or + result = inferTypeFromStepCertain(n, path) + or + result = inferTypeCertainSpecific(n, path) + or + result = inferLogicalOperationType(n, path) + or + result = inferCertainCallExprReturnType(n, path) + or + infersCertainTypeAt(n, path, result.getATypeParameter()) + } + + /** + * Holds if `n` has complete and certain type information at the type path + * `prefix.tp`. This entails that the type at `prefix` must be the type + * that declares `tp`. + */ + pragma[nomagic] + private predicate infersCertainTypeAt(AstNode n, TypePath prefix, TypeParameter tp) { + exists(TypePath path | + exists(inferTypeCertain(n, path)) and + path.isSnoc(prefix, tp) + ) + } + + /** + * Holds if `n` has complete and certain type information at `path`. + */ + pragma[nomagic] + predicate hasInferredCertainType(AstNode n, TypePath path) { + exists(inferTypeCertain(n, path)) + } + + /** + * Holds if `n` having type `t` at `path` conflicts with certain type information + * at `prefix`. + */ + bindingset[n, prefix, path, t] + pragma[inline_late] + predicate certainTypeConflict(AstNode n, TypePath prefix, TypePath path, Type t) { + inferTypeCertain(n, path) != t + or + // If we infer that `n` has _some_ type at `T1.T2....Tn`, and we also + // know that `n` certainly has type `certainType` at `T1.T2...Ti`, `0 <= i < n`, + // then it must be the case that `T(i+1)` is a type parameter of `certainType`, + // otherwise there is a conflict. + // + // Below, `prefix` is `T1.T2...Ti` and `tp` is `T(i+1)`. + exists(TypePath suffix, TypeParameter tp, Type certainType | + path = prefix.appendInverse(suffix) and + tp = suffix.getHead() and + inferTypeCertain(n, prefix) = certainType and + not certainType.getATypeParameter() = tp + ) + } + } + + predicate inferTypeCertain = Certain::inferTypeCertain/2; + + private predicate lubStep(AstNode n1, TypePath path1, AstNode n2, TypePath path2) { + path1.isEmpty() and + path2.isEmpty() and + ( + n1 = n2.(Switch).getCase(_).getBody() + or + n2 = any(ConditionalExpr ce | n1 = [ce.getThen(), ce.getElse()]) + ) + or + inferLubStep(n1, path1, n2, path2) + } + + private predicate stepSymmetric(AstNode n1, TypePath path1, AstNode n2, TypePath path2) { + path1.isEmpty() and + path2.isEmpty() and + ( + exists(AssignExpr ae | + ae.getLeftOperand() = n1 and + ae.getRightOperand() = n2 + ) + or + exists(LetDeclaration let | + let.getLeftOperand() = n1 and + let.getRightOperand() = n2 + ) + or + exists(Switch switch | + n1 = switch.getExpr() and + n2 = switch.getCase(_).getAPattern() + ) + ) + or + inferStepSymmetric(n1, path1, n2, path2) + } + + private predicate step(AstNode n1, TypePath path1, AstNode n2, TypePath path2) { + inferStep(n1, path1, n2, path2) + or + stepSymmetric(n1, path1, n2, path2) + or + stepSymmetric(n2, path2, n1, path1) + or + Certain::stepCertain(n1, path1, n2, path2) + or + lubStep(n1, path1, n2, path2) + or + n2 = unique(AstNode n | lubStep(n, _, n1, _) | n) and + lubStep(n2, path2, n1, path1) + } + + pragma[nomagic] + private Type inferTypeFromStep(AstNode n, TypePath path) { + exists(TypePath prefix1, AstNode n2, TypePath prefix2, TypePath suffix | + result = inferType(n2, prefix2.appendInverse(suffix)) and + path = prefix1.append(suffix) and + step(n2, prefix2, n, prefix1) + ) + } + + pragma[nomagic] + private Type inferTypeFromLubStepTopDown(AstNode n, TypePath path) { + exists(TypePath prefix1, AstNode n2, TypePath prefix2, TypePath suffix | + result = inferType(n2, prefix2.appendInverse(suffix)) and + path = prefix1.append(suffix) and + lubStep(n, prefix1, n2, prefix2) + ) + } + + /** + * Gets the inferred type of `n` at `path`. + */ + cached + Type inferType(AstNode n, TypePath path) { + CachedStage::ref() and + result = inferTypeCertain(n, path) + or + // Don't propagate type information into a node which conflicts with certain + // type information. + forall(TypePath prefix | + Certain::hasInferredCertainType(n, prefix) and + prefix.isPrefixOf(path) + | + not Certain::certainTypeConflict(n, prefix, path, result) + ) and + ( + result = inferTypeFromStep(n, path) + or + result = TopDownTyping::inferType(n, path) + or + result = inferCallReturnType(n, path) + or + result = TopDownTyping::inferType(n, path) + or + result = TopDownTyping::inferType(n, path) + or + result = inferTypeSpecific(n, path) + ) + } + + private module TypePositionMatchingInput { + class DeclarationPosition = TypePosition; + + class AccessPosition = DeclarationPosition; + + predicate accessDeclarationPositionMatch(AccessPosition apos, DeclarationPosition dpos) { + apos = dpos + } + } + + /** + * A matching configuration for resolving types of calls. + */ + private module CallMatchingInput implements MatchingWithEnvironmentInputSig { + import TypePositionMatchingInput + + class Declaration = Callable; + + bindingset[decl] + TypeMention getATypeParameterConstraint(TypeParameter tp, Declaration decl) { + result = Input2::getATypeParameterConstraint(tp) and + exists(decl) + or + result = decl.getAdditionalTypeParameterConstraint(tp) + } + + class AccessEnvironment = CallResolutionContext; + + final private class CallFinal = Call; + + class Access extends CallFinal { + bindingset[e] + Type getInferredType(AccessEnvironment e, AccessPosition apos, TypePath path) { + result = inferCallTypeBottomUp(this, e, apos, path) + } + } + } + + private module CallMatching = MatchingWithEnvironment; + + private Type inferCallType( + Call call, CallResolutionContext ctx, TypePosition pos, AstNode n, TypePath path + ) { + n = call.getNodeAt(pos) and + result = CallMatching::inferAccessType(call, ctx, pos, path) + } + + Type inferCallReturnType(Call call, CallResolutionContext ctx, AstNode n, TypePath path) { + exists(TypePosition pos | + result = inferCallType(call, ctx, pos, n, path) and + pos.isReturn() + ) + } + + Type inferCallArgumentTypeTopDown( + Call call, CallResolutionContext ctx, TypePosition pos, AstNode n, TypePath path + ) { + result = inferCallType(call, ctx, pos, n, path) and + not pos.isReturn() and + hasUnknownType(n) + } + + pragma[nomagic] + private predicate hasUnknownTypeAt(AstNode n, TypePath path) { + inferType(n, path) instanceof UnknownType + } + + pragma[nomagic] + private predicate hasUnknownType(AstNode n) { hasUnknownTypeAt(n, _) } + + private signature Type inferTypeTopDownSig(AstNode n, TypePath path); + + /** + * Given a predicate `infer` for inferring the type of an AST node `n` + * top-down from a context, this module exposes the predicate `inferType`, which + * restricts type information to only flow top-down into `n` when `n` has an + * explicit unknown type. + */ + private module TopDownTyping { + pragma[nomagic] + private Type inferTypeTopDown(AstNode n, TypePath prefix, TypePath path) { + result = infer(n, path) and + hasUnknownType(n) and + prefix = path.getAPrefix() + } + + pragma[nomagic] + Type inferType(AstNode n, TypePath path) { + exists(TypePath prefix | + result = inferTypeTopDown(n, prefix, path) and + hasUnknownTypeAt(n, prefix) + ) + } + } + + /** + * Gets the inferred root type of `n`, if any. + */ + Type inferType(AstNode n) { result = inferType(n, TypePath::nil()) } + + // todo: consistency checks + /** The cached stage of type inference. */ + cached + module CachedStage { + /** Reference to the cached stage of type inference. */ + cached + predicate ref() { any() } + + /** Reverse references to the predicates that reference `ref()`. */ + cached + predicate revRef() { + (exists(inferTypeCertain(_, _)) implies any()) + or + (exists(inferType(_, _)) implies any()) + or + cachedStageRevRef() + } + } + } } }