Skip to content

Commit 66bea42

Browse files
committed
Rust: Implement basic type inference in QL
1 parent 90944d5 commit 66bea42

File tree

16 files changed

+666
-65
lines changed

16 files changed

+666
-65
lines changed

rust/ql/lib/codeql/rust/controlflow/CfgNodes.qll

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,8 +264,12 @@ final class RecordPatCfgNode extends Nodes::RecordPatCfgNode {
264264
PatCfgNode getFieldPat(string field) {
265265
exists(RecordPatField rpf |
266266
rpf = node.getRecordPatFieldList().getAField() and
267-
any(ChildMapping mapping).hasCfgChild(node, rpf.getPat(), this, result) and
267+
any(ChildMapping mapping).hasCfgChild(node, rpf.getPat(), this, result)
268+
|
268269
field = rpf.getNameRef().getText()
270+
or
271+
not rpf.hasNameRef() and
272+
field = result.(IdentPatCfgNode).getName().getText()
269273
)
270274
}
271275
}

rust/ql/lib/codeql/rust/dataflow/internal/DataFlowImpl.qll

Lines changed: 63 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -740,10 +740,6 @@ abstract class Content extends TContent {
740740
*/
741741
abstract class VariantContent extends Content { }
742742

743-
private TupleField getVariantTupleField(Variant v, int i) {
744-
result = v.getFieldList().(TupleFieldList).getField(i)
745-
}
746-
747743
/** A tuple variant. */
748744
private class VariantTupleFieldContent extends VariantContent, TVariantTupleFieldContent {
749745
private Variant v;
@@ -757,16 +753,11 @@ private class VariantTupleFieldContent extends VariantContent, TVariantTupleFiel
757753
exists(string name |
758754
name = v.getName().getText() and
759755
// only print indices when the arity is > 1
760-
if exists(getVariantTupleField(v, 1)) then result = name + "(" + pos_ + ")" else result = name
756+
if exists(v.getTupleField(1)) then result = name + "(" + pos_ + ")" else result = name
761757
)
762758
}
763759

764-
final override Location getLocation() { result = getVariantTupleField(v, pos_).getLocation() }
765-
}
766-
767-
private RecordField getVariantRecordField(Variant v, string field) {
768-
result = v.getFieldList().(RecordFieldList).getAField() and
769-
field = result.getName().getText()
760+
final override Location getLocation() { result = v.getTupleField(pos_).getLocation() }
770761
}
771762

772763
/** A record variant. */
@@ -782,34 +773,56 @@ private class VariantRecordFieldContent extends VariantContent, TVariantRecordFi
782773
exists(string name |
783774
name = v.getName().getText() and
784775
// only print field when the arity is > 1
785-
if strictcount(string f | exists(getVariantRecordField(v, f))) > 1
776+
if strictcount(v.getRecordField(_)) > 1
786777
then result = name + "{" + field_ + "}"
787778
else result = name
788779
)
789780
}
790781

791782
final override Location getLocation() {
792-
result = getVariantRecordField(v, field_).getName().getLocation()
783+
result = v.getRecordField(field_).getName().getLocation()
793784
}
794785
}
795786

