Skip to content

Commit 2d0e953

Browse files
committed
Rust: Implement basic type inference in QL
1 parent 8d01bbc commit 2d0e953

File tree

17 files changed

+667
-66
lines changed

17 files changed

+667
-66
lines changed

rust/ql/lib/codeql/rust/controlflow/CfgNodes.qll

+5-1
Original file line numberDiff line numberDiff line change
@@ -264,8 +264,12 @@ final class RecordPatCfgNode extends Nodes::RecordPatCfgNode {
264264
PatCfgNode getFieldPat(string field) {
265265
exists(RecordPatField rpf |
266266
rpf = node.getRecordPatFieldList().getAField() and
267-
any(ChildMapping mapping).hasCfgChild(node, rpf.getPat(), this, result) and
267+
any(ChildMapping mapping).hasCfgChild(node, rpf.getPat(), this, result)
268+
|
268269
field = rpf.getNameRef().getText()
270+
or
271+
not rpf.hasNameRef() and
272+
field = result.(IdentPatCfgNode).getName().getText()
269273
)
270274
}
271275
}

rust/ql/lib/codeql/rust/dataflow/internal/DataFlowImpl.qll

+63-44
Original file line numberDiff line numberDiff line change
@@ -746,10 +746,6 @@ abstract class Content extends TContent {
746746
*/
747747
abstract class VariantContent extends Content { }
748748

749-
private TupleField getVariantTupleField(Variant v, int i) {
750-
result = v.getFieldList().(TupleFieldList).getField(i)
751-
}
752-
753749
/** A tuple variant. */
754750
private class VariantTupleFieldContent extends VariantContent, TVariantTupleFieldContent {
755751
private Variant v;
@@ -763,16 +759,11 @@ private class VariantTupleFieldContent extends VariantContent, TVariantTupleFiel
763759
exists(string name |
764760
name = v.getName().getText() and
765761
// only print indices when the arity is > 1
766-
if exists(getVariantTupleField(v, 1)) then result = name + "(" + pos_ + ")" else result = name
762+
if exists(v.getTupleField(1)) then result = name + "(" + pos_ + ")" else result = name
767763
)
768764
}
769765

770-
final override Location getLocation() { result = getVariantTupleField(v, pos_).getLocation() }
771-
}
772-
773-
private RecordField getVariantRecordField(Variant v, string field) {
774-
result = v.getFieldList().(RecordFieldList).getAField() and
775-
field = result.getName().getText()
766+
final override Location getLocation() { result = v.getTupleField(pos_).getLocation() }
776767
}
777768

778769
/** A record variant. */
@@ -788,34 +779,56 @@ private class VariantRecordFieldContent extends VariantContent, TVariantRecordFi
788779
exists(string name |
789780
name = v.getName().getText() and
790781
// only print field when the arity is > 1
791-
if strictcount(string f | exists(getVariantRecordField(v, f))) > 1
782+
if strictcount(v.getRecordField(_)) > 1
792783
then result = name + "{" + field_ + "}"
793784
else result = name
794785
)
795786
}
796787

797788
final override Location getLocation() {
798-
result = getVariantRecordField(v, field_).getName().getLocation()
789+
result = v.getRecordField(field_).getName().getLocation()
799790
}
800791
}
801792

