Skip to content

Commit ebdffcc

Browse files
committed
Rust: Refactor and generalize Call
1 parent 87b52cc commit ebdffcc

File tree

5 files changed

+102
-72
lines changed

5 files changed

+102
-72
lines changed

rust/ql/lib/codeql/rust/controlflow/CfgNodes.qll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,8 @@ final class CallCfgNode extends ExprCfgNode {
182182
}
183183

184184
/** Gets the `i`th argument of this call, if any. */
185-
ExprCfgNode getArgument(int i) {
186-
any(ChildMapping mapping).hasCfgChild(node, node.getArgument(i), this, result)
185+
ExprCfgNode getPositionalArgument(int i) {
186+
any(ChildMapping mapping).hasCfgChild(node, node.getPositionalArgument(i), this, result)
187187
}
188188
}
189189

rust/ql/lib/codeql/rust/dataflow/internal/DataFlowImpl.qll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ final class ParameterPosition extends TParameterPosition {
133133
final class ArgumentPosition extends ParameterPosition {
134134
/** Gets the argument of `call` at this position, if any. */
135135
Expr getArgument(Call call) {
136-
result = call.getArgument(this.getPosition())
136+
result = call.getPositionalArgument(this.getPosition())
137137
or
138138
result = call.getReceiver() and this.isSelf()
139139
}
@@ -146,7 +146,7 @@ final class ArgumentPosition extends ParameterPosition {
146146
* as the synthetic `ReceiverNode` is the argument for the `self` parameter.
147147
*/
148148
predicate isArgumentForCall(ExprCfgNode arg, CallCfgNode call, ParameterPosition pos) {
149-
call.getArgument(pos.getPosition()) = arg
149+
call.getPositionalArgument(pos.getPosition()) = arg
150150
or
151151
call.getReceiver() = arg and pos.isSelf() and not call.getCall().receiverImplicitlyBorrowed()
152152
}

rust/ql/lib/codeql/rust/elements/Call.qll

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,6 @@
44

55
private import internal.CallImpl
66

7+
final class ArgumentPosition = Impl::ArgumentPosition;
8+
79
final class Call = Impl::Call;

rust/ql/lib/codeql/rust/elements/internal/CallImpl.qll

Lines changed: 61 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -5,29 +5,58 @@ private import codeql.rust.elements.internal.ExprImpl::Impl as ExprImpl
55
private import codeql.rust.elements.Operation
66

77
module Impl {
8+
newtype TArgumentPosition =
9+
TPositionalArgumentPosition(int i) {
10+
i in [0 .. max([any(ParamList l).getNumberOfParams(), any(ArgList l).getNumberOfArgs()]) - 1]
11+
} or
12+
TSelfArgumentPosition()
13+
14+
/** An argument position in a call. */
15+
class ArgumentPosition extends TArgumentPosition {
16+
/** Gets the index of the argument in the call, if this is a positional argument. */
17+
int asPosition() { this = TPositionalArgumentPosition(result) }
18+
19+
/** Holds if this call position is a self argument. */
20+
predicate isSelf() { this instanceof TSelfArgumentPosition }
21+
22+
/** Gets a string representation of this argument position. */
23+
string toString() {
24+
result = this.asPosition().toString()
25+
or
26+
this.isSelf() and result = "self"
27+
}
28+
}
29+
830
/**
931
* An expression that calls a function.
1032
*
11-
* This class abstracts over the different ways in which a function can be called in Rust.
33+
* This class abstracts over the different ways in which a function can be
34+
* called in Rust.
1235
*/
1336
abstract class Call extends ExprImpl::Expr {
14-
/** Gets the number of arguments _excluding_ any `self` argument. */
15-
abstract int getNumberOfArguments();
16-
17-
/** Gets the receiver of this call if it is a method call. */
18-
abstract Expr getReceiver();
19-
20-
/** Holds if the call has a receiver that might be implicitly borrowed. */
21-
abstract predicate receiverImplicitlyBorrowed();
37+
/** Holds if the receiver of this call is implicitly borrowed. */
38+
predicate receiverImplicitlyBorrowed() { this.implicitBorrowAt(TSelfArgumentPosition()) }
2239

2340
/** Gets the trait targeted by this call, if any. */
2441
abstract Trait getTrait();
2542

2643
/** Gets the name of the method called if this call is a method call. */
2744
abstract string getMethodName();
2845

46+
/** Gets the argument at the given position, if any. */
47+
abstract Expr getArgument(ArgumentPosition pos);
48+
49+
/** Holds if the argument at `pos` might be implicitly borrowed. */
50+
abstract predicate implicitBorrowAt(ArgumentPosition pos);
51+
52+
/** Gets the number of arguments _excluding_ any `self` argument. */
53+
int getNumberOfArguments() { result = count(this.getArgument(TPositionalArgumentPosition(_))) }
54+
2955
/** Gets the `i`th argument of this call, if any. */
30-
abstract Expr getArgument(int i);
56+
Expr getPositionalArgument(int i) { result = this.getArgument(TPositionalArgumentPosition(i)) }
57+
58+
/** Gets the receiver of this call if it is a method call. */
59+
Expr getReceiver() { result = this.getArgument(TSelfArgumentPosition()) }
3160

3261
/** Gets the static target of this call, if any. */
3362
Function getStaticTarget() {
@@ -54,15 +83,13 @@ module Impl {
5483

5584
override string getMethodName() { none() }
5685

57-
override Expr getReceiver() { none() }
58-
5986
override Trait getTrait() { none() }
6087

61-
override predicate receiverImplicitlyBorrowed() { none() }
88+
override predicate implicitBorrowAt(ArgumentPosition pos) { none() }
6289

63-
override int getNumberOfArguments() { result = super.getArgList().getNumberOfArgs() }
64-
65-
override Expr getArgument(int i) { result = super.getArgList().getArg(i) }
90+
override Expr getArgument(ArgumentPosition pos) {
91+
result = super.getArgList().getArg(pos.asPosition())
92+
}
6693
}
6794

6895
private class CallExprMethodCall extends Call instanceof CallExpr {
@@ -73,8 +100,6 @@ module Impl {
73100

74101
override string getMethodName() { result = methodName }
75102

76-
override Expr getReceiver() { result = super.getArgList().getArg(0) }
77-
78103
override Trait getTrait() {
79104
result = resolvePath(qualifier) and
80105
// When the qualifier is `Self` and resolves to a trait, it's inside a
@@ -84,25 +109,27 @@ module Impl {
84109
qualifier.toString() != "Self"
85110
}
86111

87-
override predicate receiverImplicitlyBorrowed() { none() }
88-
89-
override int getNumberOfArguments() { result = super.getArgList().getNumberOfArgs() - 1 }
112+
override predicate implicitBorrowAt(ArgumentPosition pos) { none() }
90113

91-
override Expr getArgument(int i) { result = super.getArgList().getArg(i + 1) }
114+
override Expr getArgument(ArgumentPosition pos) {
115+
pos.isSelf() and result = super.getArgList().getArg(0)
116+
or
117+
result = super.getArgList().getArg(pos.asPosition() + 1)
118+
}
92119
}
93120

94121
private class MethodCallExprCall extends Call instanceof MethodCallExpr {
95122
override string getMethodName() { result = super.getIdentifier().getText() }
96123

97-
override Expr getReceiver() { result = this.(MethodCallExpr).getReceiver() }
98-
99124
override Trait getTrait() { none() }
100125

101-
override predicate receiverImplicitlyBorrowed() { any() }
126+
override predicate implicitBorrowAt(ArgumentPosition pos) { pos.isSelf() }
102127

103-
override int getNumberOfArguments() { result = super.getArgList().getNumberOfArgs() }
104-
105-
override Expr getArgument(int i) { result = super.getArgList().getArg(i) }
128+
override Expr getArgument(ArgumentPosition pos) {
129+
pos.isSelf() and result = this.(MethodCallExpr).getReceiver()
130+
or
131+
result = super.getArgList().getArg(pos.asPosition())
132+
}
106133
}
107134

108135
private class OperatorCall extends Call instanceof Operation {
@@ -113,14 +140,14 @@ module Impl {
113140

114141
override string getMethodName() { result = methodName }
115142

116-
override Expr getReceiver() { result = super.getOperand(0) }
117-
118143
override Trait getTrait() { result = trait }
119144

120-
override predicate receiverImplicitlyBorrowed() { none() }
145+
override predicate implicitBorrowAt(ArgumentPosition pos) { none() }
121146

122-
override int getNumberOfArguments() { result = super.getNumberOfOperands() - 1 }
123-
124-
override Expr getArgument(int i) { result = super.getOperand(1) and i = 0 }
147+
override Expr getArgument(ArgumentPosition pos) {
148+
pos.isSelf() and result = super.getOperand(0)
149+
or
150+
pos.asPosition() = 0 and result = super.getOperand(1)
151+
}
125152
}
126153
}

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 35 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -503,22 +503,20 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
503503
private predicate paramPos(ParamList pl, Param p, int pos) { p = pl.getParam(pos) }
504504

505505
private newtype TDeclarationPosition =
506-
TSelfDeclarationPosition() or
507-
TPositionalDeclarationPosition(int pos) { paramPos(_, _, pos) } or
506+
TArgumentDeclarationPosition(ArgumentPosition pos) or
508507
TReturnDeclarationPosition()
509508

510509
class DeclarationPosition extends TDeclarationPosition {
511-
predicate isSelf() { this = TSelfDeclarationPosition() }
510+
predicate isSelf() { this.asArgumentPosition().isSelf() }
512511

513-
int asPosition() { this = TPositionalDeclarationPosition(result) }
512+
int asPosition() { result = this.asArgumentPosition().asPosition() }
513+
514+
ArgumentPosition asArgumentPosition() { this = TArgumentDeclarationPosition(result) }
514515

515516
predicate isReturn() { this = TReturnDeclarationPosition() }
516517

517518
string toString() {
518-
this.isSelf() and
519-
result = "self"
520-
or
521-
result = this.asPosition().toString()
519+
result = this.asArgumentPosition().toString()
522520
or
523521
this.isReturn() and
524522
result = "(return)"
@@ -551,7 +549,7 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
551549
override Type getParameterType(DeclarationPosition dpos, TypePath path) {
552550
exists(int pos |
553551
result = this.getTupleField(pos).getTypeRepr().(TypeMention).resolveTypeAt(path) and
554-
dpos = TPositionalDeclarationPosition(pos)
552+
pos = dpos.asPosition()
555553
)
556554
}
557555

@@ -572,9 +570,9 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
572570
}
573571

574572
override Type getParameterType(DeclarationPosition dpos, TypePath path) {
575-
exists(int p |
576-
result = this.getTupleField(p).getTypeRepr().(TypeMention).resolveTypeAt(path) and
577-
dpos = TPositionalDeclarationPosition(p)
573+
exists(int pos |
574+
result = this.getTupleField(pos).getTypeRepr().(TypeMention).resolveTypeAt(path) and
575+
pos = dpos.asPosition()
578576
)
579577
}
580578

@@ -609,7 +607,7 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
609607
override Type getParameterType(DeclarationPosition dpos, TypePath path) {
610608
exists(Param p, int i |
611609
paramPos(this.getParamList(), p, i) and
612-
dpos = TPositionalDeclarationPosition(i) and
610+
i = dpos.asPosition() and
613611
result = inferAnnotatedType(p.getPat(), path)
614612
)
615613
or
@@ -642,22 +640,21 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
642640
}
643641

644642
private newtype TAccessPosition =
645-
TSelfAccessPosition(Boolean implicitlyBorrowed) or
646-
TPositionalAccessPosition(int pos) { exists(TPositionalDeclarationPosition(pos)) } or
643+
TArgumentAccessPosition(ArgumentPosition pos, Boolean borrowed) or
647644
TReturnAccessPosition()
648645

649646
class AccessPosition extends TAccessPosition {
650-
predicate isSelf(boolean implicitlyBorrowed) { this = TSelfAccessPosition(implicitlyBorrowed) }
647+
ArgumentPosition getArgumentPosition() { this = TArgumentAccessPosition(result, _) }
651648

652-
int asPosition() { this = TPositionalAccessPosition(result) }
649+
predicate isBorrowed() { this = TArgumentAccessPosition(_, true) }
653650

654651
predicate isReturn() { this = TReturnAccessPosition() }
655652

656653
string toString() {
657-
this.isSelf(_) and
658-
result = "self"
659-
or
660-
result = this.asPosition().toString()
654+
exists(ArgumentPosition pos, boolean borrowed |
655+
this = TArgumentAccessPosition(pos, borrowed) and
656+
result = pos + ":" + borrowed
657+
)
661658
or
662659
this.isReturn() and
663660
result = "(return)"
@@ -677,10 +674,11 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
677674
}
678675

679676
AstNode getNodeAt(AccessPosition apos) {
680-
result = this.getArgument(apos.asPosition())
681-
or
682-
result = this.getReceiver() and
683-
if this.receiverImplicitlyBorrowed() then apos.isSelf(true) else apos.isSelf(false)
677+
exists(ArgumentPosition pos, boolean borrowed |
678+
apos = TArgumentAccessPosition(pos, borrowed) and
679+
result = this.getArgument(pos) and
680+
if this.implicitBorrowAt(pos) then borrowed = true else borrowed = false
681+
)
684682
or
685683
result = this and apos.isReturn()
686684
}
@@ -697,9 +695,7 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
697695
}
698696

699697
predicate accessDeclarationPositionMatch(AccessPosition apos, DeclarationPosition dpos) {
700-
apos.isSelf(_) and dpos.isSelf()
701-
or
702-
apos.asPosition() = dpos.asPosition()
698+
apos.getArgumentPosition() = dpos.asArgumentPosition()
703699
or
704700
apos.isReturn() and dpos.isReturn()
705701
}
@@ -709,10 +705,13 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
709705
predicate adjustAccessType(
710706
AccessPosition apos, Declaration target, TypePath path, Type t, TypePath pathAdj, Type tAdj
711707
) {
712-
if apos.isSelf(true)
708+
if apos.getArgumentPosition().isSelf() and apos.isBorrowed()
713709
then
714710
exists(Type selfParamType |
715-
selfParamType = target.getParameterType(TSelfDeclarationPosition(), TypePath::nil())
711+
selfParamType =
712+
target
713+
.getParameterType(TArgumentDeclarationPosition(apos.getArgumentPosition()),
714+
TypePath::nil())
716715
|
717716
if selfParamType = TRefType()
718717
then
@@ -771,7 +770,7 @@ private Type inferCallExprBaseType(AstNode n, TypePath path) {
771770
// temporary workaround until implicit borrows are handled correctly
772771
if a instanceof Operation then apos.isReturn() else any()
773772
|
774-
if apos.isSelf(_)
773+
if apos.getArgumentPosition().isSelf()
775774
then
776775
exists(Type receiverType | receiverType = inferType(n) |
777776
if receiverType = TRefType()
@@ -1356,7 +1355,7 @@ private Function getMethodFromImpl(MethodCall mc) {
13561355
or
13571356
exists(int pos, TypePath path, Type type |
13581357
methodResolutionDependsOnArgument(impl, mc.getMethodName(), result, pos, path, type) and
1359-
inferType(mc.getArgument(pos), path) = type
1358+
inferType(mc.getPositionalArgument(pos), path) = type
13601359
)
13611360
)
13621361
}
@@ -1391,7 +1390,8 @@ private module Cached {
13911390
cached
13921391
predicate receiverHasImplicitDeref(AstNode receiver) {
13931392
exists(CallExprBaseMatchingInput::Access a, CallExprBaseMatchingInput::AccessPosition apos |
1394-
apos.isSelf(true) and
1393+
apos.getArgumentPosition().isSelf() and
1394+
apos.isBorrowed() and
13951395
receiver = a.getNodeAt(apos) and
13961396
inferType(receiver) = TRefType() and
13971397
CallExprBaseMatching::inferAccessType(a, apos, TypePath::nil()) != TRefType()
@@ -1402,7 +1402,8 @@ private module Cached {
14021402
cached
14031403
predicate receiverHasImplicitBorrow(AstNode receiver) {
14041404
exists(CallExprBaseMatchingInput::Access a, CallExprBaseMatchingInput::AccessPosition apos |
1405-
apos.isSelf(true) and
1405+
apos.getArgumentPosition().isSelf() and
1406+
apos.isBorrowed() and
14061407
receiver = a.getNodeAt(apos) and
14071408
CallExprBaseMatching::inferAccessType(a, apos, TypePath::nil()) = TRefType() and
14081409
inferType(receiver) != TRefType()

0 commit comments

Comments
 (0)