Skip to content

Commit b813010

Browse files
authored
Merge pull request #19903 from hvitved/rust/type-inference-overlap2
Rust: Apply inherent method prioritization inside type inference loop
2 parents d6b051e + 0723391 commit b813010

File tree

9 files changed

+2725
-2559
lines changed

9 files changed

+2725
-2559
lines changed

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 51 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -699,7 +699,7 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
699699
}
700700

701701
Declaration getTarget() {
702-
result = inferMethodCallTarget(this) // mutual recursion; resolving method calls requires resolving types and vice versa
702+
result = resolveMethodCallTarget(this) // mutual recursion; resolving method calls requires resolving types and vice versa
703703
or
704704
result = CallExprImpl::getResolvedFunction(this)
705705
}
@@ -1178,14 +1178,14 @@ private predicate methodCandidateTrait(Type type, Trait trait, string name, int
11781178
methodCandidate(type, name, arity, impl)
11791179
}
11801180

1181-
private module IsInstantiationOfInput implements IsInstantiationOfInputSig<MethodCall> {
1182-
pragma[nomagic]
1183-
private predicate isMethodCall(MethodCall mc, Type rootType, string name, int arity) {
1184-
rootType = mc.getTypeAt(TypePath::nil()) and
1185-
name = mc.getMethodName() and
1186-
arity = mc.getNumberOfArguments()
1187-
}
1181+
pragma[nomagic]
1182+
private predicate isMethodCall(MethodCall mc, Type rootType, string name, int arity) {
1183+
rootType = mc.getTypeAt(TypePath::nil()) and
1184+
name = mc.getMethodName() and
1185+
arity = mc.getNumberOfArguments()
1186+
}
11881187

1188+
private module IsInstantiationOfInput implements IsInstantiationOfInputSig<MethodCall> {
11891189
pragma[nomagic]
11901190
predicate potentialInstantiationOf(MethodCall mc, TypeAbstraction impl, TypeMention constraint) {
11911191
exists(Type rootType, string name, int arity |
@@ -1334,17 +1334,46 @@ private predicate methodResolutionDependsOnArgument(
13341334
)
13351335
}
13361336

1337+
/**
1338+
* Holds if the method call `mc` has no inherent target, i.e., it does not
1339+
* resolve to a method in an `impl` block for the type of the receiver.
1340+
*/
1341+
pragma[nomagic]
1342+
private predicate methodCallHasNoInherentTarget(MethodCall mc) {
1343+
exists(Type rootType, string name, int arity |
1344+
isMethodCall(mc, rootType, name, arity) and
1345+
forall(Impl impl |
1346+
methodCandidate(rootType, name, arity, impl) and
1347+
not impl.hasTrait()
1348+
|
1349+
IsInstantiationOf<MethodCall, IsInstantiationOfInput>::isNotInstantiationOf(mc, impl, _)
1350+
)
1351+
)
1352+
}
1353+
1354+
pragma[nomagic]
1355+
private predicate methodCallHasImplCandidate(MethodCall mc, Impl impl) {
1356+
IsInstantiationOf<MethodCall, IsInstantiationOfInput>::isInstantiationOf(mc, impl, _) and
1357+
if impl.hasTrait() and not exists(mc.getTrait())
1358+
then
1359+
// inherent methods take precedence over trait methods, so only allow
1360+
// trait methods when there are no matching inherent methods
1361+
methodCallHasNoInherentTarget(mc)
1362+
else any()
1363+
}
1364+
13371365
/** Gets a method from an `impl` block that matches the method call `mc`. */
1366+
pragma[nomagic]
13381367
private Function getMethodFromImpl(MethodCall mc) {
1339-
exists(Impl impl |
1340-
IsInstantiationOf<MethodCall, IsInstantiationOfInput>::isInstantiationOf(mc, impl, _) and
1341-
result = getMethodSuccessor(impl, mc.getMethodName())
1368+
exists(Impl impl, string name |
1369+
methodCallHasImplCandidate(mc, impl) and
1370+
name = mc.getMethodName() and
1371+
result = getMethodSuccessor(impl, name)
13421372
|
1343-
not methodResolutionDependsOnArgument(impl, _, _, _, _, _) and
1344-
result = getMethodSuccessor(impl, mc.getMethodName())
1373+
not methodResolutionDependsOnArgument(impl, _, _, _, _, _)
13451374
or
13461375
exists(int pos, TypePath path, Type type |
1347-
methodResolutionDependsOnArgument(impl, mc.getMethodName(), result, pos, path, type) and
1376+
methodResolutionDependsOnArgument(impl, name, result, pos, path, type) and
13481377
inferType(mc.getPositionalArgument(pos), path) = type
13491378
)
13501379
)
@@ -1356,22 +1385,6 @@ private Function getTraitMethod(ImplTraitReturnType trait, string name) {
13561385
result = getMethodSuccessor(trait.getImplTraitTypeRepr(), name)
13571386
}
13581387

1359-
/**
1360-
* Gets a method that the method call `mc` resolves to based on type inference,
1361-
* if any.
1362-
*/
1363-
private Function inferMethodCallTarget(MethodCall mc) {
1364-
// The method comes from an `impl` block targeting the type of the receiver.
1365-
result = getMethodFromImpl(mc)
1366-
or
1367-
// The type of the receiver is a type parameter and the method comes from a
1368-
// trait bound on the type parameter.
1369-
result = getTypeParameterMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
1370-
or
1371-
// The type of the receiver is an `impl Trait` type.
1372-
result = getTraitMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
1373-
}
1374-
13751388
cached
13761389
private module Cached {
13771390
private import codeql.rust.internal.CachedStages
@@ -1400,47 +1413,18 @@ private module Cached {
14001413
)
14011414
}
14021415

1403-
private predicate isInherentImplFunction(Function f) {
1404-
f = any(Impl impl | not impl.hasTrait()).(ImplItemNode).getAnAssocItem()
1405-
}
1406-
1407-
private predicate isTraitImplFunction(Function f) {
1408-
f = any(Impl impl | impl.hasTrait()).(ImplItemNode).getAnAssocItem()
1409-
}
1410-
1411-
private Function resolveMethodCallTargetFrom(MethodCall mc, boolean fromSource) {
1412-
result = inferMethodCallTarget(mc) and
1413-
(if result.fromSource() then fromSource = true else fromSource = false) and
1414-
(
1415-
// prioritize inherent implementation methods first
1416-
isInherentImplFunction(result)
1417-
or
1418-
not isInherentImplFunction(inferMethodCallTarget(mc)) and
1419-
(
1420-
// then trait implementation methods
1421-
isTraitImplFunction(result)
1422-
or
1423-
not isTraitImplFunction(inferMethodCallTarget(mc)) and
1424-
(
1425-
// then trait methods with default implementations
1426-
result.hasBody()
1427-
or
1428-
// and finally trait methods without default implementations
1429-
not inferMethodCallTarget(mc).hasBody()
1430-
)
1431-
)
1432-
)
1433-
}
1434-
14351416
/** Gets a method that the method call `mc` resolves to, if any. */
14361417
cached
14371418
Function resolveMethodCallTarget(MethodCall mc) {
1438-
// Functions in source code also gets extracted as library code, due to
1439-
// this duplication we prioritize functions from source code.
1440-
result = resolveMethodCallTargetFrom(mc, true)
1419+
// The method comes from an `impl` block targeting the type of the receiver.
1420+
result = getMethodFromImpl(mc)
1421+
or
1422+
// The type of the receiver is a type parameter and the method comes from a
1423+
// trait bound on the type parameter.
1424+
result = getTypeParameterMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
14411425
or
1442-
not exists(resolveMethodCallTargetFrom(mc, true)) and
1443-
result = resolveMethodCallTargetFrom(mc, false)
1426+
// The type of the receiver is an `impl Trait` type.
1427+
result = getTraitMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
14441428
}
14451429

14461430
pragma[inline]
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
multipleCallTargets
2+
| main.rs:362:14:362:30 | ... .lt(...) |

rust/ql/test/library-tests/dataflow/sources/CONSISTENCY/PathResolutionConsistency.expected

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ multipleCallTargets
1010
| test.rs:168:26:168:111 | ...::_print(...) |
1111
| test.rs:178:30:178:68 | ...::_print(...) |
1212
| test.rs:187:26:187:105 | ...::_print(...) |
13+
| test.rs:228:22:228:72 | ... .read_to_string(...) |
14+
| test.rs:482:22:482:50 | file.read_to_end(...) |
15+
| test.rs:488:22:488:53 | file.read_to_string(...) |
1316
| test.rs:609:18:609:38 | ...::_print(...) |
1417
| test.rs:614:18:614:45 | ...::_print(...) |
1518
| test.rs:618:25:618:49 | address.to_socket_addrs() |

rust/ql/test/library-tests/dataflow/sources/test.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ fn test_io_stdin() -> std::io::Result<()> {
214214
{
215215
let mut buffer = Vec::<u8>::new();
216216
let _bytes = std::io::stdin().read_to_end(&mut buffer)?; // $ Alert[rust/summary/taint-sources]
217-
sink(&buffer); // $ MISSING: hasTaintFlow
217+
sink(&buffer); // $ hasTaintFlow -- @hvitved: works in CI, but not for me locally
218218
}
219219

220220
{
Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
multipleCallTargets
22
| dereference.rs:61:15:61:24 | e1.deref() |
3-
| main.rs:1963:13:1963:31 | ...::from(...) |
4-
| main.rs:1964:13:1964:31 | ...::from(...) |
5-
| main.rs:1965:13:1965:31 | ...::from(...) |
6-
| main.rs:1970:13:1970:31 | ...::from(...) |
7-
| main.rs:1971:13:1971:31 | ...::from(...) |
8-
| main.rs:1972:13:1972:31 | ...::from(...) |
9-
| main.rs:2006:21:2006:43 | ...::from(...) |
3+
| main.rs:2032:13:2032:31 | ...::from(...) |
4+
| main.rs:2033:13:2033:31 | ...::from(...) |
5+
| main.rs:2034:13:2034:31 | ...::from(...) |
6+
| main.rs:2040:13:2040:31 | ...::from(...) |
7+
| main.rs:2041:13:2041:31 | ...::from(...) |
8+
| main.rs:2042:13:2042:31 | ...::from(...) |
9+
| main.rs:2078:21:2078:43 | ...::from(...) |

rust/ql/test/library-tests/type-inference/main.rs

Lines changed: 81 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -406,12 +406,12 @@ mod impl_overlap {
406406
impl OverlappingTrait for S1 {
407407
// <S1_as_OverlappingTrait>::common_method
408408
fn common_method(self) -> S1 {
409-
panic!("not called");
409+
S1
410410
}
411411

412412
// <S1_as_OverlappingTrait>::common_method_2
413413
fn common_method_2(self, s1: S1) -> S1 {
414-
panic!("not called");
414+
S1
415415
}
416416
}
417417

@@ -427,10 +427,78 @@ mod impl_overlap {
427427
}
428428
}
429429

430+
struct S2<T2>(T2);
431+
432+
impl S2<i32> {
433+
// S2<i32>::common_method
434+
fn common_method(self) -> S1 {
435+
S1
436+
}
437+
438+
// S2<i32>::common_method
439+
fn common_method_2(self) -> S1 {
440+
S1
441+
}
442+
}
443+
444+
impl OverlappingTrait for S2<i32> {
445+
// <S2<i32>_as_OverlappingTrait>::common_method
446+
fn common_method(self) -> S1 {
447+
S1
448+
}
449+
450+
// <S2<i32>_as_OverlappingTrait>::common_method_2
451+
fn common_method_2(self, s1: S1) -> S1 {
452+
S1
453+
}
454+
}
455+
456+
impl OverlappingTrait for S2<S1> {
457+
// <S2<S1>_as_OverlappingTrait>::common_method
458+
fn common_method(self) -> S1 {
459+
S1
460+
}
461+
462+
// <S2<S1>_as_OverlappingTrait>::common_method_2
463+
fn common_method_2(self, s1: S1) -> S1 {
464+
S1
465+
}
466+
}
467+
468+
#[derive(Debug)]
469+
struct S3<T3>(T3);
470+
471+
trait OverlappingTrait2<T> {
472+
fn m(&self, x: &T) -> &Self;
473+
}
474+
475+
impl<T> OverlappingTrait2<T> for S3<T> {
476+
// <S3<T>_as_OverlappingTrait2<T>>::m
477+
fn m(&self, x: &T) -> &Self {
478+
self
479+
}
480+
}
481+
482+
impl<T> S3<T> {
483+
// S3<T>::m
484+
fn m(&self, x: T) -> &Self {
485+
self
486+
}
487+
}
488+
430489
pub fn f() {
431490
let x = S1;
432491
println!("{:?}", x.common_method()); // $ method=S1::common_method
433492
println!("{:?}", x.common_method_2()); // $ method=S1::common_method_2
493+
494+
let y = S2(S1);
495+
println!("{:?}", y.common_method()); // $ method=<S2<S1>_as_OverlappingTrait>::common_method
496+
497+
let z = S2(0);
498+
println!("{:?}", z.common_method()); // $ method=S2<i32>::common_method
499+
500+
let w = S3(S1);
501+
println!("{:?}", w.m(x)); // $ method=S3<T>::m
434502
}
435503
}
436504

@@ -1959,22 +2027,25 @@ mod loops {
19592027
for s in &mut strings1 {} // $ MISSING: type=s:&T.str
19602028
for s in strings1 {} // $ type=s:str
19612029

1962-
let strings2 = [ // $ type=strings2:[T;...].String
2030+
let strings2 = // $ type=strings2:[T;...].String
2031+
[
19632032
String::from("foo"),
19642033
String::from("bar"),
19652034
String::from("baz"),
19662035
];
19672036
for s in strings2 {} // $ type=s:String
19682037

1969-
let strings3 = &[ // $ type=strings3:&T.[T;...].String
2038+
let strings3 = // $ type=strings3:&T.[T;...].String
2039+
&[
19702040
String::from("foo"),
19712041
String::from("bar"),
19722042
String::from("baz"),
19732043
];
19742044
for s in strings3 {} // $ MISSING: type=s:String
19752045

19762046
let callables = [MyCallable::new(), MyCallable::new(), MyCallable::new()]; // $ MISSING: type=callables:[T;...].MyCallable; 3
1977-
for c in callables // $ type=c:MyCallable
2047+
for c // $ type=c:MyCallable
2048+
in callables
19782049
{
19792050
let result = c.call(); // $ type=result:i64 method=call
19802051
}
@@ -1986,7 +2057,8 @@ mod loops {
19862057
let range = 0..10; // $ MISSING: type=range:Range type=range:Idx.i32
19872058
for i in range {} // $ MISSING: type=i:i32
19882059

1989-
let range1 = std::ops::Range { // $ type=range1:Range type=range1:Idx.u16
2060+
let range1 = // $ type=range1:Range type=range1:Idx.u16
2061+
std::ops::Range {
19902062
start: 0u16,
19912063
end: 10u16,
19922064
};
@@ -2031,10 +2103,11 @@ mod loops {
20312103
// while loops
20322104

20332105
let mut a: i64 = 0; // $ type=a:i64
2034-
while a < 10 // $ method=lt type=a:i64
2106+
#[rustfmt::skip]
2107+
let _ = while a < 10 // $ method=lt type=a:i64
20352108
{
20362109
a += 1; // $ type=a:i64 method=add_assign
2037-
}
2110+
};
20382111
}
20392112
}
20402113

0 commit comments

Comments
 (0)