Skip to content

Commit e88d7ba

Browse files
committed
Rust: Apply inherent method prioritization inside type inference loop
1 parent e5f0ef6 commit e88d7ba

File tree

6 files changed

+84
-254
lines changed

6 files changed

+84
-254
lines changed

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 51 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -699,7 +699,7 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
699699
}
700700

701701
Declaration getTarget() {
702-
result = inferMethodCallTarget(this) // mutual recursion; resolving method calls requires resolving types and vice versa
702+
result = resolveMethodCallTarget(this) // mutual recursion; resolving method calls requires resolving types and vice versa
703703
or
704704
result = CallExprImpl::getResolvedFunction(this)
705705
}
@@ -1178,14 +1178,14 @@ private predicate methodCandidateTrait(Type type, Trait trait, string name, int
11781178
methodCandidate(type, name, arity, impl)
11791179
}
11801180

1181-
private module IsInstantiationOfInput implements IsInstantiationOfInputSig<MethodCall> {
1182-
pragma[nomagic]
1183-
private predicate isMethodCall(MethodCall mc, Type rootType, string name, int arity) {
1184-
rootType = mc.getTypeAt(TypePath::nil()) and
1185-
name = mc.getMethodName() and
1186-
arity = mc.getNumberOfArguments()
1187-
}
1181+
pragma[nomagic]
1182+
private predicate isMethodCall(MethodCall mc, Type rootType, string name, int arity) {
1183+
rootType = mc.getTypeAt(TypePath::nil()) and
1184+
name = mc.getMethodName() and
1185+
arity = mc.getNumberOfArguments()
1186+
}
11881187

1188+
private module IsInstantiationOfInput implements IsInstantiationOfInputSig<MethodCall> {
11891189
pragma[nomagic]
11901190
predicate potentialInstantiationOf(MethodCall mc, TypeAbstraction impl, TypeMention constraint) {
11911191
exists(Type rootType, string name, int arity |
@@ -1334,17 +1334,46 @@ private predicate methodResolutionDependsOnArgument(
13341334
)
13351335
}
13361336

1337+
/**
1338+
* Holds if the method call `mc` has no inherent target, i.e., it does not
1339+
* resolve to a method in an `impl` block for the type of the receiver.
1340+
*/
1341+
pragma[nomagic]
1342+
private predicate methodCallHasNoInherentTarget(MethodCall mc) {
1343+
exists(Type rootType, string name, int arity |
1344+
isMethodCall(mc, rootType, name, arity) and
1345+
forall(Impl impl |
1346+
methodCandidate(rootType, name, arity, impl) and
1347+
not impl.hasTrait()
1348+
|
1349+
IsInstantiationOf<MethodCall, IsInstantiationOfInput>::isNotInstantiationOf(mc, impl, _)
1350+
)
1351+
)
1352+
}
1353+
1354+
pragma[nomagic]
1355+
private predicate methodCallHasImplCandidate(MethodCall mc, Impl impl) {
1356+
IsInstantiationOf<MethodCall, IsInstantiationOfInput>::isInstantiationOf(mc, impl, _) and
1357+
if impl.hasTrait() and not exists(mc.getTrait())
1358+
then
1359+
// inherent methods take precedence over trait methods, so only allow
1360+
// trait methods when there are no matching inherent methods
1361+
methodCallHasNoInherentTarget(mc)
1362+
else any()
1363+
}
1364+
13371365
/** Gets a method from an `impl` block that matches the method call `mc`. */
1366+
pragma[nomagic]
13381367
private Function getMethodFromImpl(MethodCall mc) {
1339-
exists(Impl impl |
1340-
IsInstantiationOf<MethodCall, IsInstantiationOfInput>::isInstantiationOf(mc, impl, _) and
1341-
result = getMethodSuccessor(impl, mc.getMethodName())
1368+
exists(Impl impl, string name |
1369+
methodCallHasImplCandidate(mc, impl) and
1370+
name = mc.getMethodName() and
1371+
result = getMethodSuccessor(impl, name)
13421372
|
1343-
not methodResolutionDependsOnArgument(impl, _, _, _, _, _) and
1344-
result = getMethodSuccessor(impl, mc.getMethodName())
1373+
not methodResolutionDependsOnArgument(impl, _, _, _, _, _)
13451374
or
13461375
exists(int pos, TypePath path, Type type |
1347-
methodResolutionDependsOnArgument(impl, mc.getMethodName(), result, pos, path, type) and
1376+
methodResolutionDependsOnArgument(impl, name, result, pos, path, type) and
13481377
inferType(mc.getPositionalArgument(pos), path) = type
13491378
)
13501379
)
@@ -1356,22 +1385,6 @@ private Function getTraitMethod(ImplTraitReturnType trait, string name) {
13561385
result = getMethodSuccessor(trait.getImplTraitTypeRepr(), name)
13571386
}
13581387

1359-
/**
1360-
* Gets a method that the method call `mc` resolves to based on type inference,
1361-
* if any.
1362-
*/
1363-
private Function inferMethodCallTarget(MethodCall mc) {
1364-
// The method comes from an `impl` block targeting the type of the receiver.
1365-
result = getMethodFromImpl(mc)
1366-
or
1367-
// The type of the receiver is a type parameter and the method comes from a
1368-
// trait bound on the type parameter.
1369-
result = getTypeParameterMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
1370-
or
1371-
// The type of the receiver is an `impl Trait` type.
1372-
result = getTraitMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
1373-
}
1374-
13751388
cached
13761389
private module Cached {
13771390
private import codeql.rust.internal.CachedStages
@@ -1400,47 +1413,18 @@ private module Cached {
14001413
)
14011414
}
14021415

1403-
private predicate isInherentImplFunction(Function f) {
1404-
f = any(Impl impl | not impl.hasTrait()).(ImplItemNode).getAnAssocItem()
1405-
}
1406-
1407-
private predicate isTraitImplFunction(Function f) {
1408-
f = any(Impl impl | impl.hasTrait()).(ImplItemNode).getAnAssocItem()
1409-
}
1410-
1411-
private Function resolveMethodCallTargetFrom(MethodCall mc, boolean fromSource) {
1412-
result = inferMethodCallTarget(mc) and
1413-
(if result.fromSource() then fromSource = true else fromSource = false) and
1414-
(
1415-
// prioritize inherent implementation methods first
1416-
isInherentImplFunction(result)
1417-
or
1418-
not isInherentImplFunction(inferMethodCallTarget(mc)) and
1419-
(
1420-
// then trait implementation methods
1421-
isTraitImplFunction(result)
1422-
or
1423-
not isTraitImplFunction(inferMethodCallTarget(mc)) and
1424-
(
1425-
// then trait methods with default implementations
1426-
result.hasBody()
1427-
or
1428-
// and finally trait methods without default implementations
1429-
not inferMethodCallTarget(mc).hasBody()
1430-
)
1431-
)
1432-
)
1433-
}
1434-
14351416
/** Gets a method that the method call `mc` resolves to, if any. */
14361417
cached
14371418
Function resolveMethodCallTarget(MethodCall mc) {
1438-
// Functions in source code also gets extracted as library code, due to
1439-
// this duplication we prioritize functions from source code.
1440-
result = resolveMethodCallTargetFrom(mc, true)
1419+
// The method comes from an `impl` block targeting the type of the receiver.
1420+
result = getMethodFromImpl(mc)
1421+
or
1422+
// The type of the receiver is a type parameter and the method comes from a
1423+
// trait bound on the type parameter.
1424+
result = getTypeParameterMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
14411425
or
1442-
not exists(resolveMethodCallTargetFrom(mc, true)) and
1443-
result = resolveMethodCallTargetFrom(mc, false)
1426+
// The type of the receiver is an `impl Trait` type.
1427+
result = getTraitMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
14441428
}
14451429

14461430
pragma[inline]

rust/ql/test/library-tests/dataflow/sources/CONSISTENCY/PathResolutionConsistency.expected

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ multipleCallTargets
1010
| test.rs:168:26:168:111 | ...::_print(...) |
1111
| test.rs:178:30:178:68 | ...::_print(...) |
1212
| test.rs:187:26:187:105 | ...::_print(...) |
13+
| test.rs:228:22:228:72 | ... .read_to_string(...) |
14+
| test.rs:482:22:482:50 | file.read_to_end(...) |
15+
| test.rs:488:22:488:53 | file.read_to_string(...) |
1316
| test.rs:609:18:609:38 | ...::_print(...) |
1417
| test.rs:614:18:614:45 | ...::_print(...) |
1518
| test.rs:618:25:618:49 | address.to_socket_addrs() |
Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
multipleCallTargets
22
| dereference.rs:61:15:61:24 | e1.deref() |
3-
| main.rs:1963:13:1963:31 | ...::from(...) |
4-
| main.rs:1964:13:1964:31 | ...::from(...) |
5-
| main.rs:1965:13:1965:31 | ...::from(...) |
6-
| main.rs:1970:13:1970:31 | ...::from(...) |
7-
| main.rs:1971:13:1971:31 | ...::from(...) |
8-
| main.rs:1972:13:1972:31 | ...::from(...) |
9-
| main.rs:2006:21:2006:43 | ...::from(...) |
3+
| main.rs:2032:13:2032:31 | ...::from(...) |
4+
| main.rs:2033:13:2033:31 | ...::from(...) |
5+
| main.rs:2034:13:2034:31 | ...::from(...) |
6+
| main.rs:2040:13:2040:31 | ...::from(...) |
7+
| main.rs:2041:13:2041:31 | ...::from(...) |
8+
| main.rs:2042:13:2042:31 | ...::from(...) |
9+
| main.rs:2078:21:2078:43 | ...::from(...) |

0 commit comments

Comments
 (0)