diff --git a/rust/ql/lib/codeql/rust/internal/TypeInference.qll b/rust/ql/lib/codeql/rust/internal/TypeInference.qll index 7668ab88651f..b0179567ae8b 100644 --- a/rust/ql/lib/codeql/rust/internal/TypeInference.qll +++ b/rust/ql/lib/codeql/rust/internal/TypeInference.qll @@ -257,7 +257,7 @@ private Type inferAnnotatedType(AstNode n, TypePath path) { } /** Module for inferring certain type information. */ -private module CertainTypeInference { +module CertainTypeInference { pragma[nomagic] private predicate callResolvesTo(CallExpr ce, Path p, Function f) { p = CallExprImpl::getFunctionPath(ce) and @@ -286,7 +286,7 @@ private module CertainTypeInference { } pragma[nomagic] - Type inferCertainCallExprType(CallExpr ce, TypePath path) { + private Type inferCertainCallExprType(CallExpr ce, TypePath path) { exists(Type ty, TypePath prefix, Path p | ty = getCertainCallExprType(ce, p, prefix) | exists(TypePath suffix, TypeParam tp | tp = ty.(TypeParamTypeParameter).getTypeParam() and diff --git a/rust/ql/test/library-tests/type-inference/closure.rs b/rust/ql/test/library-tests/type-inference/closure.rs index 1b11335947c1..43b26819325e 100644 --- a/rust/ql/test/library-tests/type-inference/closure.rs +++ b/rust/ql/test/library-tests/type-inference/closure.rs @@ -5,7 +5,7 @@ mod simple_closures { // A simple closure without type annotations or invocations. let my_closure = |a, b| a && b; - let x: i64 = 1i64; // $ type=x:i64 + let x: i64 = 1i64; // $ certainType=x:i64 let add_one = |n| n + 1i64; // $ target=add let _y = add_one(x); // $ type=_y:i64 @@ -27,7 +27,7 @@ mod simple_closures { // The return type of `id2` is inferred from the type of the call expression. let id2 = |b| b; let arg = Default::default(); // $ target=default type=arg:bool - let _b2: bool = id2(arg); // $ type=_b2:bool + let _b2: bool = id2(arg); // $ certainType=_b2:bool } } @@ -60,7 +60,7 @@ mod fn_once_trait { let _r = apply(f, true); // $ target=apply type=_r:i64 let f = |x| x + 1; // $ MISSING: type=x:i64 target=add - let _r2 = apply_two(f); // $ target=apply_two type=_r2:i64 + let _r2 = apply_two(f); // $ target=apply_two certainType=_r2:i64 } } diff --git a/rust/ql/test/library-tests/type-inference/main.rs b/rust/ql/test/library-tests/type-inference/main.rs index 6685b80ae755..148b60077f87 100644 --- a/rust/ql/test/library-tests/type-inference/main.rs +++ b/rust/ql/test/library-tests/type-inference/main.rs @@ -1140,20 +1140,20 @@ mod type_aliases { println!("{:?}", p1); // Type can be only inferred from the type alias - let p2: MyPair = PairOption::PairNone(); // $ type=p2:Fst.S1 type=p2:Snd.S2 + let p2: MyPair = PairOption::PairNone(); // $ certainType=p2:Fst.S1 certainType=p2:Snd.S2 println!("{:?}", p2); // First type from alias, second from constructor - let p3: AnotherPair<_> = PairOption::PairSnd(S3); // $ type=p3:Fst.S2 + let p3: AnotherPair<_> = PairOption::PairSnd(S3); // $ certainType=p3:Fst.S2 println!("{:?}", p3); // First type from alias definition, second from argument to alias - let p3: AnotherPair = PairOption::PairNone(); // $ type=p3:Fst.S2 type=p3:Snd.S3 + let p3: AnotherPair = PairOption::PairNone(); // $ certainType=p3:Fst.S2 certainType=p3:Snd.S3 println!("{:?}", p3); g(PairOption::PairSnd(PairOption::PairSnd(S3))); // $ target=g - let x: S7; // $ type=x:Result $ type=x:E.S1 $ type=x:T.S4 $ type=x:T.T41.S2 $ type=x:T.T42.S5 $ type=x:T.T42.T5.S2 + let x: S7; // $ certainType=x:Result $ certainType=x:E.S1 $ certainType=x:T.S4 $ certainType=x:T.T41.S2 $ certainType=x:T.T42.S5 $ certainType=x:T.T42.T5.S2 let y = GenS(true).get_input(); // $ type=y:Result type=y:T.bool type=y:E.bool target=get_input } @@ -1199,7 +1199,7 @@ mod option_methods { struct S; pub fn f() { - let x1 = MyOption::::new(); // $ type=x1:T.S target=new + let x1 = MyOption::::new(); // $ certainType=x1:T.S target=new println!("{:?}", x1); let mut x2 = MyOption::new(); // $ target=new @@ -1327,7 +1327,7 @@ mod method_call_type_conversion { let t = x7.m1(); // $ target=m1 type=t:& type=t:&T.S2 println!("{:?}", x7); - let x9: String = "Hello".to_string(); // $ type=x9:String + let x9: String = "Hello".to_string(); // $ certainType=x9:String // Implicit `String` -> `str` conversion happens via the `Deref` trait: // https://doc.rust-lang.org/std/string/struct.String.html#deref. @@ -1502,23 +1502,23 @@ mod try_expressions { mod builtins { pub fn f() { - let x: i32 = 1; // $ type=x:i32 + let x: i32 = 1; // $ certainType=x:i32 let y = 2; // $ type=y:i32 let z = x + y; // $ type=z:i32 target=add let z = x.abs(); // $ target=abs $ type=z:i32 - let c = 'c'; // $ type=c:char - let hello = "Hello"; // $ type=hello:&T.str - let f = 123.0f64; // $ type=f:f64 - let t = true; // $ type=t:bool - let f = false; // $ type=f:bool + let c = 'c'; // $ certainType=c:char + let hello = "Hello"; // $ certainType=hello:&T.str + let f = 123.0f64; // $ certainType=f:f64 + let t = true; // $ certainType=t:bool + let f = false; // $ certainType=f:bool } } // Tests for non-overloaded operators. mod operators { pub fn f() { - let x = true && false; // $ type=x:bool - let y = true || false; // $ type=y:bool + let x = true && false; // $ certainType=x:bool + let y = true || false; // $ certainType=y:bool let mut a; let cond = 34 == 33; // $ target=eq @@ -2292,10 +2292,10 @@ mod loops { let vals2 = [1u16; 3]; // $ type=vals2:[T;...].u16 for u in vals2 {} // $ type=u:u16 - let vals3: [u32; 3] = [1, 2, 3]; // $ type=vals3:[T;...].u32 + let vals3: [u32; 3] = [1, 2, 3]; // $ certainType=vals3:[T;...].u32 for u in vals3 {} // $ type=u:u32 - let vals4: [u64; 3] = [1; 3]; // $ type=vals4:[T;...].u64 + let vals4: [u64; 3] = [1; 3]; // $ certainType=vals4:[T;...].u64 for u in vals4 {} // $ type=u:u64 let mut strings1 = ["foo", "bar", "baz"]; // $ type=strings1:[T;...].&T.str @@ -2330,9 +2330,9 @@ mod loops { for i in 0..10 {} // $ type=i:i32 for u in [0u8..10] {} // $ type=u:Range type=u:Idx.u8 - let range = 0..10; // $ type=range:Range type=range:Idx.i32 + let range = 0..10; // $ certainType=range:Range type=range:Idx.i32 for i in range {} // $ type=i:i32 - let range_full = ..; // $ type=range_full:RangeFull + let range_full = ..; // $ certainType=range_full:RangeFull for i in &[1i64, 2i64, 3i64][range_full] {} // $ target=index MISSING: type=i:&T.i64 let range1 = // $ type=range1:Range type=range1:Idx.u16 @@ -2347,19 +2347,19 @@ mod loops { let vals3 = vec![1, 2, 3]; // $ MISSING: type=vals3:Vec type=vals3:T.i32 for i in vals3 {} // $ MISSING: type=i:i32 - let vals4a: Vec = [1u16, 2, 3].to_vec(); // $ type=vals4a:Vec type=vals4a:T.u16 + let vals4a: Vec = [1u16, 2, 3].to_vec(); // $ certainType=vals4a:Vec certainType=vals4a:T.u16 for u in vals4a {} // $ type=u:u16 let vals4b = [1u16, 2, 3].to_vec(); // $ MISSING: type=vals4b:Vec type=vals4b:T.u16 for u in vals4b {} // $ MISSING: type=u:u16 - let vals5 = Vec::from([1u32, 2, 3]); // $ type=vals5:Vec target=from type=vals5:T.u32 + let vals5 = Vec::from([1u32, 2, 3]); // $ certainType=vals5:Vec target=from type=vals5:T.u32 for u in vals5 {} // $ type=u:u32 - let vals6: Vec<&u64> = [1u64, 2, 3].iter().collect(); // $ type=vals6:Vec type=vals6:T.&T.u64 + let vals6: Vec<&u64> = [1u64, 2, 3].iter().collect(); // $ certainType=vals6:Vec certainType=vals6:T.&T.u64 for u in vals6 {} // $ type=u:&T.u64 - let mut vals7 = Vec::new(); // $ target=new type=vals7:Vec type=vals7:T.u8 + let mut vals7 = Vec::new(); // $ target=new certainType=vals7:Vec type=vals7:T.u8 vals7.push(1u8); // $ target=push for u in vals7 {} // $ type=u:u8 @@ -2380,11 +2380,11 @@ mod loops { // while loops - let mut a: i64 = 0; // $ type=a:i64 + let mut a: i64 = 0; // $ certainType=a:i64 #[rustfmt::skip] - let _ = while a < 10 // $ target=lt type=a:i64 + let _ = while a < 10 // $ target=lt certainType=a:i64 { - a += 1; // $ type=a:i64 MISSING: target=add_assign + a += 1; // $ certainType=a:i64 MISSING: target=add_assign }; } } @@ -2422,11 +2422,11 @@ mod explicit_type_args { } pub fn f() { - let x1: Option> = S1::assoc_fun(); // $ type=x1:T.T.S2 target=assoc_fun - let x2 = S1::::assoc_fun(); // $ type=x2:T.T.S2 target=assoc_fun - let x3 = S3::assoc_fun(); // $ type=x3:T.T.S2 target=assoc_fun - let x4 = S1::::method(S1::default()); // $ target=method target=default type=x4:T.S2 - let x5 = S3::method(S1::default()); // $ target=method target=default type=x5:T.S2 + let x1: Option> = S1::assoc_fun(); // $ certainType=x1:T.T.S2 target=assoc_fun + let x2 = S1::::assoc_fun(); // $ certainType=x2:T.T.S2 target=assoc_fun + let x3 = S3::assoc_fun(); // $ certainType=x3:T.T.S2 target=assoc_fun + let x4 = S1::::method(S1::default()); // $ target=method target=default certainType=x4:T.S2 + let x5 = S3::method(S1::default()); // $ target=method target=default certainType=x5:T.S2 let x6 = S4::(Default::default()); // $ type=x6:T4.S2 target=default let x7 = S4(S2); // $ type=x7:T4.S2 let x8 = S4(0); // $ type=x8:T4.i32 @@ -2441,7 +2441,7 @@ mod explicit_type_args { { field: S2::default(), // $ target=default }; - let x14 = foo::(Default::default()); // $ type=x14:i32 target=default target=foo + let x14 = foo::(Default::default()); // $ certainType=x14:i32 target=default target=foo } } @@ -2457,8 +2457,8 @@ mod tuples { } pub fn f() { - let a = S1::get_pair(); // $ target=get_pair type=a:(T_2) - let mut b = S1::get_pair(); // $ target=get_pair type=b:(T_2) + let a = S1::get_pair(); // $ target=get_pair certainType=a:(T_2) + let mut b = S1::get_pair(); // $ target=get_pair certainType=b:(T_2) let (c, d) = S1::get_pair(); // $ target=get_pair type=c:S1 type=d:S1 let (mut e, f) = S1::get_pair(); // $ target=get_pair type=e:S1 type=f:S1 let (mut g, mut h) = S1::get_pair(); // $ target=get_pair type=g:S1 type=h:S1 @@ -2593,11 +2593,11 @@ pub mod path_buf { } pub fn f() { - let path1 = Path::new(); // $ target=new type=path1:Path + let path1 = Path::new(); // $ target=new certainType=path1:Path let path2 = path1.canonicalize(); // $ target=canonicalize let path3 = path2.unwrap(); // $ target=unwrap type=path3:PathBuf - let pathbuf1 = PathBuf::new(); // $ target=new type=pathbuf1:PathBuf + let pathbuf1 = PathBuf::new(); // $ target=new certainType=pathbuf1:PathBuf let pathbuf2 = pathbuf1.canonicalize(); // $ MISSING: target=canonicalize let pathbuf3 = pathbuf2.unwrap(); // $ MISSING: target=unwrap type=pathbuf3:PathBuf } diff --git a/rust/ql/test/library-tests/type-inference/pattern_matching.rs b/rust/ql/test/library-tests/type-inference/pattern_matching.rs index 30ddd61444e7..9e40560f18c8 100755 --- a/rust/ql/test/library-tests/type-inference/pattern_matching.rs +++ b/rust/ql/test/library-tests/type-inference/pattern_matching.rs @@ -171,15 +171,15 @@ pub fn literal_patterns() { match value { // LiteralPat - Literal patterns (including negative literals) 42 => { - let literal_match = value; // $ type=literal_match:i32 + let literal_match = value; // $ certainType=literal_match:i32 println!("Literal pattern: {}", literal_match); } -1 => { - let negative_literal = value; // $ type=negative_literal:i32 + let negative_literal = value; // $ certainType=negative_literal:i32 println!("Negative literal: {}", negative_literal); } 0 => { - let zero_literal = value; // $ type=zero_literal:i32 + let zero_literal = value; // $ certainType=zero_literal:i32 println!("Zero literal: {}", zero_literal); } _ => {} @@ -188,7 +188,7 @@ pub fn literal_patterns() { let float_val = 3.14f64; match float_val { 3.14 => { - let pi_match = float_val; // $ type=pi_match:f64 + let pi_match = float_val; // $ certainType=pi_match:f64 println!("Pi matched: {}", pi_match); } _ => {} @@ -197,7 +197,7 @@ pub fn literal_patterns() { let string_val = "hello"; match string_val { "hello" => { - let hello_match = string_val; // $ type=hello_match:&T.str + let hello_match = string_val; // $ certainType=hello_match:&T.str println!("String literal: {}", hello_match); } _ => {} @@ -206,11 +206,11 @@ pub fn literal_patterns() { let bool_val = true; match bool_val { true => { - let true_match = bool_val; // $ type=true_match:bool + let true_match = bool_val; // $ certainType=true_match:bool println!("True literal: {}", true_match); } false => { - let false_match = bool_val; // $ type=false_match:bool + let false_match = bool_val; // $ certainType=false_match:bool println!("False literal: {}", false_match); } } @@ -283,7 +283,7 @@ pub fn wildcard_patterns() { 42 => println!("Specific match"), // WildcardPat - Wildcard pattern _ => { - let wildcard_context = value; // $ type=wildcard_context:i32 + let wildcard_context = value; // $ certainType=wildcard_context:i32 println!("Wildcard pattern for: {}", wildcard_context); } } @@ -295,15 +295,15 @@ pub fn range_patterns() { match value { // RangePat - Range patterns 1..=10 => { - let range_inclusive = value; // $ type=range_inclusive:i32 + let range_inclusive = value; // $ certainType=range_inclusive:i32 println!("Range inclusive: {}", range_inclusive); } 11.. => { - let range_from = value; // $ type=range_from:i32 + let range_from = value; // $ certainType=range_from:i32 println!("Range from 11: {}", range_from); } ..=0 => { - let range_to_inclusive = value; // $ type=range_to_inclusive:i32 + let range_to_inclusive = value; // $ certainType=range_to_inclusive:i32 println!("Range to 0 inclusive: {}", range_to_inclusive); } _ => {} @@ -312,11 +312,11 @@ pub fn range_patterns() { let char_val = 'c'; match char_val { 'a'..='z' => { - let lowercase_char = char_val; // $ type=lowercase_char:char + let lowercase_char = char_val; // $ certainType=lowercase_char:char println!("Lowercase char: {}", lowercase_char); } 'A'..='Z' => { - let uppercase_char = char_val; // $ type=uppercase_char:char + let uppercase_char = char_val; // $ certainType=uppercase_char:char println!("Uppercase char: {}", uppercase_char); } _ => {} @@ -330,7 +330,7 @@ pub fn reference_patterns() { // RefPat - Reference patterns match &value { &42 => { - let deref_match = value; // $ type=deref_match:i32 + let deref_match = value; // $ certainType=deref_match:i32 println!("Dereferenced match: {}", deref_match); } &x => { @@ -446,7 +446,7 @@ pub fn tuple_patterns() { // TuplePat - Tuple patterns match tuple { (1, 2, 3.0) => { - let exact_tuple = tuple; // $ type=exact_tuple:(T_3) + let exact_tuple = tuple; // $ certainType=exact_tuple:(T_3) println!("Exact tuple: {:?}", exact_tuple); } (a, b, c) => { @@ -469,7 +469,7 @@ pub fn tuple_patterns() { let unit = (); match unit { () => { - let unit_value = unit; // $ type=unit_value:() + let unit_value = unit; // $ certainType=unit_value:() println!("Unit value: {:?}", unit_value); } } @@ -525,7 +525,7 @@ pub fn slice_patterns() { // SlicePat - Slice patterns match slice { [] => { - let empty_slice = slice; // $ type=empty_slice:&T.[T].i32 + let empty_slice = slice; // $ certainType=empty_slice:&T.[T].i32 println!("Empty slice: {:?}", empty_slice); } [x] => { @@ -569,7 +569,7 @@ pub fn path_patterns() { match value { CONSTANT => { - let const_match = value; // $ type=const_match:i32 + let const_match = value; // $ certainType=const_match:i32 println!("Matches constant: {}", const_match); } _ => {} @@ -606,11 +606,11 @@ pub fn or_patterns() { // OrPat - Or patterns match value { 1 | 2 | 3 => { - let small_num = value; // $ type=small_num:i32 + let small_num = value; // $ certainType=small_num:i32 println!("Small number: {}", small_num); } 10 | 20 => { - let round_num = value; // $ type=round_num:i32 + let round_num = value; // $ certainType=round_num:i32 println!("Round number: {}", round_num); } _ => {} @@ -630,7 +630,7 @@ pub fn or_patterns() { // Or pattern with ranges match value { 1..=10 | 90..=100 => { - let range_or_value = value; // $ type=range_or_value:i32 + let range_or_value = value; // $ certainType=range_or_value:i32 println!("In range: {}", range_or_value); } _ => {} @@ -750,11 +750,11 @@ pub fn patterns_in_let_statements() { // Let with reference pattern let value = 42i32; let ref ref_val = value; - let let_ref = ref_val; // $ type=let_ref:&T.i32 + let let_ref = ref_val; // $ certainType=let_ref:&T.i32 // Let with mutable pattern let mut mut_val = 10i32; - let let_mut = mut_val; // $ type=let_mut:i32 + let let_mut = mut_val; // $ certainType=let_mut:i32 } pub fn patterns_in_function_parameters() { @@ -779,13 +779,13 @@ pub fn patterns_in_function_parameters() { // Call the functions to use them let point = Point { x: 5, y: 10 }; - let extracted = extract_point(point); // $ target=extract_point type=extracted:0(2).i32 type=extracted:1(2).i32 + let extracted = extract_point(point); // $ target=extract_point certainType=extracted:0(2).i32 certainType=extracted:1(2).i32 let color = Color(200, 100, 50); - let red = extract_color(color); // $ target=extract_color type=red:u8 + let red = extract_color(color); // $ target=extract_color certainType=red:u8 let tuple = (42i32, 3.14f64, true); - let tuple_extracted = extract_tuple(tuple); // $ target=extract_tuple type=tuple_extracted:0(2).i32 type=tuple_extracted:1(2).bool + let tuple_extracted = extract_tuple(tuple); // $ target=extract_tuple certainType=tuple_extracted:0(2).i32 certainType=tuple_extracted:1(2).bool } #[rustfmt::skip] diff --git a/rust/ql/test/library-tests/type-inference/type-inference.ql b/rust/ql/test/library-tests/type-inference/type-inference.ql index 2122c7898a7d..059cc7848a0a 100644 --- a/rust/ql/test/library-tests/type-inference/type-inference.ql +++ b/rust/ql/test/library-tests/type-inference/type-inference.ql @@ -48,24 +48,27 @@ module ResolveTest implements TestSig { } module TypeTest implements TestSig { - string getARelevantTag() { result = "type" } + string getARelevantTag() { result = ["type", "certainType"] } predicate tagIsOptional(string expectedTag) { expectedTag = "type" } predicate hasActualResult(Location location, string element, string tag, string value) { none() } predicate hasOptionalResult(Location location, string element, string tag, string value) { - tag = "type" and exists(AstNode n, TypePath path, Type t | t = TypeInference::inferType(n, path) and + ( + if t = TypeInference::CertainTypeInference::inferCertainType(n, path) + then tag = "certainType" + else tag = "type" + ) and location = n.getLocation() and - if path.isEmpty() - then value = element + ":" + t - else value = element + ":" + path.toString() + "." + t.toString() - | - element = n.toString() - or - element = n.(IdentPat).getName().getText() + ( + if path.isEmpty() + then value = element + ":" + t + else value = element + ":" + path.toString() + "." + t.toString() + ) and + element = [n.toString(), n.(IdentPat).getName().getText()] ) } }