Skip to content

Commit 05e9261

Browse files
committed
Rust: Use Call in type inference
1 parent ea2005c commit 05e9261

File tree

1 file changed

+35
-180
lines changed

1 file changed

+35
-180
lines changed

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 35 additions & 180 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ private import TypeMention
88
private import codeql.typeinference.internal.TypeInference
99
private import codeql.rust.frameworks.stdlib.Stdlib
1010
private import codeql.rust.frameworks.stdlib.Bultins as Builtins
11+
private import codeql.rust.elements.Call
1112

1213
class Type = T::Type;
1314

@@ -488,28 +489,25 @@ private Type inferPathExprType(PathExpr pe, TypePath path) {
488489
* like `foo::bar(baz)` and `foo.bar(baz)`.
489490
*/
490491
private module CallExprBaseMatchingInput implements MatchingInputSig {
491-
private predicate paramPos(ParamList pl, Param p, int pos, boolean inMethod) {
492-
p = pl.getParam(pos) and
493-
if pl.hasSelfParam() then inMethod = true else inMethod = false
494-
}
492+
private predicate paramPos(ParamList pl, Param p, int pos) { p = pl.getParam(pos) }
495493

496494
private newtype TDeclarationPosition =
497495
TSelfDeclarationPosition() or
498-
TPositionalDeclarationPosition(int pos, boolean inMethod) { paramPos(_, _, pos, inMethod) } or
496+
TPositionalDeclarationPosition(int pos) { paramPos(_, _, pos) } or
499497
TReturnDeclarationPosition()
500498

501499
class DeclarationPosition extends TDeclarationPosition {
502500
predicate isSelf() { this = TSelfDeclarationPosition() }
503501

504-
int asPosition(boolean inMethod) { this = TPositionalDeclarationPosition(result, inMethod) }
502+
int asPosition() { this = TPositionalDeclarationPosition(result) }
505503

506504
predicate isReturn() { this = TReturnDeclarationPosition() }
507505

508506
string toString() {
509507
this.isSelf() and
510508
result = "self"
511509
or
512-
result = this.asPosition(_).toString()
510+
result = this.asPosition().toString()
513511
or
514512
this.isReturn() and
515513
result = "(return)"
@@ -542,7 +540,7 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
542540
override Type getParameterType(DeclarationPosition dpos, TypePath path) {
543541
exists(int pos |
544542
result = this.getTupleField(pos).getTypeRepr().(TypeMention).resolveTypeAt(path) and
545-
dpos = TPositionalDeclarationPosition(pos, false)
543+
dpos = TPositionalDeclarationPosition(pos)
546544
)
547545
}
548546

@@ -565,7 +563,7 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
565563
override Type getParameterType(DeclarationPosition dpos, TypePath path) {
566564
exists(int p |
567565
result = this.getTupleField(p).getTypeRepr().(TypeMention).resolveTypeAt(path) and
568-
dpos = TPositionalDeclarationPosition(p, false)
566+
dpos = TPositionalDeclarationPosition(p)
569567
)
570568
}
571569

@@ -598,9 +596,9 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
598596
}
599597

600598
override Type getParameterType(DeclarationPosition dpos, TypePath path) {
601-
exists(Param p, int i, boolean inMethod |
602-
paramPos(this.getParamList(), p, i, inMethod) and
603-
dpos = TPositionalDeclarationPosition(i, inMethod) and
599+
exists(Param p, int i |
600+
paramPos(this.getParamList(), p, i) and
601+
dpos = TPositionalDeclarationPosition(i) and
604602
result = inferAnnotatedType(p.getPat(), path)
605603
)
606604
or
@@ -632,125 +630,44 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
632630
}
633631
}
634632

635-
private predicate argPos(CallExprBase call, Expr e, int pos, boolean isMethodCall) {
636-
exists(ArgList al |
637-
e = al.getArg(pos) and
638-
call.getArgList() = al and
639-
if call instanceof MethodCallExpr then isMethodCall = true else isMethodCall = false
640-
)
641-
}
642-
643-
private newtype TAccessPosition =
644-
TSelfAccessPosition() or
645-
TPositionalAccessPosition(int pos, boolean isMethodCall) { argPos(_, _, pos, isMethodCall) } or
646-
TReturnAccessPosition()
647-
648-
class AccessPosition extends TAccessPosition {
649-
predicate isSelf() { this = TSelfAccessPosition() }
650-
651-
int asPosition(boolean isMethodCall) { this = TPositionalAccessPosition(result, isMethodCall) }
652-
653-
predicate isReturn() { this = TReturnAccessPosition() }
654-
655-
string toString() {
656-
this.isSelf() and
657-
result = "self"
658-
or
659-
result = this.asPosition(_).toString()
660-
or
661-
this.isReturn() and
662-
result = "(return)"
663-
}
664-
}
633+
class AccessPosition = DeclarationPosition;
665634

666635
private import codeql.rust.elements.internal.CallExprImpl::Impl as CallExprImpl
667636

668-
abstract class Access extends Expr {
669-
abstract Type getTypeArgument(TypeArgumentPosition apos, TypePath path);
670-
671-
abstract AstNode getNodeAt(AccessPosition apos);
672-
673-
abstract Type getInferredType(AccessPosition apos, TypePath path);
674-
675-
abstract Declaration getTarget();
676-
}
677-
678-
private class CallExprBaseAccess extends Access instanceof CallExprBase {
679-
private TypeMention getMethodTypeArg(int i) {
680-
result = this.(MethodCallExpr).getGenericArgList().getTypeArg(i)
681-
}
682-
683-
override Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
637+
final class Access extends Call {
638+
Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
684639
exists(TypeMention arg | result = arg.resolveTypeAt(path) |
685640
arg = getExplicitTypeArgMention(CallExprImpl::getFunctionPath(this), apos.asTypeParam())
686641
or
687-
arg = this.getMethodTypeArg(apos.asMethodTypeArgumentPosition())
642+
arg =
643+
this.(MethodCallExpr).getGenericArgList().getTypeArg(apos.asMethodTypeArgumentPosition())
688644
)
689645
}
690646

691-
override AstNode getNodeAt(AccessPosition apos) {
692-
exists(int p, boolean isMethodCall |
693-
argPos(this, result, p, isMethodCall) and
694-
apos = TPositionalAccessPosition(p, isMethodCall)
695-
)
647+
AstNode getNodeAt(AccessPosition apos) {
648+
result = this.getArgument(apos.asPosition())
696649
or
697-
result = this.(MethodCallExpr).getReceiver() and
698-
apos = TSelfAccessPosition()
650+
result = this.getReceiver() and apos.isSelf()
699651
or
700-
result = this and
701-
apos = TReturnAccessPosition()
652+
result = this and apos.isReturn()
702653
}
703654

704-
override Type getInferredType(AccessPosition apos, TypePath path) {
655+
Type getInferredType(AccessPosition apos, TypePath path) {
705656
result = inferType(this.getNodeAt(apos), path)
706657
}
707658

708-
override Declaration getTarget() {
709-
result = CallExprImpl::getResolvedFunction(this)
710-
or
659+
Declaration getTarget() {
711660
result = inferMethodCallTarget(this) // mutual recursion; resolving method calls requires resolving types and vice versa
712-
}
713-
}
714-
715-
private class OperationAccess extends Access instanceof Operation {
716-
OperationAccess() { super.isOverloaded(_, _) }
717-
718-
override Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
719-
// The syntax for operators does not allow type arguments.
720-
none()
721-
}
722-
723-
override AstNode getNodeAt(AccessPosition apos) {
724-
result = super.getOperand(0) and apos = TSelfAccessPosition()
725-
or
726-
result = super.getOperand(1) and apos = TPositionalAccessPosition(0, true)
727661
or
728-
result = this and apos = TReturnAccessPosition()
729-
}
730-
731-
override Type getInferredType(AccessPosition apos, TypePath path) {
732-
result = inferType(this.getNodeAt(apos), path)
733-
}
734-
735-
override Declaration getTarget() {
736-
result = inferMethodCallTarget(this) // mutual recursion; resolving method calls requires resolving types and vice versa
662+
result = CallExprImpl::getResolvedFunction(this)
737663
}
738664
}
739665

740666
predicate accessDeclarationPositionMatch(AccessPosition apos, DeclarationPosition dpos) {
741667
apos.isSelf() and
742668
dpos.isSelf()
743669
or
744-
exists(int pos, boolean isMethodCall | pos = apos.asPosition(isMethodCall) |
745-
pos = 0 and
746-
isMethodCall = false and
747-
dpos.isSelf()
748-
or
749-
isMethodCall = false and
750-
pos = dpos.asPosition(true) + 1
751-
or
752-
pos = dpos.asPosition(isMethodCall)
753-
)
670+
apos.asPosition() = dpos.asPosition()
754671
or
755672
apos.isReturn() and
756673
dpos.isReturn()
@@ -1128,91 +1045,29 @@ private Type inferAwaitExprType(AstNode n, TypePath path) {
11281045
)
11291046
}
11301047

1131-
private module MethodCall {
1132-
/** An expression that calls a method. */
1133-
abstract private class MethodCallImpl extends Expr {
1134-
/** Gets the name of the method targeted. */
1135-
abstract string getMethodName();
1136-
1137-
/** Gets the number of arguments _excluding_ the `self` argument. */
1138-
abstract int getArity();
1139-
1140-
/** Gets the trait targeted by this method call, if any. */
1141-
Trait getTrait() { none() }
1142-
1143-
/** Gets the type of the receiver of the method call at `path`. */
1144-
abstract Type getTypeAt(TypePath path);
1048+
final class MethodCall extends Call {
1049+
MethodCall() {
1050+
exists(this.getReceiver()) and
1051+
// We want the method calls that don't have a path to a concrete method in
1052+
// an impl block. We need to exclude calls like `MyType::my_method(..)`.
1053+
(this instanceof CallExpr implies exists(this.getTrait()))
11451054
}
11461055

1147-
final class MethodCall = MethodCallImpl;
1148-
1149-
private class MethodCallExprMethodCall extends MethodCallImpl instanceof MethodCallExpr {
1150-
override string getMethodName() { result = super.getIdentifier().getText() }
1151-
1152-
override int getArity() { result = super.getArgList().getNumberOfArgs() }
1153-
1154-
pragma[nomagic]
1155-
override Type getTypeAt(TypePath path) {
1056+
/** Gets the type of the receiver of the method call at `path`. */
1057+
Type getTypeAt(TypePath path) {
1058+
if this.receiverImplicitlyBorrowed()
1059+
then
11561060
exists(TypePath path0 | result = inferType(super.getReceiver(), path0) |
11571061
path0.isCons(TRefTypeParameter(), path)
11581062
or
11591063
not path0.isCons(TRefTypeParameter(), _) and
11601064
not (path0.isEmpty() and result = TRefType()) and
11611065
path = path0
11621066
)
1163-
}
1164-
}
1165-
1166-
private class CallExprMethodCall extends MethodCallImpl instanceof CallExpr {
1167-
TraitItemNode trait;
1168-
string methodName;
1169-
Expr receiver;
1170-
1171-
CallExprMethodCall() {
1172-
receiver = this.getArg(0) and
1173-
exists(Path path, Function f |
1174-
path = this.getFunction().(PathExpr).getPath() and
1175-
f = resolvePath(path) and
1176-
f.getParamList().hasSelfParam() and
1177-
trait = resolvePath(path.getQualifier()) and
1178-
trait.getAnAssocItem() = f and
1179-
path.getSegment().getIdentifier().getText() = methodName
1180-
)
1181-
}
1182-
1183-
override string getMethodName() { result = methodName }
1184-
1185-
override int getArity() { result = super.getArgList().getNumberOfArgs() - 1 }
1186-
1187-
override Trait getTrait() { result = trait }
1188-
1189-
pragma[nomagic]
1190-
override Type getTypeAt(TypePath path) { result = inferType(receiver, path) }
1191-
}
1192-
1193-
private class OperationMethodCall extends MethodCallImpl instanceof Operation {
1194-
TraitItemNode trait;
1195-
string methodName;
1196-
1197-
OperationMethodCall() { super.isOverloaded(trait, methodName) }
1198-
1199-
override string getMethodName() { result = methodName }
1200-
1201-
override int getArity() { result = this.(Operation).getNumberOfOperands() - 1 }
1202-
1203-
override Trait getTrait() { result = trait }
1204-
1205-
pragma[nomagic]
1206-
override Type getTypeAt(TypePath path) {
1207-
result = inferType(this.(BinaryExpr).getLhs(), path)
1208-
or
1209-
result = inferType(this.(PrefixExpr).getExpr(), path)
1210-
}
1067+
else result = inferType(super.getReceiver(), path)
12111068
}
12121069
}
12131070

1214-
import MethodCall
1215-
12161071
/**
12171072
* Holds if a method for `type` with the name `name` and the arity `arity`
12181073
* exists in `impl`.
@@ -1241,7 +1096,7 @@ private module IsInstantiationOfInput implements IsInstantiationOfInputSig<Metho
12411096
private predicate isMethodCall(MethodCall mc, Type rootType, string name, int arity) {
12421097
rootType = mc.getTypeAt(TypePath::nil()) and
12431098
name = mc.getMethodName() and
1244-
arity = mc.getArity()
1099+
arity = mc.getNumberOfArguments()
12451100
}
12461101

12471102
pragma[nomagic]

0 commit comments

Comments
 (0)