diff --git a/rust/ql/lib/codeql/rust/frameworks/rustcrypto/RustCrypto.qll b/rust/ql/lib/codeql/rust/frameworks/rustcrypto/RustCrypto.qll index 51d00f795d7e..123824b3d696 100644 --- a/rust/ql/lib/codeql/rust/frameworks/rustcrypto/RustCrypto.qll +++ b/rust/ql/lib/codeql/rust/frameworks/rustcrypto/RustCrypto.qll @@ -31,8 +31,7 @@ class StreamCipherInit extends Cryptography::CryptographicOperation::Range { // extract the algorithm name from the type of `ce` or its receiver. exists(Type t, TypePath tp | t = inferType([ce, ce.(MethodCallExpr).getReceiver()], tp) and - rawAlgorithmName = - t.(StructType).asItemNode().(Addressable).getCanonicalPath().splitAt("::") + rawAlgorithmName = t.(StructType).getStruct().(Addressable).getCanonicalPath().splitAt("::") ) and algorithmName = simplifyAlgorithmName(rawAlgorithmName) and // only match a known cryptographic algorithm diff --git a/rust/ql/lib/codeql/rust/internal/Type.qll b/rust/ql/lib/codeql/rust/internal/Type.qll index 29e6ed283bc6..9dc15e31d996 100644 --- a/rust/ql/lib/codeql/rust/internal/Type.qll +++ b/rust/ql/lib/codeql/rust/internal/Type.qll @@ -78,14 +78,6 @@ private predicate implTraitTypeParam(ImplTraitTypeRepr implTrait, int i, TypePar * types, such as traits and implementation blocks. */ abstract class Type extends TType { - /** Gets the struct field `name` belonging to this type, if any. */ - pragma[nomagic] - abstract StructField getStructField(string name); - - /** Gets the `i`th tuple field belonging to this type, if any. */ - pragma[nomagic] - abstract TupleField getTupleField(int i); - /** * Gets the `i`th positional type parameter of this type, if any. * @@ -117,10 +109,6 @@ class TupleType extends Type, TTuple { TupleType() { this = TTuple(arity) } - override StructField getStructField(string name) { none() } - - override TupleField getTupleField(int i) { none() } - override TypeParameter getPositionalTypeParameter(int i) { result = TTupleTypeParameter(arity, i) } @@ -140,21 +128,14 @@ class UnitType extends TupleType { override string toString() { result = "()" } } -abstract private class StructOrEnumType extends Type { - abstract ItemNode asItemNode(); -} - /** A struct type. */ -class StructType extends StructOrEnumType, TStruct { +class StructType extends Type, TStruct { private Struct struct; StructType() { this = TStruct(struct) } - override ItemNode asItemNode() { result = struct } - - override StructField getStructField(string name) { result = struct.getStructField(name) } - - override TupleField getTupleField(int i) { result = struct.getTupleField(i) } + /** Gets the struct that this struct type represents. */ + Struct getStruct() { result = struct } override TypeParameter getPositionalTypeParameter(int i) { result = TTypeParamTypeParameter(struct.getGenericParamList().getTypeParam(i)) @@ -170,17 +151,11 @@ class StructType extends StructOrEnumType, TStruct { } /** An enum type. */ -class EnumType extends StructOrEnumType, TEnum { +class EnumType extends Type, TEnum { private Enum enum; EnumType() { this = TEnum(enum) } - override ItemNode asItemNode() { result = enum } - - override StructField getStructField(string name) { none() } - - override TupleField getTupleField(int i) { none() } - override TypeParameter getPositionalTypeParameter(int i) { result = TTypeParamTypeParameter(enum.getGenericParamList().getTypeParam(i)) } @@ -203,10 +178,6 @@ class TraitType extends Type, TTrait { /** Gets the underlying trait. */ Trait getTrait() { result = trait } - override StructField getStructField(string name) { none() } - - override TupleField getTupleField(int i) { none() } - override TypeParameter getPositionalTypeParameter(int i) { result = TTypeParamTypeParameter(trait.getGenericParamList().getTypeParam(i)) } @@ -229,16 +200,13 @@ class TraitType extends Type, TTrait { } /** A union type. */ -class UnionType extends StructOrEnumType, TUnion { +class UnionType extends Type, TUnion { private Union union; UnionType() { this = TUnion(union) } - override ItemNode asItemNode() { result = union } - - override StructField getStructField(string name) { result = union.getStructField(name) } - - override TupleField getTupleField(int i) { none() } + /** Gets the union that this union type represents. */ + Union getUnion() { result = union } override TypeParameter getPositionalTypeParameter(int i) { result = TTypeParamTypeParameter(union.getGenericParamList().getTypeParam(i)) @@ -262,10 +230,6 @@ class UnionType extends StructOrEnumType, TUnion { class ArrayType extends Type, TArrayType { ArrayType() { this = TArrayType() } - override StructField getStructField(string name) { none() } - - override TupleField getTupleField(int i) { none() } - override TypeParameter getPositionalTypeParameter(int i) { result = TArrayTypeParameter() and i = 0 @@ -285,10 +249,6 @@ class ArrayType extends Type, TArrayType { class RefType extends Type, TRefType { RefType() { this = TRefType() } - override StructField getStructField(string name) { none() } - - override TupleField getTupleField(int i) { none() } - override TypeParameter getPositionalTypeParameter(int i) { result = TRefTypeParameter() and i = 0 @@ -318,10 +278,6 @@ class ImplTraitType extends Type, TImplTraitType { /** Gets the function that this `impl Trait` belongs to. */ abstract Function getFunction(); - override StructField getStructField(string name) { none() } - - override TupleField getTupleField(int i) { none() } - override TypeParameter getPositionalTypeParameter(int i) { exists(TypeParam tp | implTraitTypeParam(impl, i, tp) and @@ -339,10 +295,6 @@ class DynTraitType extends Type, TDynTraitType { DynTraitType() { this = TDynTraitType(trait) } - override StructField getStructField(string name) { none() } - - override TupleField getTupleField(int i) { none() } - override DynTraitTypeParameter getPositionalTypeParameter(int i) { result = TDynTraitTypeParameter(trait.getGenericParamList().getTypeParam(i)) } @@ -389,10 +341,6 @@ class ImplTraitReturnType extends ImplTraitType { class SliceType extends Type, TSliceType { SliceType() { this = TSliceType() } - override StructField getStructField(string name) { none() } - - override TupleField getTupleField(int i) { none() } - override TypeParameter getPositionalTypeParameter(int i) { result = TSliceTypeParameter() and i = 0 @@ -404,10 +352,6 @@ class SliceType extends Type, TSliceType { } class NeverType extends Type, TNeverType { - override StructField getStructField(string name) { none() } - - override TupleField getTupleField(int i) { none() } - override TypeParameter getPositionalTypeParameter(int i) { none() } override string toString() { result = "!" } @@ -416,10 +360,6 @@ class NeverType extends Type, TNeverType { } class PtrType extends Type, TPtrType { - override StructField getStructField(string name) { none() } - - override TupleField getTupleField(int i) { none() } - override TypeParameter getPositionalTypeParameter(int i) { i = 0 and result = TPtrTypeParameter() @@ -432,10 +372,6 @@ class PtrType extends Type, TPtrType { /** A type parameter. */ abstract class TypeParameter extends Type { - override StructField getStructField(string name) { none() } - - override TupleField getTupleField(int i) { none() } - override TypeParameter getPositionalTypeParameter(int i) { none() } } @@ -634,10 +570,6 @@ class ImplTraitTypeTypeParameter extends ImplTraitType, TypeParameter { override Function getFunction() { result = function } - override StructField getStructField(string name) { none() } - - override TupleField getTupleField(int i) { none() } - override TypeParameter getPositionalTypeParameter(int i) { none() } } diff --git a/rust/ql/lib/codeql/rust/internal/TypeInference.qll b/rust/ql/lib/codeql/rust/internal/TypeInference.qll index 1f987d6572ea..c9dbf0bac13b 100644 --- a/rust/ql/lib/codeql/rust/internal/TypeInference.qll +++ b/rust/ql/lib/codeql/rust/internal/TypeInference.qll @@ -1173,8 +1173,8 @@ private Type inferCallExprBaseType(AstNode n, TypePath path) { path = TypePath::cons(TRefTypeParameter(), path0) else ( not ( - argType.(StructType).asItemNode() instanceof StringStruct and - result.(StructType).asItemNode() instanceof Builtins::Str + argType.(StructType).getStruct() instanceof StringStruct and + result.(StructType).getStruct() instanceof Builtins::Str ) and ( not path0.isCons(TRefTypeParameter(), _) and @@ -1889,8 +1889,8 @@ final class MethodCall extends Call { // // See also https://doc.rust-lang.org/reference/expressions/method-call-expr.html#r-expr.method.autoref-deref path.isEmpty() and - t0.(StructType).asItemNode() instanceof StringStruct and - result.(StructType).asItemNode() instanceof Builtins::Str + t0.(StructType).getStruct() instanceof StringStruct and + result.(StructType).getStruct() instanceof Builtins::Str ) else result = this.getReceiverTypeAt(path) } @@ -2518,7 +2518,10 @@ private module Cached { */ cached StructField resolveStructFieldExpr(FieldExpr fe) { - exists(string name | result = getFieldExprLookupType(fe, name).getStructField(name)) + exists(string name, Type ty | ty = getFieldExprLookupType(fe, name) | + result = ty.(StructType).getStruct().getStructField(name) or + result = ty.(UnionType).getUnion().getStructField(name) + ) } /** @@ -2526,7 +2529,9 @@ private module Cached { */ cached TupleField resolveTupleFieldExpr(FieldExpr fe) { - exists(int i | result = getTupleFieldExprLookupType(fe, i).getTupleField(i)) + exists(int i | + result = getTupleFieldExprLookupType(fe, i).(StructType).getStruct().getTupleField(i) + ) } /**