Skip to content

Commit 84f3f8c

Browse files
committed
Rust: Apply inherent method prioritization inside type inference loop
1 parent d5b10c5 commit 84f3f8c

File tree

6 files changed

+66
-252
lines changed

6 files changed

+66
-252
lines changed

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 33 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -699,7 +699,7 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
699699
}
700700

701701
Declaration getTarget() {
702-
result = inferMethodCallTarget(this) // mutual recursion; resolving method calls requires resolving types and vice versa
702+
result = resolveMethodCallTarget(this) // mutual recursion; resolving method calls requires resolving types and vice versa
703703
or
704704
result = CallExprImpl::getResolvedFunction(this)
705705
}
@@ -1178,14 +1178,14 @@ private predicate methodCandidateTrait(Type type, Trait trait, string name, int
11781178
methodCandidate(type, name, arity, impl)
11791179
}
11801180

1181-
private module IsInstantiationOfInput implements IsInstantiationOfInputSig<MethodCall> {
1182-
pragma[nomagic]
1183-
private predicate isMethodCall(MethodCall mc, Type rootType, string name, int arity) {
1184-
rootType = mc.getTypeAt(TypePath::nil()) and
1185-
name = mc.getMethodName() and
1186-
arity = mc.getNumberOfArguments()
1187-
}
1181+
pragma[nomagic]
1182+
private predicate isMethodCall(MethodCall mc, Type rootType, string name, int arity) {
1183+
rootType = mc.getTypeAt(TypePath::nil()) and
1184+
name = mc.getMethodName() and
1185+
arity = mc.getNumberOfArguments()
1186+
}
11881187

1188+
private module IsInstantiationOfInput implements IsInstantiationOfInputSig<MethodCall> {
11891189
pragma[nomagic]
11901190
predicate potentialInstantiationOf(MethodCall mc, TypeAbstraction impl, TypeMention constraint) {
11911191
exists(Type rootType, string name, int arity |
@@ -1335,16 +1335,29 @@ private predicate methodResolutionDependsOnArgument(
13351335
}
13361336

13371337
/** Gets a method from an `impl` block that matches the method call `mc`. */
1338+
pragma[nomagic]
13381339
private Function getMethodFromImpl(MethodCall mc) {
1339-
exists(Impl impl |
1340+
exists(Type rootType, string name, int arity, Impl impl |
1341+
isMethodCall(mc, rootType, name, arity) and
13401342
IsInstantiationOf<MethodCall, IsInstantiationOfInput>::isInstantiationOf(mc, impl, _) and
1341-
result = getMethodSuccessor(impl, mc.getMethodName())
1343+
result = getMethodSuccessor(impl, name) and
1344+
if impl.hasTrait() and not exists(mc.getTrait())
1345+
then
1346+
// inherent methods take precedence over trait methods, so only allow
1347+
// trait methods when there are no matching inherent methods
1348+
forall(Impl other |
1349+
not other.hasTrait() and
1350+
methodCandidate(rootType, name, arity, other)
1351+
|
1352+
IsInstantiationOf<MethodCall, IsInstantiationOfInput>::isNotInstantiationOf(mc, other, _)
1353+
)
1354+
else any()
13421355
|
13431356
not methodResolutionDependsOnArgument(impl, _, _, _, _, _) and
1344-
result = getMethodSuccessor(impl, mc.getMethodName())
1357+
result = getMethodSuccessor(impl, name)
13451358
or
13461359
exists(int pos, TypePath path, Type type |
1347-
methodResolutionDependsOnArgument(impl, mc.getMethodName(), result, pos, path, type) and
1360+
methodResolutionDependsOnArgument(impl, name, result, pos, path, type) and
13481361
inferType(mc.getPositionalArgument(pos), path) = type
13491362
)
13501363
)
@@ -1356,22 +1369,6 @@ private Function getTraitMethod(ImplTraitReturnType trait, string name) {
13561369
result = getMethodSuccessor(trait.getImplTraitTypeRepr(), name)
13571370
}
13581371

1359-
/**
1360-
* Gets a method that the method call `mc` resolves to based on type inference,
1361-
* if any.
1362-
*/
1363-
private Function inferMethodCallTarget(MethodCall mc) {
1364-
// The method comes from an `impl` block targeting the type of the receiver.
1365-
result = getMethodFromImpl(mc)
1366-
or
1367-
// The type of the receiver is a type parameter and the method comes from a
1368-
// trait bound on the type parameter.
1369-
result = getTypeParameterMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
1370-
or
1371-
// The type of the receiver is an `impl Trait` type.
1372-
result = getTraitMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
1373-
}
1374-
13751372
cached
13761373
private module Cached {
13771374
private import codeql.rust.internal.CachedStages
@@ -1400,47 +1397,18 @@ private module Cached {
14001397
)
14011398
}
14021399

1403-
private predicate isInherentImplFunction(Function f) {
1404-
f = any(Impl impl | not impl.hasTrait()).(ImplItemNode).getAnAssocItem()
1405-
}
1406-
1407-
private predicate isTraitImplFunction(Function f) {
1408-
f = any(Impl impl | impl.hasTrait()).(ImplItemNode).getAnAssocItem()
1409-
}
1410-
1411-
private Function resolveMethodCallTargetFrom(MethodCall mc, boolean fromSource) {
1412-
result = inferMethodCallTarget(mc) and
1413-
(if result.fromSource() then fromSource = true else fromSource = false) and
1414-
(
1415-
// prioritize inherent implementation methods first
1416-
isInherentImplFunction(result)
1417-
or
1418-
not isInherentImplFunction(inferMethodCallTarget(mc)) and
1419-
(
1420-
// then trait implementation methods
1421-
isTraitImplFunction(result)
1422-
or
1423-
not isTraitImplFunction(inferMethodCallTarget(mc)) and
1424-
(
1425-
// then trait methods with default implementations
1426-
result.hasBody()
1427-
or
1428-
// and finally trait methods without default implementations
1429-
not inferMethodCallTarget(mc).hasBody()
1430-
)
1431-
)
1432-
)
1433-
}
1434-
14351400
/** Gets a method that the method call `mc` resolves to, if any. */
14361401
cached
14371402
Function resolveMethodCallTarget(MethodCall mc) {
1438-
// Functions in source code also gets extracted as library code, due to
1439-
// this duplication we prioritize functions from source code.
1440-
result = resolveMethodCallTargetFrom(mc, true)
1403+
// The method comes from an `impl` block targeting the type of the receiver.
1404+
result = getMethodFromImpl(mc)
1405+
or
1406+
// The type of the receiver is a type parameter and the method comes from a
1407+
// trait bound on the type parameter.
1408+
result = getTypeParameterMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
14411409
or
1442-
not exists(resolveMethodCallTargetFrom(mc, true)) and
1443-
result = resolveMethodCallTargetFrom(mc, false)
1410+
// The type of the receiver is an `impl Trait` type.
1411+
result = getTraitMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
14441412
}
14451413

14461414
pragma[inline]

rust/ql/test/library-tests/dataflow/sources/CONSISTENCY/PathResolutionConsistency.expected

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ multipleCallTargets
1010
| test.rs:168:26:168:111 | ...::_print(...) |
1111
| test.rs:178:30:178:68 | ...::_print(...) |
1212
| test.rs:187:26:187:105 | ...::_print(...) |
13+
| test.rs:228:22:228:72 | ... .read_to_string(...) |
14+
| test.rs:482:22:482:50 | file.read_to_end(...) |
15+
| test.rs:488:22:488:53 | file.read_to_string(...) |
1316
| test.rs:609:18:609:38 | ...::_print(...) |
1417
| test.rs:614:18:614:45 | ...::_print(...) |
1518
| test.rs:618:25:618:49 | address.to_socket_addrs() |
Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
multipleCallTargets
22
| dereference.rs:61:15:61:24 | e1.deref() |
3-
| main.rs:1963:13:1963:31 | ...::from(...) |
4-
| main.rs:1964:13:1964:31 | ...::from(...) |
5-
| main.rs:1965:13:1965:31 | ...::from(...) |
6-
| main.rs:1970:13:1970:31 | ...::from(...) |
7-
| main.rs:1971:13:1971:31 | ...::from(...) |
8-
| main.rs:1972:13:1972:31 | ...::from(...) |
9-
| main.rs:2006:21:2006:43 | ...::from(...) |
3+
| main.rs:2032:13:2032:31 | ...::from(...) |
4+
| main.rs:2033:13:2033:31 | ...::from(...) |
5+
| main.rs:2034:13:2034:31 | ...::from(...) |
6+
| main.rs:2040:13:2040:31 | ...::from(...) |
7+
| main.rs:2041:13:2041:31 | ...::from(...) |
8+
| main.rs:2042:13:2042:31 | ...::from(...) |
9+
| main.rs:2078:21:2078:43 | ...::from(...) |

0 commit comments

Comments
 (0)