Skip to content

Commit e379795

Browse files
authored
Merge pull request #19754 from geoffw0/typeinfer
Rust: Type inference for `for` loops and array expressions
2 parents c380c5f + 96dcdf9 commit e379795

File tree

5 files changed

+498
-40
lines changed

5 files changed

+498
-40
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
category: minorAnalysis
3+
---
4+
* Added type inference for `for` loops and array expressions.

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,16 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
285285
prefix2.isEmpty()
286286
)
287287
)
288+
or
289+
// an array list expression (`[1, 2, 3]`) has the type of the first (any) element
290+
n1.(ArrayListExpr).getExpr(_) = n2 and
291+
prefix1 = TypePath::singleton(TArrayTypeParameter()) and
292+
prefix2.isEmpty()
293+
or
294+
// an array repeat expression (`[1; 3]`) has the type of the repeat operand
295+
n1.(ArrayRepeatExpr).getRepeatOperand() = n2 and
296+
prefix1 = TypePath::singleton(TArrayTypeParameter()) and
297+
prefix2.isEmpty()
288298
}
289299

290300
pragma[nomagic]
@@ -1037,6 +1047,12 @@ private class Vec extends Struct {
10371047
}
10381048
}
10391049

1050+
/**
1051+
* Gets the root type of the array expression `ae`.
1052+
*/
1053+
pragma[nomagic]
1054+
private Type inferArrayExprType(ArrayExpr ae) { exists(ae) and result = TArrayType() }
1055+
10401056
/**
10411057
* According to [the Rust reference][1]: _"array and slice-typed expressions
10421058
* can be indexed with a `usize` index ... For other types an index expression
@@ -1073,6 +1089,26 @@ private Type inferIndexExprType(IndexExpr ie, TypePath path) {
10731089
)
10741090
}
10751091

1092+
pragma[nomagic]
1093+
private Type inferForLoopExprType(AstNode n, TypePath path) {
1094+
// type of iterable -> type of pattern (loop variable)
1095+
exists(ForExpr fe, Type iterableType, TypePath iterablePath |
1096+
n = fe.getPat() and
1097+
iterableType = inferType(fe.getIterable(), iterablePath) and
1098+
result = iterableType and
1099+
(
1100+
iterablePath.isCons(any(Vec v).getElementTypeParameter(), path)
1101+
or
1102+
iterablePath.isCons(any(ArrayTypeParameter tp), path)
1103+
or
1104+
iterablePath
1105+
.stripPrefix(TypePath::cons(TRefTypeParameter(),
1106+
TypePath::singleton(any(SliceTypeParameter tp)))) = path
1107+
// TODO: iterables (general case for containers, ranges etc)
1108+
)
1109+
)
1110+
}
1111+
10761112
final class MethodCall extends Call {
10771113
MethodCall() {
10781114
exists(this.getReceiver()) and
@@ -1518,7 +1554,12 @@ private module Cached {
15181554
or
15191555
result = inferAwaitExprType(n, path)
15201556
or
1557+
result = inferArrayExprType(n) and
1558+
path.isEmpty()
1559+
or
15211560
result = inferIndexExprType(n, path)
1561+
or
1562+
result = inferForLoopExprType(n, path)
15221563
}
15231564
}
15241565

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,9 @@
11
multipleCallTargets
22
| dereference.rs:61:15:61:24 | e1.deref() |
3+
| main.rs:1963:13:1963:31 | ...::from(...) |
4+
| main.rs:1964:13:1964:31 | ...::from(...) |
5+
| main.rs:1965:13:1965:31 | ...::from(...) |
6+
| main.rs:1970:13:1970:31 | ...::from(...) |
7+
| main.rs:1971:13:1971:31 | ...::from(...) |
8+
| main.rs:1972:13:1972:31 | ...::from(...) |
9+
| main.rs:2006:21:2006:43 | ...::from(...) |

rust/ql/test/library-tests/type-inference/main.rs

Lines changed: 118 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1910,11 +1910,7 @@ mod method_determined_by_argument_type {
19101910
impl MyAdd<bool> for i64 {
19111911
// MyAdd<bool>::my_add
19121912
fn my_add(&self, value: bool) -> Self {
1913-
if value {
1914-
1
1915-
} else {
1916-
0
1917-
}
1913+
if value { 1 } else { 0 }
19181914
}
19191915
}
19201916

@@ -1926,6 +1922,122 @@ mod method_determined_by_argument_type {
19261922
}
19271923
}
19281924

1925+
mod loops {
1926+
struct MyCallable {}
1927+
1928+
impl MyCallable {
1929+
fn new() -> Self {
1930+
MyCallable {}
1931+
}
1932+
1933+
fn call(&self) -> i64 {
1934+
1
1935+
}
1936+
}
1937+
1938+
pub fn f() {
1939+
// for loops with arrays
1940+
1941+
for i in [1, 2, 3] {} // $ type=i:i32
1942+
for i in [1, 2, 3].map(|x| x + 1) {} // $ method=map MISSING: type=i:i32
1943+
for i in [1, 2, 3].into_iter() {} // $ method=into_iter MISSING: type=i:i32
1944+
1945+
let vals1 = [1u8, 2, 3]; // $ type=vals1:[T;...].u8
1946+
for u in vals1 {} // $ type=u:u8
1947+
1948+
let vals2 = [1u16; 3]; // $ type=vals2:[T;...].u16
1949+
for u in vals2 {} // $ type=u:u16
1950+
1951+
let vals3: [u32; 3] = [1, 2, 3]; // $ type=vals3:[T;...].u32
1952+
for u in vals3 {} // $ type=u:u32
1953+
1954+
let vals4: [u64; 3] = [1; 3]; // $ type=vals4:[T;...].u64
1955+
for u in vals4 {} // $ type=u:u64
1956+
1957+
let mut strings1 = ["foo", "bar", "baz"]; // $ type=strings1:[T;...].str
1958+
for s in &strings1 {} // $ MISSING: type=s:&T.str
1959+
for s in &mut strings1 {} // $ MISSING: type=s:&T.str
1960+
for s in strings1 {} // $ type=s:str
1961+
1962+
let strings2 = [ // $ type=strings2:[T;...].String
1963+
String::from("foo"),
1964+
String::from("bar"),
1965+
String::from("baz"),
1966+
];
1967+
for s in strings2 {} // $ type=s:String
1968+
1969+
let strings3 = &[ // $ type=strings3:&T.[T;...].String
1970+
String::from("foo"),
1971+
String::from("bar"),
1972+
String::from("baz"),
1973+
];
1974+
for s in strings3 {} // $ MISSING: type=s:String
1975+
1976+
let callables = [MyCallable::new(), MyCallable::new(), MyCallable::new()]; // $ MISSING: type=callables:[T;...].MyCallable; 3
1977+
for c in callables // $ type=c:MyCallable
1978+
{
1979+
let result = c.call(); // $ type=result:i64 method=call
1980+
}
1981+
1982+
// for loops with ranges
1983+
1984+
for i in 0..10 {} // $ MISSING: type=i:i32
1985+
for u in [0u8..10] {} // $ MISSING: type=u:u8
1986+
let range = 0..10; // $ MISSING: type=range:Range type=range:Idx.i32
1987+
for i in range {} // $ MISSING: type=i:i32
1988+
1989+
let range1 = std::ops::Range { // $ type=range1:Range type=range1:Idx.u16
1990+
start: 0u16,
1991+
end: 10u16,
1992+
};
1993+
for u in range1 {} // $ MISSING: type=u:u16
1994+
1995+
// for loops with containers
1996+
1997+
let vals3 = vec![1, 2, 3]; // $ MISSING: type=vals3:Vec type=vals3:T.i32
1998+
for i in vals3 {} // $ MISSING: type=i:i32
1999+
2000+
let vals4a: Vec<u16> = [1u16, 2, 3].to_vec(); // $ type=vals4a:Vec type=vals4a:T.u16
2001+
for u in vals4a {} // $ type=u:u16
2002+
2003+
let vals4b = [1u16, 2, 3].to_vec(); // $ MISSING: type=vals4b:Vec type=vals4b:T.u16
2004+
for u in vals4b {} // $ MISSING: type=u:u16
2005+
2006+
let vals5 = Vec::from([1u32, 2, 3]); // $ type=vals5:Vec MISSING: type=vals5:T.u32
2007+
for u in vals5 {} // $ MISSING: type=u:u32
2008+
2009+
let vals6: Vec<&u64> = [1u64, 2, 3].iter().collect(); // $ type=vals6:Vec type=vals6:T.&T.u64
2010+
for u in vals6 {} // $ type=u:&T.u64
2011+
2012+
let mut vals7 = Vec::new(); // $ type=vals7:Vec MISSING: type=vals7:T.u8
2013+
vals7.push(1u8); // $ method=push
2014+
for u in vals7 {} // $ MISSING: type=u:u8
2015+
2016+
let matrix1 = vec![vec![1, 2], vec![3, 4]]; // $ MISSING: type=matrix1:Vec type=matrix1:T.Vec type=matrix1:T.T.i32
2017+
for row in matrix1 {
2018+
// $ MISSING: type=row:Vec type=row:T.i32
2019+
for cell in row { // $ MISSING: type=cell:i32
2020+
}
2021+
}
2022+
2023+
let mut map1 = std::collections::HashMap::new(); // $ MISSING: type=map1:Hashmap type=map1:K.i32 type=map1:V.Box type1=map1:V.T.&T.str
2024+
map1.insert(1, Box::new("one")); // $ method=insert
2025+
map1.insert(2, Box::new("two")); // $ method=insert
2026+
for key in map1.keys() {} // $ method=keys MISSING: type=key:i32
2027+
for value in map1.values() {} // $ method=values MISSING: type=value:Box type=value:T.&T.str
2028+
for (key, value) in map1.iter() {} // $ method=iter MISSING: type=key:i32 type=value:Box type=value:T.&T.str
2029+
for (key, value) in &map1 {} // $ MISSING: type=key:i32 type=value:Box type=value:T.&T.str
2030+
2031+
// while loops
2032+
2033+
let mut a: i64 = 0; // $ type=a:i64
2034+
while a < 10 // $ method=lt type=a:i64
2035+
{
2036+
a += 1; // $ type=a:i64 method=add_assign
2037+
}
2038+
}
2039+
}
2040+
19292041
mod dereference;
19302042

19312043
fn main() {
@@ -1950,6 +2062,7 @@ fn main() {
19502062
async_::f();
19512063
impl_trait::f();
19522064
indexers::f();
2065+
loops::f();
19532066
macros::f();
19542067
method_determined_by_argument_type::f();
19552068
dereference::test();

0 commit comments

Comments
 (0)