802-
/** Content stored in a field on a struct. */
803-
class StructFieldContent extends Content, TStructFieldContent {
793+
abstract private class FieldContent extends Content {
794+
pragma[nomagic]
795+
abstract FieldExprCfgNode getAnAccess();
796+
}
797+
798+
/** Content stored in a tuple field on a struct. */
799+
private class StructTupleFieldContent extends FieldContent, TStructTupleFieldContent {
800+
private Struct s;
801+
private int pos_;
802+
803+
StructTupleFieldContent() { this = TStructTupleFieldContent(s, pos_) }
804+
805+
Struct getStruct(int pos) { result = s and pos = pos_ }
806+
807+
override FieldExprCfgNode getAnAccess() {
808+
s.getTupleField(pos_) = result.getFieldExpr().getTupleField()
809+
}
810+
811+
override string toString() { result = s.getName().getText() + "." + pos_.toString() }
812+
813+
override Location getLocation() { result = s.getTupleField(pos_).getLocation() }
814+
}
815+
816+
/** Content stored in a record field on a struct. */
817+
class StructRecordFieldContent extends FieldContent, TStructRecordFieldContent {
804818
private Struct s;
805819
private string field_;
806820

807-
StructFieldContent() { this = TStructFieldContent(s, field_) }
821+
StructRecordFieldContent() { this = TStructRecordFieldContent(s, field_) }
808822

809823
Struct getStruct(string field) { result = s and field = field_ }
810824

825+
override FieldExprCfgNode getAnAccess() {
826+
s.getRecordField(field_) = result.getFieldExpr().getRecordField()
827+
}
828+
811829
override string toString() { result = s.getName().getText() + "." + field_.toString() }
812830

813-
override Location getLocation() {
814-
exists(Name f | f = s.getFieldList().(RecordFieldList).getAField().getName() |
815-
f.getText() = field_ and
816-
result = f.getLocation()
817-
)
818-
}
831+
override Location getLocation() { result = s.getRecordField(field_).getName().getLocation() }
819832
}
820833

821834
/** A captured variable. */
@@ -859,23 +872,23 @@ final class ElementContent extends Content, TElementContent {
859872
* NOTE: Unlike `struct`s and `enum`s tuples are structural and not nominal,
860873
* hence we don't store a canonical path for them.
861874
*/
862-
final class TuplePositionContent extends Content, TTuplePositionContent {
875+
final class TuplePositionContent extends FieldContent, TTuplePositionContent {
863876
private int pos;
864877

865878
TuplePositionContent() { this = TTuplePositionContent(pos) }
866879

867880
int getPosition() { result = pos }
868881

882+
override FieldExprCfgNode getAnAccess() {
883+
// todo: limit to tuple types
884+
result.getNameRef().getText().toInt() = pos
885+
}
886+
869887
override string toString() { result = "tuple." + pos.toString() }
870888

871889
override Location getLocation() { result instanceof EmptyLocation }
872890
}
873891

874-
/** Holds if `access` indexes a tuple at an index corresponding to `c`. */
875-
private predicate fieldTuplePositionContent(FieldExprCfgNode access, TuplePositionContent c) {
876-
access.getNameRef().getText().toInt() = c.getPosition()
877-
}
878-
879892
/** A value that represents a set of `Content`s. */
880893
abstract class ContentSet extends TContentSet {
881894
/** Gets a textual representation of this element. */
@@ -1107,6 +1120,10 @@ module RustDataFlow implements InputSig<Location> {
11071120
pragma[nomagic]
11081121
private predicate tupleVariantDestruction(TupleStructPat p, Variant v) { v = resolvePath(p) }
11091122

1123+
/** Holds if `p` destructs an struct `s`. */
1124+
pragma[nomagic]
1125+
private predicate tupleStructDestruction(TupleStructPat p, Struct s) { s = resolvePath(p) }
1126+
11101127
/** Holds if `p` destructs an enum variant `v`. */
11111128
pragma[nomagic]
11121129
private predicate recordVariantDestruction(RecordPat p, Variant v) { v = resolvePath(p) }
@@ -1128,6 +1145,10 @@ module RustDataFlow implements InputSig<Location> {
11281145
|
11291146
tupleVariantDestruction(pat.getPat(), c.(VariantTupleFieldContent).getVariant(pos))
11301147
or
1148+
tupleStructDestruction(pat.getPat(), c.(StructTupleFieldContent).getStruct(pos))
1149+
or
1150+
tupleStructDestruction(pat.getPat(), c.(StructTupleFieldContent).getStruct(pos))
1151+
or
11311152
VariantInLib::tupleVariantCanonicalDestruction(pat.getPat(), c, pos)
11321153
)
11331154
or
@@ -1144,7 +1165,7 @@ module RustDataFlow implements InputSig<Location> {
11441165
recordVariantDestruction(pat.getPat(), c.(VariantRecordFieldContent).getVariant(field))
11451166
or
11461167
// Pattern destructs a struct.
1147-
structDestruction(pat.getPat(), c.(StructFieldContent).getStruct(field))
1168+
structDestruction(pat.getPat(), c.(StructRecordFieldContent).getStruct(field))
11481169
) and
11491170
node2.asPat() = pat.getFieldPat(field)
11501171
)
@@ -1153,11 +1174,9 @@ module RustDataFlow implements InputSig<Location> {
11531174
node1.asPat().(RefPatCfgNode).getPat() = node2.asPat()
11541175
or
11551176
exists(FieldExprCfgNode access |
1156-
// Read of a tuple entry
1157-
fieldTuplePositionContent(access, c) and
1158-
// TODO: Handle read of a struct field.
11591177
node1.asExpr() = access.getExpr() and
1160-
node2.asExpr() = access
1178+
node2.asExpr() = access and
1179+
access = c.(FieldContent).getAnAccess()
11611180
)
11621181
or
11631182
exists(IndexExprCfgNode arr |
@@ -1213,12 +1232,13 @@ module RustDataFlow implements InputSig<Location> {
12131232
pragma[nomagic]
12141233
private predicate structConstruction(RecordExpr re, Struct s) { s = resolvePath(re) }
12151234

1216-
private predicate tupleAssignment(Node node1, Node node2, TuplePositionContent c) {
1235+
pragma[nomagic]
1236+
private predicate fieldAssignment(Node node1, Node node2, FieldContent c) {
12171237
exists(AssignmentExprCfgNode assignment, FieldExprCfgNode access |
12181238
assignment.getLhs() = access and
1219-
fieldTuplePositionContent(access, c) and
12201239
node1.asExpr() = assignment.getRhs() and
1221-
node2.asExpr() = access.getExpr()
1240+
node2.asExpr() = access.getExpr() and
1241+
access = c.getAnAccess()
12221242
)
12231243
}
12241244

@@ -1240,7 +1260,7 @@ module RustDataFlow implements InputSig<Location> {
12401260
c.(VariantRecordFieldContent).getVariant(field))
12411261
or
12421262
// Expression is for a struct.
1243-
structConstruction(re.getRecordExpr(), c.(StructFieldContent).getStruct(field))
1263+
structConstruction(re.getRecordExpr(), c.(StructRecordFieldContent).getStruct(field))
12441264
) and
12451265
node1.asExpr() = re.getFieldExpr(field) and
12461266
node2.asExpr() = re
@@ -1258,7 +1278,7 @@ module RustDataFlow implements InputSig<Location> {
12581278
node2.asExpr().(ArrayListExprCfgNode).getAnExpr()
12591279
]
12601280
or
1261-
tupleAssignment(node1, node2.(PostUpdateNode).getPreUpdateNode(), c)
1281+
fieldAssignment(node1, node2.(PostUpdateNode).getPreUpdateNode(), c)
12621282
or
12631283
exists(AssignmentExprCfgNode assignment, IndexExprCfgNode index |
12641284
c instanceof ElementContent and
@@ -1294,7 +1314,7 @@ module RustDataFlow implements InputSig<Location> {
12941314
* in `x.f = newValue`.
12951315
*/
12961316
predicate clearsContent(Node n, ContentSet cs) {
1297-
tupleAssignment(_, n, cs.(SingletonContentSet).getContent())
1317+
fieldAssignment(_, n, cs.(SingletonContentSet).getContent())
12981318
or
12991319
FlowSummaryImpl::Private::Steps::summaryClearsContent(n.(Node::FlowSummaryNode).getSummaryNode(),
13001320
cs)
@@ -1600,10 +1620,10 @@ private module Cached {
16001620

16011621
cached
16021622
newtype TContent =
1603-
TVariantTupleFieldContent(Variant v, int pos) { exists(getVariantTupleField(v, pos)) } or
1623+
TVariantTupleFieldContent(Variant v, int pos) { exists(v.getTupleField(pos)) } or
16041624
// TODO: Remove once library types are extracted
16051625
TVariantInLibTupleFieldContent(VariantInLib::VariantInLib v, int pos) { pos = v.getAPosition() } or
1606-
TVariantRecordFieldContent(Variant v, string field) { exists(getVariantRecordField(v, field)) } or
1626+
TVariantRecordFieldContent(Variant v, string field) { exists(v.getRecordField(field)) } or
16071627
TElementContent() or
16081628
TTuplePositionContent(int pos) {
16091629
pos in [0 .. max([
@@ -1612,9 +1632,8 @@ private module Cached {
16121632
]
16131633
)]
16141634
} or
1615-
TStructFieldContent(Struct s, string field) {
1616-
field = s.getFieldList().(RecordFieldList).getAField().getName().getText()
1617-
} or
1635+
TStructTupleFieldContent(Struct s, int pos) { exists(s.getTupleField(pos)) } or
1636+
TStructRecordFieldContent(Struct s, string field) { exists(s.getRecordField(field)) } or
16181637
TCapturedVariableContent(VariableCapture::CapturedVariable v) or
16191638
TReferenceContent()
16201639

rust/ql/lib/codeql/rust/dataflow/internal/FlowSummaryImpl.qll

+1-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ module Input implements InputSig<Location, RustDataFlow> {
8383
or
8484
exists(Struct s, string field |
8585
result = "Struct" and
86-
c = TStructFieldContent(s, field) and
86+
c = TStructRecordFieldContent(s, field) and
8787
// TODO: calculate in QL
8888
arg = s.getExtendedCanonicalPath() + "::" + field
8989
)

rust/ql/lib/codeql/rust/elements/internal/CallExprBaseImpl.qll

+5-2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ module Impl {
1717
private import codeql.rust.elements.internal.CallExprImpl::Impl
1818
private import codeql.rust.elements.internal.PathExprImpl::Impl
1919
private import codeql.rust.elements.internal.PathResolution
20+
private import codeql.rust.elements.internal.TypeInference
2021

2122
pragma[nomagic]
2223
Resolvable getCallResolvable(CallExprBase call) {
@@ -35,9 +36,11 @@ module Impl {
3536
* be statically resolved.
3637
*/
3738
Callable getStaticTarget() {
38-
getCallResolvable(this).resolvesAsItem(result)
39-
or
39+
// getCallResolvable(this).resolvesAsItem(result)
40+
// or
4041
result = resolvePath(this.(CallExpr).getFunction().(PathExpr).getPath())
42+
or
43+
result = resolveMethodCallExpr(this)
4144
}
4245
}
4346
}

rust/ql/lib/codeql/rust/elements/internal/FieldExprImpl.qll

+9
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ private import codeql.rust.elements.internal.generated.FieldExpr
1111
* be referenced directly.
1212
*/
1313
module Impl {
14+
private import rust
15+
private import TypeInference as TypeInference
16+
1417
// the following QLdoc is generated: if you need to edit it, do it in the schema file
1518
/**
1619
* A field access expression. For example:
@@ -26,5 +29,11 @@ module Impl {
2629
if abbr = "..." then result = "... ." + name else result = abbr + "." + name
2730
)
2831
}
32+
33+
/** Gets the record field that this access references, if any. */
34+
RecordField getRecordField() { result = TypeInference::resolveRecordFieldExpr(this) }
35+
36+
/** Gets the tuple field that this access references, if any. */
37+
TupleField getTupleField() { result = TypeInference::resolveTupleFieldExpr(this) }
2938
}
3039
}

rust/ql/lib/codeql/rust/elements/internal/PathResolution.qll

+8-2
Original file line numberDiff line numberDiff line change
@@ -184,10 +184,10 @@ abstract private class ImplOrTraitItemNode extends ItemNode {
184184
}
185185
}
186186

187-
private class ImplItemNode extends ImplOrTraitItemNode instanceof Impl {
187+
class ImplItemNode extends ImplOrTraitItemNode instanceof Impl {
188188
override string getName() { result = "(impl)" }
189189

190-
override Visibility getVisibility() { none() }
190+
override Visibility getVisibility() { result = Impl.super.getVisibility() }
191191
}
192192

193193
private class MacroCallItemNode extends ItemNode instanceof MacroCall {
@@ -232,6 +232,12 @@ private class BlockExprItemNode extends ItemNode instanceof BlockExpr {
232232
override Visibility getVisibility() { none() }
233233
}
234234

235+
private class TypeParamItemNode extends ItemNode instanceof TypeParam {
236+
override string getName() { result = TypeParam.super.getName().getText() }
237+
238+
override Visibility getVisibility() { none() }
239+
}
240+
235241
/** Holds if `item` has the name `name` and is a top-level item inside `f`. */
236242
private predicate sourceFileEdge(SourceFile f, string name, ItemNode item) {
237243
item = f.getAnItem() and

rust/ql/lib/codeql/rust/elements/internal/StructImpl.qll

+12
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
* INTERNAL: Do not use.
55
*/
66

7+
private import rust
78
private import codeql.rust.elements.internal.generated.Struct
89

910
/**
@@ -20,5 +21,16 @@ module Impl {
2021
*/
2122
class Struct extends Generated::Struct {
2223
override string toString() { result = "struct " + this.getName().getText() }
24+
25+
/** Gets the record field named `name`, if any. */
26+
pragma[nomagic]
27+
RecordField getRecordField(string name) {
28+
result = this.getFieldList().(RecordFieldList).getAField() and
29+
result.getName().getText() = name
30+
}
31+
32+
/** Gets the `i`th tuple field, if any. */
33+
pragma[nomagic]
34+
TupleField getTupleField(int i) { result = this.getFieldList().(TupleFieldList).getField(i) }
2335
}
2436
}

0 commit comments

Comments
 (0)