Skip to content

Commit 10a6897

Browse files
committed
Rust: Apply inherent method prioritization inside type inference loop
1 parent d5b10c5 commit 10a6897

File tree

3 files changed

+55
-245
lines changed

3 files changed

+55
-245
lines changed

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 33 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -699,7 +699,7 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
699699
}
700700

701701
Declaration getTarget() {
702-
result = inferMethodCallTarget(this) // mutual recursion; resolving method calls requires resolving types and vice versa
702+
result = resolveMethodCallTarget(this) // mutual recursion; resolving method calls requires resolving types and vice versa
703703
or
704704
result = CallExprImpl::getResolvedFunction(this)
705705
}
@@ -1178,14 +1178,14 @@ private predicate methodCandidateTrait(Type type, Trait trait, string name, int
11781178
methodCandidate(type, name, arity, impl)
11791179
}
11801180

1181-
private module IsInstantiationOfInput implements IsInstantiationOfInputSig<MethodCall> {
1182-
pragma[nomagic]
1183-
private predicate isMethodCall(MethodCall mc, Type rootType, string name, int arity) {
1184-
rootType = mc.getTypeAt(TypePath::nil()) and
1185-
name = mc.getMethodName() and
1186-
arity = mc.getNumberOfArguments()
1187-
}
1181+
pragma[nomagic]
1182+
private predicate isMethodCall(MethodCall mc, Type rootType, string name, int arity) {
1183+
rootType = mc.getTypeAt(TypePath::nil()) and
1184+
name = mc.getMethodName() and
1185+
arity = mc.getNumberOfArguments()
1186+
}
11881187

1188+
private module IsInstantiationOfInput implements IsInstantiationOfInputSig<MethodCall> {
11891189
pragma[nomagic]
11901190
predicate potentialInstantiationOf(MethodCall mc, TypeAbstraction impl, TypeMention constraint) {
11911191
exists(Type rootType, string name, int arity |
@@ -1335,16 +1335,29 @@ private predicate methodResolutionDependsOnArgument(
13351335
}
13361336

13371337
/** Gets a method from an `impl` block that matches the method call `mc`. */
1338+
pragma[nomagic]
13381339
private Function getMethodFromImpl(MethodCall mc) {
1339-
exists(Impl impl |
1340+
exists(Type rootType, string name, int arity, Impl impl |
1341+
isMethodCall(mc, rootType, name, arity) and
13401342
IsInstantiationOf<MethodCall, IsInstantiationOfInput>::isInstantiationOf(mc, impl, _) and
1341-
result = getMethodSuccessor(impl, mc.getMethodName())
1343+
result = getMethodSuccessor(impl, name) and
1344+
if impl.hasTrait() and not exists(mc.getTrait())
1345+
then
1346+
// inherent methods take precedence over trait methods, so only allow
1347+
// trait methods when there are no matching inherent methods
1348+
forall(Impl other |
1349+
not other.hasTrait() and
1350+
methodCandidate(rootType, name, arity, other)
1351+
|
1352+
IsInstantiationOf<MethodCall, IsInstantiationOfInput>::isNotInstantiationOf(mc, other, _)
1353+
)
1354+
else any()
13421355
|
13431356
not methodResolutionDependsOnArgument(impl, _, _, _, _, _) and
1344-
result = getMethodSuccessor(impl, mc.getMethodName())
1357+
result = getMethodSuccessor(impl, name)
13451358
or
13461359
exists(int pos, TypePath path, Type type |
1347-
methodResolutionDependsOnArgument(impl, mc.getMethodName(), result, pos, path, type) and
1360+
methodResolutionDependsOnArgument(impl, name, result, pos, path, type) and
13481361
inferType(mc.getPositionalArgument(pos), path) = type
13491362
)
13501363
)
@@ -1356,22 +1369,6 @@ private Function getTraitMethod(ImplTraitReturnType trait, string name) {
13561369
result = getMethodSuccessor(trait.getImplTraitTypeRepr(), name)
13571370
}
13581371

1359-
/**
1360-
* Gets a method that the method call `mc` resolves to based on type inference,
1361-
* if any.
1362-
*/
1363-
private Function inferMethodCallTarget(MethodCall mc) {
1364-
// The method comes from an `impl` block targeting the type of the receiver.
1365-
result = getMethodFromImpl(mc)
1366-
or
1367-
// The type of the receiver is a type parameter and the method comes from a
1368-
// trait bound on the type parameter.
1369-
result = getTypeParameterMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
1370-
or
1371-
// The type of the receiver is an `impl Trait` type.
1372-
result = getTraitMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
1373-
}
1374-
13751372
cached
13761373
private module Cached {
13771374
private import codeql.rust.internal.CachedStages
@@ -1400,47 +1397,18 @@ private module Cached {
14001397
)
14011398
}
14021399

1403-
private predicate isInherentImplFunction(Function f) {
1404-
f = any(Impl impl | not impl.hasTrait()).(ImplItemNode).getAnAssocItem()
1405-
}
1406-
1407-
private predicate isTraitImplFunction(Function f) {
1408-
f = any(Impl impl | impl.hasTrait()).(ImplItemNode).getAnAssocItem()
1409-
}
1410-
1411-
private Function resolveMethodCallTargetFrom(MethodCall mc, boolean fromSource) {
1412-
result = inferMethodCallTarget(mc) and
1413-
(if result.fromSource() then fromSource = true else fromSource = false) and
1414-
(
1415-
// prioritize inherent implementation methods first
1416-
isInherentImplFunction(result)
1417-
or
1418-
not isInherentImplFunction(inferMethodCallTarget(mc)) and
1419-
(
1420-
// then trait implementation methods
1421-
isTraitImplFunction(result)
1422-
or
1423-
not isTraitImplFunction(inferMethodCallTarget(mc)) and
1424-
(
1425-
// then trait methods with default implementations
1426-
result.hasBody()
1427-
or
1428-
// and finally trait methods without default implementations
1429-
not inferMethodCallTarget(mc).hasBody()
1430-
)
1431-
)
1432-
)
1433-
}
1434-
14351400
/** Gets a method that the method call `mc` resolves to, if any. */
14361401
cached
14371402
Function resolveMethodCallTarget(MethodCall mc) {
1438-
// Functions in source code also gets extracted as library code, due to
1439-
// this duplication we prioritize functions from source code.
1440-
result = resolveMethodCallTargetFrom(mc, true)
1403+
// The method comes from an `impl` block targeting the type of the receiver.
1404+
result = getMethodFromImpl(mc)
1405+
or
1406+
// The type of the receiver is a type parameter and the method comes from a
1407+
// trait bound on the type parameter.
1408+
result = getTypeParameterMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
14411409
or
1442-
not exists(resolveMethodCallTargetFrom(mc, true)) and
1443-
result = resolveMethodCallTargetFrom(mc, false)
1410+
// The type of the receiver is an `impl Trait` type.
1411+
result = getTraitMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
14441412
}
14451413

14461414
pragma[inline]

0 commit comments

Comments
 (0)