Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ class StreamCipherInit extends Cryptography::CryptographicOperation::Range {
// extract the algorithm name from the type of `ce` or its receiver.
exists(Type t, TypePath tp |
t = inferType([ce, ce.(MethodCallExpr).getReceiver()], tp) and
rawAlgorithmName =
t.(StructType).asItemNode().(Addressable).getCanonicalPath().splitAt("::")
rawAlgorithmName = t.(StructType).getStruct().(Addressable).getCanonicalPath().splitAt("::")
) and
algorithmName = simplifyAlgorithmName(rawAlgorithmName) and
// only match a known cryptographic algorithm
Expand Down
82 changes: 7 additions & 75 deletions rust/ql/lib/codeql/rust/internal/Type.qll
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,6 @@ private predicate implTraitTypeParam(ImplTraitTypeRepr implTrait, int i, TypePar
* types, such as traits and implementation blocks.
*/
abstract class Type extends TType {
/** Gets the struct field `name` belonging to this type, if any. */
pragma[nomagic]
abstract StructField getStructField(string name);

/** Gets the `i`th tuple field belonging to this type, if any. */
pragma[nomagic]
abstract TupleField getTupleField(int i);

/**
* Gets the `i`th positional type parameter of this type, if any.
*
Expand Down Expand Up @@ -117,10 +109,6 @@ class TupleType extends Type, TTuple {

TupleType() { this = TTuple(arity) }

override StructField getStructField(string name) { none() }

override TupleField getTupleField(int i) { none() }

override TypeParameter getPositionalTypeParameter(int i) {
result = TTupleTypeParameter(arity, i)
}
Expand All @@ -140,21 +128,14 @@ class UnitType extends TupleType {
override string toString() { result = "()" }
}

abstract private class StructOrEnumType extends Type {
abstract ItemNode asItemNode();
}

/** A struct type. */
class StructType extends StructOrEnumType, TStruct {
class StructType extends Type, TStruct {
private Struct struct;

StructType() { this = TStruct(struct) }

override ItemNode asItemNode() { result = struct }

override StructField getStructField(string name) { result = struct.getStructField(name) }

override TupleField getTupleField(int i) { result = struct.getTupleField(i) }
/** Gets the struct that this struct type represents. */
Struct getStruct() { result = struct }

override TypeParameter getPositionalTypeParameter(int i) {
result = TTypeParamTypeParameter(struct.getGenericParamList().getTypeParam(i))
Expand All @@ -170,17 +151,11 @@ class StructType extends StructOrEnumType, TStruct {
}

/** An enum type. */
class EnumType extends StructOrEnumType, TEnum {
class EnumType extends Type, TEnum {
private Enum enum;

EnumType() { this = TEnum(enum) }

override ItemNode asItemNode() { result = enum }

override StructField getStructField(string name) { none() }

override TupleField getTupleField(int i) { none() }

override TypeParameter getPositionalTypeParameter(int i) {
result = TTypeParamTypeParameter(enum.getGenericParamList().getTypeParam(i))
}
Expand All @@ -203,10 +178,6 @@ class TraitType extends Type, TTrait {
/** Gets the underlying trait. */
Trait getTrait() { result = trait }

override StructField getStructField(string name) { none() }

override TupleField getTupleField(int i) { none() }

override TypeParameter getPositionalTypeParameter(int i) {
result = TTypeParamTypeParameter(trait.getGenericParamList().getTypeParam(i))
}
Expand All @@ -229,16 +200,13 @@ class TraitType extends Type, TTrait {
}

/** A union type. */
class UnionType extends StructOrEnumType, TUnion {
class UnionType extends Type, TUnion {
private Union union;

UnionType() { this = TUnion(union) }

override ItemNode asItemNode() { result = union }

override StructField getStructField(string name) { result = union.getStructField(name) }

override TupleField getTupleField(int i) { none() }
/** Gets the union that this union type represents. */
Union getUnion() { result = union }

override TypeParameter getPositionalTypeParameter(int i) {
result = TTypeParamTypeParameter(union.getGenericParamList().getTypeParam(i))
Expand All @@ -262,10 +230,6 @@ class UnionType extends StructOrEnumType, TUnion {
class ArrayType extends Type, TArrayType {
ArrayType() { this = TArrayType() }

override StructField getStructField(string name) { none() }

override TupleField getTupleField(int i) { none() }

override TypeParameter getPositionalTypeParameter(int i) {
result = TArrayTypeParameter() and
i = 0
Expand All @@ -285,10 +249,6 @@ class ArrayType extends Type, TArrayType {
class RefType extends Type, TRefType {
RefType() { this = TRefType() }

override StructField getStructField(string name) { none() }

override TupleField getTupleField(int i) { none() }

override TypeParameter getPositionalTypeParameter(int i) {
result = TRefTypeParameter() and
i = 0
Expand Down Expand Up @@ -318,10 +278,6 @@ class ImplTraitType extends Type, TImplTraitType {
/** Gets the function that this `impl Trait` belongs to. */
abstract Function getFunction();

override StructField getStructField(string name) { none() }

override TupleField getTupleField(int i) { none() }

override TypeParameter getPositionalTypeParameter(int i) {
exists(TypeParam tp |
implTraitTypeParam(impl, i, tp) and
Expand All @@ -339,10 +295,6 @@ class DynTraitType extends Type, TDynTraitType {

DynTraitType() { this = TDynTraitType(trait) }

override StructField getStructField(string name) { none() }

override TupleField getTupleField(int i) { none() }

override DynTraitTypeParameter getPositionalTypeParameter(int i) {
result = TDynTraitTypeParameter(trait.getGenericParamList().getTypeParam(i))
}
Expand Down Expand Up @@ -389,10 +341,6 @@ class ImplTraitReturnType extends ImplTraitType {
class SliceType extends Type, TSliceType {
SliceType() { this = TSliceType() }

override StructField getStructField(string name) { none() }

override TupleField getTupleField(int i) { none() }

override TypeParameter getPositionalTypeParameter(int i) {
result = TSliceTypeParameter() and
i = 0
Expand All @@ -404,10 +352,6 @@ class SliceType extends Type, TSliceType {
}

class NeverType extends Type, TNeverType {
override StructField getStructField(string name) { none() }

override TupleField getTupleField(int i) { none() }

override TypeParameter getPositionalTypeParameter(int i) { none() }

override string toString() { result = "!" }
Expand All @@ -416,10 +360,6 @@ class NeverType extends Type, TNeverType {
}

class PtrType extends Type, TPtrType {
override StructField getStructField(string name) { none() }

override TupleField getTupleField(int i) { none() }

override TypeParameter getPositionalTypeParameter(int i) {
i = 0 and
result = TPtrTypeParameter()
Expand All @@ -432,10 +372,6 @@ class PtrType extends Type, TPtrType {

/** A type parameter. */
abstract class TypeParameter extends Type {
override StructField getStructField(string name) { none() }

override TupleField getTupleField(int i) { none() }

override TypeParameter getPositionalTypeParameter(int i) { none() }
}

Expand Down Expand Up @@ -634,10 +570,6 @@ class ImplTraitTypeTypeParameter extends ImplTraitType, TypeParameter {

override Function getFunction() { result = function }

override StructField getStructField(string name) { none() }

override TupleField getTupleField(int i) { none() }

override TypeParameter getPositionalTypeParameter(int i) { none() }
}

Expand Down
17 changes: 11 additions & 6 deletions rust/ql/lib/codeql/rust/internal/TypeInference.qll
Original file line number Diff line number Diff line change
Expand Up @@ -1173,8 +1173,8 @@ private Type inferCallExprBaseType(AstNode n, TypePath path) {
path = TypePath::cons(TRefTypeParameter(), path0)
else (
not (
argType.(StructType).asItemNode() instanceof StringStruct and
result.(StructType).asItemNode() instanceof Builtins::Str
argType.(StructType).getStruct() instanceof StringStruct and
result.(StructType).getStruct() instanceof Builtins::Str
) and
(
not path0.isCons(TRefTypeParameter(), _) and
Expand Down Expand Up @@ -1889,8 +1889,8 @@ final class MethodCall extends Call {
//
// See also https://doc.rust-lang.org/reference/expressions/method-call-expr.html#r-expr.method.autoref-deref
path.isEmpty() and
t0.(StructType).asItemNode() instanceof StringStruct and
result.(StructType).asItemNode() instanceof Builtins::Str
t0.(StructType).getStruct() instanceof StringStruct and
result.(StructType).getStruct() instanceof Builtins::Str
)
else result = this.getReceiverTypeAt(path)
}
Expand Down Expand Up @@ -2518,15 +2518,20 @@ private module Cached {
*/
cached
StructField resolveStructFieldExpr(FieldExpr fe) {
exists(string name | result = getFieldExprLookupType(fe, name).getStructField(name))
exists(string name, Type ty | ty = getFieldExprLookupType(fe, name) |
result = ty.(StructType).getStruct().getStructField(name) or
result = ty.(UnionType).getUnion().getStructField(name)
)
}

/**
* Gets the tuple field that the field expression `fe` resolves to, if any.
*/
cached
TupleField resolveTupleFieldExpr(FieldExpr fe) {
exists(int i | result = getTupleFieldExprLookupType(fe, i).getTupleField(i))
exists(int i |
result = getTupleFieldExprLookupType(fe, i).(StructType).getStruct().getTupleField(i)
)
}

/**
Expand Down