796-
/** Content stored in a field on a struct. */
797-
private class StructFieldContent extends Content, TStructFieldContent {
787+
abstract private class FieldContent extends Content {
788+
pragma[nomagic]
789+
abstract FieldExprCfgNode getAnAccess();
790+
}
791+
792+
/** Content stored in a tuple field on a struct. */
793+
private class StructTupleFieldContent extends FieldContent, TStructTupleFieldContent {
794+
private Struct s;
795+
private int pos_;
796+
797+
StructTupleFieldContent() { this = TStructTupleFieldContent(s, pos_) }
798+
799+
Struct getStruct(int pos) { result = s and pos = pos_ }
800+
801+
override FieldExprCfgNode getAnAccess() {
802+
s.getTupleField(pos_) = result.getFieldExpr().getTupleField()
803+
}
804+
805+
override string toString() { result = s.getName().getText() + "." + pos_.toString() }
806+
807+
override Location getLocation() { result = s.getTupleField(pos_).getLocation() }
808+
}
809+
810+
/** Content stored in a record field on a struct. */
811+
private class StructRecordFieldContent extends FieldContent, TStructRecordFieldContent {
798812
private Struct s;
799813
private string field_;
800814

801-
StructFieldContent() { this = TStructFieldContent(s, field_) }
815+
StructRecordFieldContent() { this = TStructRecordFieldContent(s, field_) }
802816

803817
Struct getStruct(string field) { result = s and field = field_ }
804818

819+
override FieldExprCfgNode getAnAccess() {
820+
s.getRecordField(field_) = result.getFieldExpr().getRecordField()
821+
}
822+
805823
override string toString() { result = s.getName().getText() + "." + field_.toString() }
806824

807-
override Location getLocation() {
808-
exists(Name f | f = s.getFieldList().(RecordFieldList).getAField().getName() |
809-
f.getText() = field_ and
810-
result = f.getLocation()
811-
)
812-
}
825+
override Location getLocation() { result = s.getRecordField(field_).getName().getLocation() }
813826
}
814827

815828
/** A captured variable. */
@@ -853,23 +866,23 @@ final class ElementContent extends Content, TElementContent {
853866
* NOTE: Unlike `struct`s and `enum`s tuples are structural and not nominal,
854867
* hence we don't store a canonical path for them.
855868
*/
856-
final class TuplePositionContent extends Content, TTuplePositionContent {
869+
final class TuplePositionContent extends FieldContent, TTuplePositionContent {
857870
private int pos;
858871

859872
TuplePositionContent() { this = TTuplePositionContent(pos) }
860873

861874
int getPosition() { result = pos }
862875

876+
override FieldExprCfgNode getAnAccess() {
877+
// todo: limit to tuple types
878+
result.getNameRef().getText().toInt() = pos
879+
}
880+
863881
override string toString() { result = "tuple." + pos.toString() }
864882

865883
override Location getLocation() { result instanceof EmptyLocation }
866884
}
867885

868-
/** Holds if `access` indexes a tuple at an index corresponding to `c`. */
869-
private predicate fieldTuplePositionContent(FieldExprCfgNode access, TuplePositionContent c) {
870-
access.getNameRef().getText().toInt() = c.getPosition()
871-
}
872-
873886
/** A value that represents a set of `Content`s. */
874887
abstract class ContentSet extends TContentSet {
875888
/** Gets a textual representation of this element. */
@@ -1101,6 +1114,10 @@ module RustDataFlow implements InputSig<Location> {
11011114
pragma[nomagic]
11021115
private predicate tupleVariantDestruction(TupleStructPat p, Variant v) { v = resolvePath(p) }
11031116

1117+
/** Holds if `p` destructs an struct `s`. */
1118+
pragma[nomagic]
1119+
private predicate tupleStructDestruction(TupleStructPat p, Struct s) { s = resolvePath(p) }
1120+
11041121
/** Holds if `p` destructs an enum variant `v`. */
11051122
pragma[nomagic]
11061123
private predicate recordVariantDestruction(RecordPat p, Variant v) { v = resolvePath(p) }
@@ -1122,6 +1139,10 @@ module RustDataFlow implements InputSig<Location> {
11221139
|
11231140
tupleVariantDestruction(pat.getPat(), c.(VariantTupleFieldContent).getVariant(pos))
11241141
or
1142+
tupleStructDestruction(pat.getPat(), c.(StructTupleFieldContent).getStruct(pos))
1143+
or
1144+
tupleStructDestruction(pat.getPat(), c.(StructTupleFieldContent).getStruct(pos))
1145+
or
11251146
VariantInLib::tupleVariantCanonicalDestruction(pat.getPat(), c, pos)
11261147
)
11271148
or
@@ -1138,7 +1159,7 @@ module RustDataFlow implements InputSig<Location> {
11381159
recordVariantDestruction(pat.getPat(), c.(VariantRecordFieldContent).getVariant(field))
11391160
or
11401161
// Pattern destructs a struct.
1141-
structDestruction(pat.getPat(), c.(StructFieldContent).getStruct(field))
1162+
structDestruction(pat.getPat(), c.(StructRecordFieldContent).getStruct(field))
11421163
) and
11431164
node2.asPat() = pat.getFieldPat(field)
11441165
)
@@ -1147,11 +1168,9 @@ module RustDataFlow implements InputSig<Location> {
11471168
node1.asPat().(RefPatCfgNode).getPat() = node2.asPat()
11481169
or
11491170
exists(FieldExprCfgNode access |
1150-
// Read of a tuple entry
1151-
fieldTuplePositionContent(access, c) and
1152-
// TODO: Handle read of a struct field.
11531171
node1.asExpr() = access.getExpr() and
1154-
node2.asExpr() = access
1172+
node2.asExpr() = access and
1173+
access = c.(FieldContent).getAnAccess()
11551174
)
11561175
or
11571176
exists(IndexExprCfgNode arr |
@@ -1207,12 +1226,13 @@ module RustDataFlow implements InputSig<Location> {
12071226
pragma[nomagic]
12081227
private predicate structConstruction(RecordExpr re, Struct s) { s = resolvePath(re) }
12091228

1210-
private predicate tupleAssignment(Node node1, Node node2, TuplePositionContent c) {
1229+
pragma[nomagic]
1230+
private predicate fieldAssignment(Node node1, Node node2, FieldContent c) {
12111231
exists(AssignmentExprCfgNode assignment, FieldExprCfgNode access |
12121232
assignment.getLhs() = access and
1213-
fieldTuplePositionContent(access, c) and
12141233
node1.asExpr() = assignment.getRhs() and
1215-
node2.asExpr() = access.getExpr()
1234+
node2.asExpr() = access.getExpr() and
1235+
access = c.getAnAccess()
12161236
)
12171237
}
12181238

@@ -1234,7 +1254,7 @@ module RustDataFlow implements InputSig<Location> {
12341254
c.(VariantRecordFieldContent).getVariant(field))
12351255
or
12361256
// Expression is for a struct.
1237-
structConstruction(re.getRecordExpr(), c.(StructFieldContent).getStruct(field))
1257+
structConstruction(re.getRecordExpr(), c.(StructRecordFieldContent).getStruct(field))
12381258
) and
12391259
node1.asExpr() = re.getFieldExpr(field) and
12401260
node2.asExpr() = re
@@ -1252,7 +1272,7 @@ module RustDataFlow implements InputSig<Location> {
12521272
node2.asExpr().(ArrayListExprCfgNode).getAnExpr()
12531273
]
12541274
or
1255-
tupleAssignment(node1, node2.(PostUpdateNode).getPreUpdateNode(), c)
1275+
fieldAssignment(node1, node2.(PostUpdateNode).getPreUpdateNode(), c)
12561276
or
12571277
exists(AssignmentExprCfgNode assignment, IndexExprCfgNode index |
12581278
c instanceof ElementContent and
@@ -1288,7 +1308,7 @@ module RustDataFlow implements InputSig<Location> {
12881308
* in `x.f = newValue`.
12891309
*/
12901310
predicate clearsContent(Node n, ContentSet cs) {
1291-
tupleAssignment(_, n, cs.(SingletonContentSet).getContent())
1311+
fieldAssignment(_, n, cs.(SingletonContentSet).getContent())
12921312
or
12931313
FlowSummaryImpl::Private::Steps::summaryClearsContent(n.(Node::FlowSummaryNode).getSummaryNode(),
12941314
cs)
@@ -1594,10 +1614,10 @@ private module Cached {
15941614

15951615
cached
15961616
newtype TContent =
1597-
TVariantTupleFieldContent(Variant v, int pos) { exists(getVariantTupleField(v, pos)) } or
1617+
TVariantTupleFieldContent(Variant v, int pos) { exists(v.getTupleField(pos)) } or
15981618
// TODO: Remove once library types are extracted
15991619
TVariantInLibTupleFieldContent(VariantInLib::VariantInLib v, int pos) { pos = v.getAPosition() } or
1600-
TVariantRecordFieldContent(Variant v, string field) { exists(getVariantRecordField(v, field)) } or
1620+
TVariantRecordFieldContent(Variant v, string field) { exists(v.getRecordField(field)) } or
16011621
TElementContent() or
16021622
TTuplePositionContent(int pos) {
16031623
pos in [0 .. max([
@@ -1606,9 +1626,8 @@ private module Cached {
16061626
]
16071627
)]
16081628
} or
1609-
TStructFieldContent(Struct s, string field) {
1610-
field = s.getFieldList().(RecordFieldList).getAField().getName().getText()
1611-
} or
1629+
TStructTupleFieldContent(Struct s, int pos) { exists(s.getTupleField(pos)) } or
1630+
TStructRecordFieldContent(Struct s, string field) { exists(s.getRecordField(field)) } or
16121631
TCapturedVariableContent(VariableCapture::CapturedVariable v) or
16131632
TReferenceContent()
16141633

rust/ql/lib/codeql/rust/dataflow/internal/FlowSummaryImpl.qll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ module Input implements InputSig<Location, RustDataFlow> {
8383
or
8484
exists(Struct s, string field |
8585
result = "Struct" and
86-
c = TStructFieldContent(s, field) and
86+
c = TStructRecordFieldContent(s, field) and
8787
// TODO: calculate in QL
8888
arg = s.getExtendedCanonicalPath() + "::" + field
8989
)

rust/ql/lib/codeql/rust/elements/internal/CallExprBaseImpl.qll

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ module Impl {
1717
private import codeql.rust.elements.internal.CallExprImpl::Impl
1818
private import codeql.rust.elements.internal.PathExprImpl::Impl
1919
private import codeql.rust.elements.internal.PathResolution
20+
private import codeql.rust.elements.internal.TypeInference
2021

2122
pragma[nomagic]
2223
Resolvable getCallResolvable(CallExprBase call) {
@@ -35,9 +36,11 @@ module Impl {
3536
* be statically resolved.
3637
*/
3738
Callable getStaticTarget() {
38-
getCallResolvable(this).resolvesAsItem(result)
39-
or
39+
// getCallResolvable(this).resolvesAsItem(result)
40+
// or
4041
result = resolvePath(this.(CallExpr).getFunction().(PathExpr).getPath())
42+
or
43+
result = resolveMethodCallExpr(this)
4144
}
4245
}
4346
}

rust/ql/lib/codeql/rust/elements/internal/FieldExprImpl.qll

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ private import codeql.rust.elements.internal.generated.FieldExpr
1111
* be referenced directly.
1212
*/
1313
module Impl {
14+
private import rust
15+
private import TypeInference as TypeInference
16+
1417
// the following QLdoc is generated: if you need to edit it, do it in the schema file
1518
/**
1619
* A field access expression. For example:
@@ -26,5 +29,11 @@ module Impl {
2629
if abbr = "..." then result = "... ." + name else result = abbr + "." + name
2730
)
2831
}
32+
33+
/** Gets the record field that this access references, if any. */
34+
RecordField getRecordField() { result = TypeInference::resolveRecordFieldExpr(this) }
35+
36+
/** Gets the tuple field that this access references, if any. */
37+
TupleField getTupleField() { result = TypeInference::resolveTupleFieldExpr(this) }
2938
}
3039
}

rust/ql/lib/codeql/rust/elements/internal/PathResolution.qll

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,10 +184,10 @@ abstract private class ImplOrTraitItemNode extends ItemNode {
184184
}
185185
}
186186

187-
private class ImplItemNode extends ImplOrTraitItemNode instanceof Impl {
187+
class ImplItemNode extends ImplOrTraitItemNode instanceof Impl {
188188
override string getName() { result = "(impl)" }
189189

190-
override Visibility getVisibility() { none() }
190+
override Visibility getVisibility() { result = Impl.super.getVisibility() }
191191
}
192192

193193
private class MacroCallItemNode extends ItemNode instanceof MacroCall {
@@ -232,6 +232,12 @@ private class BlockExprItemNode extends ItemNode instanceof BlockExpr {
232232
override Visibility getVisibility() { none() }
233233
}
234234

235+
private class TypeParamItemNode extends ItemNode instanceof TypeParam {
236+
override string getName() { result = TypeParam.super.getName().getText() }
237+
238+
override Visibility getVisibility() { none() }
239+
}
240+
235241
/** Holds if `item` has the name `name` and is a top-level item inside `f`. */
236242
private predicate sourceFileEdge(SourceFile f, string name, ItemNode item) {
237243
item = f.getAnItem() and

rust/ql/lib/codeql/rust/elements/internal/StructImpl.qll

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
* INTERNAL: Do not use.
55
*/
66

7+
private import rust
78
private import codeql.rust.elements.internal.generated.Struct
89

910
/**
@@ -20,5 +21,16 @@ module Impl {
2021
*/
2122
class Struct extends Generated::Struct {
2223
override string toString() { result = "struct " + this.getName().getText() }
24+
25+
/** Gets the record field named `name`, if any. */
26+
pragma[nomagic]
27+
RecordField getRecordField(string name) {
28+
result = this.getFieldList().(RecordFieldList).getAField() and
29+
result.getName().getText() = name
30+
}
31+
32+
/** Gets the `i`th tuple field, if any. */
33+
pragma[nomagic]
34+
TupleField getTupleField(int i) { result = this.getFieldList().(TupleFieldList).getField(i) }
2335
}
2436
}

0 commit comments

Comments
 (0)