Skip to content

Commit bea855a

Browse files
authored
Merge pull request #19789 from paldepind/rust/operator-borrowing
Rust: Account for borrows in operators in type inference
2 parents 7678679 + f18acdf commit bea855a

File tree

8 files changed

+1105
-1047
lines changed

8 files changed

+1105
-1047
lines changed

rust/ql/lib/codeql/rust/controlflow/CfgNodes.qll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,8 @@ final class CallCfgNode extends ExprCfgNode {
182182
}
183183

184184
/** Gets the `i`th argument of this call, if any. */
185-
ExprCfgNode getArgument(int i) {
186-
any(ChildMapping mapping).hasCfgChild(node, node.getArgument(i), this, result)
185+
ExprCfgNode getPositionalArgument(int i) {
186+
any(ChildMapping mapping).hasCfgChild(node, node.getPositionalArgument(i), this, result)
187187
}
188188
}
189189

rust/ql/lib/codeql/rust/dataflow/internal/DataFlowImpl.qll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ final class ParameterPosition extends TParameterPosition {
132132
final class ArgumentPosition extends ParameterPosition {
133133
/** Gets the argument of `call` at this position, if any. */
134134
Expr getArgument(Call call) {
135-
result = call.getArgument(this.getPosition())
135+
result = call.getPositionalArgument(this.getPosition())
136136
or
137137
result = call.getReceiver() and this.isSelf()
138138
}
@@ -145,7 +145,7 @@ final class ArgumentPosition extends ParameterPosition {
145145
* as the synthetic `ReceiverNode` is the argument for the `self` parameter.
146146
*/
147147
predicate isArgumentForCall(ExprCfgNode arg, CallCfgNode call, ParameterPosition pos) {
148-
call.getArgument(pos.getPosition()) = arg
148+
call.getPositionalArgument(pos.getPosition()) = arg
149149
or
150150
call.getReceiver() = arg and pos.isSelf() and not call.getCall().receiverImplicitlyBorrowed()
151151
}

rust/ql/lib/codeql/rust/elements/Call.qll

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,6 @@
44

55
private import internal.CallImpl
66

7+
final class ArgumentPosition = Impl::ArgumentPosition;
8+
79
final class Call = Impl::Call;

rust/ql/lib/codeql/rust/elements/internal/CallImpl.qll

Lines changed: 67 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -5,29 +5,58 @@ private import codeql.rust.elements.internal.ExprImpl::Impl as ExprImpl
55
private import codeql.rust.elements.Operation
66

77
module Impl {
8+
newtype TArgumentPosition =
9+
TPositionalArgumentPosition(int i) {
10+
i in [0 .. max([any(ParamList l).getNumberOfParams(), any(ArgList l).getNumberOfArgs()]) - 1]
11+
} or
12+
TSelfArgumentPosition()
13+
14+
/** An argument position in a call. */
15+
class ArgumentPosition extends TArgumentPosition {
16+
/** Gets the index of the argument in the call, if this is a positional argument. */
17+
int asPosition() { this = TPositionalArgumentPosition(result) }
18+
19+
/** Holds if this call position is a self argument. */
20+
predicate isSelf() { this instanceof TSelfArgumentPosition }
21+
22+
/** Gets a string representation of this argument position. */
23+
string toString() {
24+
result = this.asPosition().toString()
25+
or
26+
this.isSelf() and result = "self"
27+
}
28+
}
29+
830
/**
931
* An expression that calls a function.
1032
*
11-
* This class abstracts over the different ways in which a function can be called in Rust.
33+
* This class abstracts over the different ways in which a function can be
34+
* called in Rust.
1235
*/
1336
abstract class Call extends ExprImpl::Expr {
14-
/** Gets the number of arguments _excluding_ any `self` argument. */
15-
abstract int getNumberOfArguments();
16-
17-
/** Gets the receiver of this call if it is a method call. */
18-
abstract Expr getReceiver();
19-
20-
/** Holds if the call has a receiver that might be implicitly borrowed. */
21-
abstract predicate receiverImplicitlyBorrowed();
37+
/** Holds if the receiver of this call is implicitly borrowed. */
38+
predicate receiverImplicitlyBorrowed() { this.implicitBorrowAt(TSelfArgumentPosition()) }
2239

2340
/** Gets the trait targeted by this call, if any. */
2441
abstract Trait getTrait();
2542

2643
/** Gets the name of the method called if this call is a method call. */
2744
abstract string getMethodName();
2845

46+
/** Gets the argument at the given position, if any. */
47+
abstract Expr getArgument(ArgumentPosition pos);
48+
49+
/** Holds if the argument at `pos` might be implicitly borrowed. */
50+
abstract predicate implicitBorrowAt(ArgumentPosition pos);
51+
52+
/** Gets the number of arguments _excluding_ any `self` argument. */
53+
int getNumberOfArguments() { result = count(this.getArgument(TPositionalArgumentPosition(_))) }
54+
2955
/** Gets the `i`th argument of this call, if any. */
30-
abstract Expr getArgument(int i);
56+
Expr getPositionalArgument(int i) { result = this.getArgument(TPositionalArgumentPosition(i)) }
57+
58+
/** Gets the receiver of this call if it is a method call. */
59+
Expr getReceiver() { result = this.getArgument(TSelfArgumentPosition()) }
3160

3261
/** Gets the static target of this call, if any. */
3362
Function getStaticTarget() {
@@ -54,15 +83,13 @@ module Impl {
5483

5584
override string getMethodName() { none() }
5685

57-
override Expr getReceiver() { none() }
58-
5986
override Trait getTrait() { none() }
6087

61-
override predicate receiverImplicitlyBorrowed() { none() }
62-
63-
override int getNumberOfArguments() { result = super.getArgList().getNumberOfArgs() }
88+
override predicate implicitBorrowAt(ArgumentPosition pos) { none() }
6489

65-
override Expr getArgument(int i) { result = super.getArgList().getArg(i) }
90+
override Expr getArgument(ArgumentPosition pos) {
91+
result = super.getArgList().getArg(pos.asPosition())
92+
}
6693
}
6794

6895
private class CallExprMethodCall extends Call instanceof CallExpr {
@@ -73,8 +100,6 @@ module Impl {
73100

74101
override string getMethodName() { result = methodName }
75102

76-
override Expr getReceiver() { result = super.getArgList().getArg(0) }
77-
78103
override Trait getTrait() {
79104
result = resolvePath(qualifier) and
80105
// When the qualifier is `Self` and resolves to a trait, it's inside a
@@ -84,43 +109,50 @@ module Impl {
84109
qualifier.toString() != "Self"
85110
}
86111

87-
override predicate receiverImplicitlyBorrowed() { none() }
112+
override predicate implicitBorrowAt(ArgumentPosition pos) { none() }
88113

89-
override int getNumberOfArguments() { result = super.getArgList().getNumberOfArgs() - 1 }
90-
91-
override Expr getArgument(int i) { result = super.getArgList().getArg(i + 1) }
114+
override Expr getArgument(ArgumentPosition pos) {
115+
pos.isSelf() and result = super.getArgList().getArg(0)
116+
or
117+
result = super.getArgList().getArg(pos.asPosition() + 1)
118+
}
92119
}
93120

94121
private class MethodCallExprCall extends Call instanceof MethodCallExpr {
95122
override string getMethodName() { result = super.getIdentifier().getText() }
96123

97-
override Expr getReceiver() { result = this.(MethodCallExpr).getReceiver() }
98-
99124
override Trait getTrait() { none() }
100125

101-
override predicate receiverImplicitlyBorrowed() { any() }
126+
override predicate implicitBorrowAt(ArgumentPosition pos) { pos.isSelf() }
102127

103-
override int getNumberOfArguments() { result = super.getArgList().getNumberOfArgs() }
104-
105-
override Expr getArgument(int i) { result = super.getArgList().getArg(i) }
128+
override Expr getArgument(ArgumentPosition pos) {
129+
pos.isSelf() and result = this.(MethodCallExpr).getReceiver()
130+
or
131+
result = super.getArgList().getArg(pos.asPosition())
132+
}
106133
}
107134

108135
private class OperatorCall extends Call instanceof Operation {
109136
Trait trait;
110137
string methodName;
138+
int borrows;
111139

112-
OperatorCall() { super.isOverloaded(trait, methodName) }
140+
OperatorCall() { super.isOverloaded(trait, methodName, borrows) }
113141

114142
override string getMethodName() { result = methodName }
115143

116-
override Expr getReceiver() { result = super.getOperand(0) }
117-
118144
override Trait getTrait() { result = trait }
119145

120-
override predicate receiverImplicitlyBorrowed() { none() }
121-
122-
override int getNumberOfArguments() { result = super.getNumberOfOperands() - 1 }
146+
override predicate implicitBorrowAt(ArgumentPosition pos) {
147+
pos.isSelf() and borrows >= 1
148+
or
149+
pos.asPosition() = 0 and borrows = 2
150+
}
123151

124-
override Expr getArgument(int i) { result = super.getOperand(1) and i = 0 }
152+
override Expr getArgument(ArgumentPosition pos) {
153+
pos.isSelf() and result = super.getOperand(0)
154+
or
155+
pos.asPosition() = 0 and result = super.getOperand(1)
156+
}
125157
}
126158
}

rust/ql/lib/codeql/rust/elements/internal/OperationImpl.qll

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -9,79 +9,80 @@ private import codeql.rust.elements.internal.ExprImpl::Impl as ExprImpl
99

1010
/**
1111
* Holds if the operator `op` with arity `arity` is overloaded to a trait with
12-
* the canonical path `path` and the method name `method`.
12+
* the canonical path `path` and the method name `method`, and if it borrows its
13+
* first `borrows` arguments.
1314
*/
14-
private predicate isOverloaded(string op, int arity, string path, string method) {
15+
private predicate isOverloaded(string op, int arity, string path, string method, int borrows) {
1516
arity = 1 and
1617
(
1718
// Negation
18-
op = "-" and path = "core::ops::arith::Neg" and method = "neg"
19+
op = "-" and path = "core::ops::arith::Neg" and method = "neg" and borrows = 0
1920
or
2021
// Not
21-
op = "!" and path = "core::ops::bit::Not" and method = "not"
22+
op = "!" and path = "core::ops::bit::Not" and method = "not" and borrows = 0
2223
or
2324
// Dereference
24-
op = "*" and path = "core::ops::deref::Deref" and method = "deref"
25+
op = "*" and path = "core::ops::deref::Deref" and method = "deref" and borrows = 0
2526
)
2627
or
2728
arity = 2 and
2829
(
2930
// Comparison operators
30-
op = "==" and path = "core::cmp::PartialEq" and method = "eq"
31+
op = "==" and path = "core::cmp::PartialEq" and method = "eq" and borrows = 2
3132
or
32-
op = "!=" and path = "core::cmp::PartialEq" and method = "ne"
33+
op = "!=" and path = "core::cmp::PartialEq" and method = "ne" and borrows = 2
3334
or
34-
op = "<" and path = "core::cmp::PartialOrd" and method = "lt"
35+
op = "<" and path = "core::cmp::PartialOrd" and method = "lt" and borrows = 2
3536
or
36-
op = "<=" and path = "core::cmp::PartialOrd" and method = "le"
37+
op = "<=" and path = "core::cmp::PartialOrd" and method = "le" and borrows = 2
3738
or
38-
op = ">" and path = "core::cmp::PartialOrd" and method = "gt"
39+
op = ">" and path = "core::cmp::PartialOrd" and method = "gt" and borrows = 2
3940
or
40-
op = ">=" and path = "core::cmp::PartialOrd" and method = "ge"
41+
op = ">=" and path = "core::cmp::PartialOrd" and method = "ge" and borrows = 2
4142
or
4243
// Arithmetic operators
43-
op = "+" and path = "core::ops::arith::Add" and method = "add"
44+
op = "+" and path = "core::ops::arith::Add" and method = "add" and borrows = 0
4445
or
45-
op = "-" and path = "core::ops::arith::Sub" and method = "sub"
46+
op = "-" and path = "core::ops::arith::Sub" and method = "sub" and borrows = 0
4647
or
47-
op = "*" and path = "core::ops::arith::Mul" and method = "mul"
48+
op = "*" and path = "core::ops::arith::Mul" and method = "mul" and borrows = 0
4849
or
49-
op = "/" and path = "core::ops::arith::Div" and method = "div"
50+
op = "/" and path = "core::ops::arith::Div" and method = "div" and borrows = 0
5051
or
51-
op = "%" and path = "core::ops::arith::Rem" and method = "rem"
52+
op = "%" and path = "core::ops::arith::Rem" and method = "rem" and borrows = 0
5253
or
5354
// Arithmetic assignment expressions
54-
op = "+=" and path = "core::ops::arith::AddAssign" and method = "add_assign"
55+
op = "+=" and path = "core::ops::arith::AddAssign" and method = "add_assign" and borrows = 1
5556
or
56-
op = "-=" and path = "core::ops::arith::SubAssign" and method = "sub_assign"
57+
op = "-=" and path = "core::ops::arith::SubAssign" and method = "sub_assign" and borrows = 1
5758
or
58-
op = "*=" and path = "core::ops::arith::MulAssign" and method = "mul_assign"
59+
op = "*=" and path = "core::ops::arith::MulAssign" and method = "mul_assign" and borrows = 1
5960
or
60-
op = "/=" and path = "core::ops::arith::DivAssign" and method = "div_assign"
61+
op = "/=" and path = "core::ops::arith::DivAssign" and method = "div_assign" and borrows = 1
6162
or
62-
op = "%=" and path = "core::ops::arith::RemAssign" and method = "rem_assign"
63+
op = "%=" and path = "core::ops::arith::RemAssign" and method = "rem_assign" and borrows = 1
6364
or
6465
// Bitwise operators
65-
op = "&" and path = "core::ops::bit::BitAnd" and method = "bitand"
66+
op = "&" and path = "core::ops::bit::BitAnd" and method = "bitand" and borrows = 0
6667
or
67-
op = "|" and path = "core::ops::bit::BitOr" and method = "bitor"
68+
op = "|" and path = "core::ops::bit::BitOr" and method = "bitor" and borrows = 0
6869
or
69-
op = "^" and path = "core::ops::bit::BitXor" and method = "bitxor"
70+
op = "^" and path = "core::ops::bit::BitXor" and method = "bitxor" and borrows = 0
7071
or
71-
op = "<<" and path = "core::ops::bit::Shl" and method = "shl"
72+
op = "<<" and path = "core::ops::bit::Shl" and method = "shl" and borrows = 0
7273
or
73-
op = ">>" and path = "core::ops::bit::Shr" and method = "shr"
74+
op = ">>" and path = "core::ops::bit::Shr" and method = "shr" and borrows = 0
7475
or
7576
// Bitwise assignment operators
76-
op = "&=" and path = "core::ops::bit::BitAndAssign" and method = "bitand_assign"
77+
op = "&=" and path = "core::ops::bit::BitAndAssign" and method = "bitand_assign" and borrows = 1
7778
or
78-
op = "|=" and path = "core::ops::bit::BitOrAssign" and method = "bitor_assign"
79+
op = "|=" and path = "core::ops::bit::BitOrAssign" and method = "bitor_assign" and borrows = 1
7980
or
80-
op = "^=" and path = "core::ops::bit::BitXorAssign" and method = "bitxor_assign"
81+
op = "^=" and path = "core::ops::bit::BitXorAssign" and method = "bitxor_assign" and borrows = 1
8182
or
82-
op = "<<=" and path = "core::ops::bit::ShlAssign" and method = "shl_assign"
83+
op = "<<=" and path = "core::ops::bit::ShlAssign" and method = "shl_assign" and borrows = 1
8384
or
84-
op = ">>=" and path = "core::ops::bit::ShrAssign" and method = "shr_assign"
85+
op = ">>=" and path = "core::ops::bit::ShrAssign" and method = "shr_assign" and borrows = 1
8586
)
8687
}
8788

@@ -114,9 +115,9 @@ module Impl {
114115
* Holds if this operation is overloaded to the method `methodName` of the
115116
* trait `trait`.
116117
*/
117-
predicate isOverloaded(Trait trait, string methodName) {
118+
predicate isOverloaded(Trait trait, string methodName, int borrows) {
118119
isOverloaded(this.getOperatorName(), this.getNumberOfOperands(), trait.getCanonicalPath(),
119-
methodName)
120+
methodName, borrows)
120121
}
121122
}
122123
}

0 commit comments

Comments
 (0)