Skip to content

Commit

Permalink
Rust: Implement basic type inference in QL
Browse files Browse the repository at this point in the history
  • Loading branch information
hvitved committed Feb 4, 2025
1 parent 90944d5 commit 66bea42
Show file tree
Hide file tree
Showing 16 changed files with 666 additions and 65 deletions.
6 changes: 5 additions & 1 deletion rust/ql/lib/codeql/rust/controlflow/CfgNodes.qll
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,12 @@ final class RecordPatCfgNode extends Nodes::RecordPatCfgNode {
PatCfgNode getFieldPat(string field) {
exists(RecordPatField rpf |
rpf = node.getRecordPatFieldList().getAField() and
any(ChildMapping mapping).hasCfgChild(node, rpf.getPat(), this, result) and
any(ChildMapping mapping).hasCfgChild(node, rpf.getPat(), this, result)
|
field = rpf.getNameRef().getText()
or
not rpf.hasNameRef() and
field = result.(IdentPatCfgNode).getName().getText()
)
}
}
107 changes: 63 additions & 44 deletions rust/ql/lib/codeql/rust/dataflow/internal/DataFlowImpl.qll
Original file line number Diff line number Diff line change
Expand Up @@ -740,10 +740,6 @@ abstract class Content extends TContent {
*/
abstract class VariantContent extends Content { }

private TupleField getVariantTupleField(Variant v, int i) {
result = v.getFieldList().(TupleFieldList).getField(i)
}

/** A tuple variant. */
private class VariantTupleFieldContent extends VariantContent, TVariantTupleFieldContent {
private Variant v;
Expand All @@ -757,16 +753,11 @@ private class VariantTupleFieldContent extends VariantContent, TVariantTupleFiel
exists(string name |
name = v.getName().getText() and
// only print indices when the arity is > 1
if exists(getVariantTupleField(v, 1)) then result = name + "(" + pos_ + ")" else result = name
if exists(v.getTupleField(1)) then result = name + "(" + pos_ + ")" else result = name
)
}

final override Location getLocation() { result = getVariantTupleField(v, pos_).getLocation() }
}

private RecordField getVariantRecordField(Variant v, string field) {
result = v.getFieldList().(RecordFieldList).getAField() and
field = result.getName().getText()
final override Location getLocation() { result = v.getTupleField(pos_).getLocation() }
}

/** A record variant. */
Expand All @@ -782,34 +773,56 @@ private class VariantRecordFieldContent extends VariantContent, TVariantRecordFi
exists(string name |
name = v.getName().getText() and
// only print field when the arity is > 1
if strictcount(string f | exists(getVariantRecordField(v, f))) > 1
if strictcount(v.getRecordField(_)) > 1
then result = name + "{" + field_ + "}"
else result = name
)
}

final override Location getLocation() {
result = getVariantRecordField(v, field_).getName().getLocation()
result = v.getRecordField(field_).getName().getLocation()
}
}

/** Content stored in a field on a struct. */
private class StructFieldContent extends Content, TStructFieldContent {
abstract private class FieldContent extends Content {
pragma[nomagic]
abstract FieldExprCfgNode getAnAccess();
}

/** Content stored in a tuple field on a struct. */
private class StructTupleFieldContent extends FieldContent, TStructTupleFieldContent {
private Struct s;
private int pos_;

StructTupleFieldContent() { this = TStructTupleFieldContent(s, pos_) }

Struct getStruct(int pos) { result = s and pos = pos_ }

override FieldExprCfgNode getAnAccess() {
s.getTupleField(pos_) = result.getFieldExpr().getTupleField()
}

override string toString() { result = s.getName().getText() + "." + pos_.toString() }

override Location getLocation() { result = s.getTupleField(pos_).getLocation() }
}

/** Content stored in a record field on a struct. */
private class StructRecordFieldContent extends FieldContent, TStructRecordFieldContent {
private Struct s;
private string field_;

StructFieldContent() { this = TStructFieldContent(s, field_) }
StructRecordFieldContent() { this = TStructRecordFieldContent(s, field_) }

Struct getStruct(string field) { result = s and field = field_ }

override FieldExprCfgNode getAnAccess() {
s.getRecordField(field_) = result.getFieldExpr().getRecordField()
}

override string toString() { result = s.getName().getText() + "." + field_.toString() }

override Location getLocation() {
exists(Name f | f = s.getFieldList().(RecordFieldList).getAField().getName() |
f.getText() = field_ and
result = f.getLocation()
)
}
override Location getLocation() { result = s.getRecordField(field_).getName().getLocation() }
}

/** A captured variable. */
Expand Down Expand Up @@ -853,23 +866,23 @@ final class ElementContent extends Content, TElementContent {
* NOTE: Unlike `struct`s and `enum`s tuples are structural and not nominal,
* hence we don't store a canonical path for them.
*/
final class TuplePositionContent extends Content, TTuplePositionContent {
final class TuplePositionContent extends FieldContent, TTuplePositionContent {
private int pos;

TuplePositionContent() { this = TTuplePositionContent(pos) }

int getPosition() { result = pos }

override FieldExprCfgNode getAnAccess() {
// todo: limit to tuple types
result.getNameRef().getText().toInt() = pos
}

override string toString() { result = "tuple." + pos.toString() }

override Location getLocation() { result instanceof EmptyLocation }
}

/** Holds if `access` indexes a tuple at an index corresponding to `c`. */
private predicate fieldTuplePositionContent(FieldExprCfgNode access, TuplePositionContent c) {
access.getNameRef().getText().toInt() = c.getPosition()
}

/** A value that represents a set of `Content`s. */
abstract class ContentSet extends TContentSet {
/** Gets a textual representation of this element. */
Expand Down Expand Up @@ -1101,6 +1114,10 @@ module RustDataFlow implements InputSig<Location> {
pragma[nomagic]
private predicate tupleVariantDestruction(TupleStructPat p, Variant v) { v = resolvePath(p) }

/** Holds if `p` destructs an struct `s`. */
pragma[nomagic]
private predicate tupleStructDestruction(TupleStructPat p, Struct s) { s = resolvePath(p) }

/** Holds if `p` destructs an enum variant `v`. */
pragma[nomagic]
private predicate recordVariantDestruction(RecordPat p, Variant v) { v = resolvePath(p) }
Expand All @@ -1122,6 +1139,10 @@ module RustDataFlow implements InputSig<Location> {
|
tupleVariantDestruction(pat.getPat(), c.(VariantTupleFieldContent).getVariant(pos))
or
tupleStructDestruction(pat.getPat(), c.(StructTupleFieldContent).getStruct(pos))
or
tupleStructDestruction(pat.getPat(), c.(StructTupleFieldContent).getStruct(pos))
or
VariantInLib::tupleVariantCanonicalDestruction(pat.getPat(), c, pos)
)
or
Expand All @@ -1138,7 +1159,7 @@ module RustDataFlow implements InputSig<Location> {
recordVariantDestruction(pat.getPat(), c.(VariantRecordFieldContent).getVariant(field))
or
// Pattern destructs a struct.
structDestruction(pat.getPat(), c.(StructFieldContent).getStruct(field))
structDestruction(pat.getPat(), c.(StructRecordFieldContent).getStruct(field))
) and
node2.asPat() = pat.getFieldPat(field)
)
Expand All @@ -1147,11 +1168,9 @@ module RustDataFlow implements InputSig<Location> {
node1.asPat().(RefPatCfgNode).getPat() = node2.asPat()
or
exists(FieldExprCfgNode access |
// Read of a tuple entry
fieldTuplePositionContent(access, c) and
// TODO: Handle read of a struct field.
node1.asExpr() = access.getExpr() and
node2.asExpr() = access
node2.asExpr() = access and
access = c.(FieldContent).getAnAccess()
)
or
exists(IndexExprCfgNode arr |
Expand Down Expand Up @@ -1207,12 +1226,13 @@ module RustDataFlow implements InputSig<Location> {
pragma[nomagic]
private predicate structConstruction(RecordExpr re, Struct s) { s = resolvePath(re) }

private predicate tupleAssignment(Node node1, Node node2, TuplePositionContent c) {
pragma[nomagic]
private predicate fieldAssignment(Node node1, Node node2, FieldContent c) {
exists(AssignmentExprCfgNode assignment, FieldExprCfgNode access |
assignment.getLhs() = access and
fieldTuplePositionContent(access, c) and
node1.asExpr() = assignment.getRhs() and
node2.asExpr() = access.getExpr()
node2.asExpr() = access.getExpr() and
access = c.getAnAccess()
)
}

Expand All @@ -1234,7 +1254,7 @@ module RustDataFlow implements InputSig<Location> {
c.(VariantRecordFieldContent).getVariant(field))
or
// Expression is for a struct.
structConstruction(re.getRecordExpr(), c.(StructFieldContent).getStruct(field))
structConstruction(re.getRecordExpr(), c.(StructRecordFieldContent).getStruct(field))
) and
node1.asExpr() = re.getFieldExpr(field) and
node2.asExpr() = re
Expand All @@ -1252,7 +1272,7 @@ module RustDataFlow implements InputSig<Location> {
node2.asExpr().(ArrayListExprCfgNode).getAnExpr()
]
or
tupleAssignment(node1, node2.(PostUpdateNode).getPreUpdateNode(), c)
fieldAssignment(node1, node2.(PostUpdateNode).getPreUpdateNode(), c)
or
exists(AssignmentExprCfgNode assignment, IndexExprCfgNode index |
c instanceof ElementContent and
Expand Down Expand Up @@ -1288,7 +1308,7 @@ module RustDataFlow implements InputSig<Location> {
* in `x.f = newValue`.
*/
predicate clearsContent(Node n, ContentSet cs) {
tupleAssignment(_, n, cs.(SingletonContentSet).getContent())
fieldAssignment(_, n, cs.(SingletonContentSet).getContent())
or
FlowSummaryImpl::Private::Steps::summaryClearsContent(n.(Node::FlowSummaryNode).getSummaryNode(),
cs)
Expand Down Expand Up @@ -1594,10 +1614,10 @@ private module Cached {

cached
newtype TContent =
TVariantTupleFieldContent(Variant v, int pos) { exists(getVariantTupleField(v, pos)) } or
TVariantTupleFieldContent(Variant v, int pos) { exists(v.getTupleField(pos)) } or
// TODO: Remove once library types are extracted
TVariantInLibTupleFieldContent(VariantInLib::VariantInLib v, int pos) { pos = v.getAPosition() } or
TVariantRecordFieldContent(Variant v, string field) { exists(getVariantRecordField(v, field)) } or
TVariantRecordFieldContent(Variant v, string field) { exists(v.getRecordField(field)) } or
TElementContent() or
TTuplePositionContent(int pos) {
pos in [0 .. max([
Expand All @@ -1606,9 +1626,8 @@ private module Cached {
]
)]
} or
TStructFieldContent(Struct s, string field) {
field = s.getFieldList().(RecordFieldList).getAField().getName().getText()
} or
TStructTupleFieldContent(Struct s, int pos) { exists(s.getTupleField(pos)) } or
TStructRecordFieldContent(Struct s, string field) { exists(s.getRecordField(field)) } or
TCapturedVariableContent(VariableCapture::CapturedVariable v) or
TReferenceContent()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ module Input implements InputSig<Location, RustDataFlow> {
or
exists(Struct s, string field |
result = "Struct" and
c = TStructFieldContent(s, field) and
c = TStructRecordFieldContent(s, field) and
// TODO: calculate in QL
arg = s.getExtendedCanonicalPath() + "::" + field
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ module Impl {
private import codeql.rust.elements.internal.CallExprImpl::Impl
private import codeql.rust.elements.internal.PathExprImpl::Impl
private import codeql.rust.elements.internal.PathResolution
private import codeql.rust.elements.internal.TypeInference

pragma[nomagic]
Resolvable getCallResolvable(CallExprBase call) {
Expand All @@ -35,9 +36,11 @@ module Impl {
* be statically resolved.
*/
Callable getStaticTarget() {
getCallResolvable(this).resolvesAsItem(result)
or
// getCallResolvable(this).resolvesAsItem(result)
// or
result = resolvePath(this.(CallExpr).getFunction().(PathExpr).getPath())
or
result = resolveMethodCallExpr(this)
}
}
}
9 changes: 9 additions & 0 deletions rust/ql/lib/codeql/rust/elements/internal/FieldExprImpl.qll
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ private import codeql.rust.elements.internal.generated.FieldExpr
* be referenced directly.
*/
module Impl {
private import rust
private import TypeInference as TypeInference

// the following QLdoc is generated: if you need to edit it, do it in the schema file
/**
* A field access expression. For example:
Expand All @@ -26,5 +29,11 @@ module Impl {
if abbr = "..." then result = "... ." + name else result = abbr + "." + name
)
}

/** Gets the record field that this access references, if any. */
RecordField getRecordField() { result = TypeInference::resolveRecordFieldExpr(this) }

/** Gets the tuple field that this access references, if any. */
TupleField getTupleField() { result = TypeInference::resolveTupleFieldExpr(this) }
}
}
10 changes: 8 additions & 2 deletions rust/ql/lib/codeql/rust/elements/internal/PathResolution.qll
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,10 @@ abstract private class ImplOrTraitItemNode extends ItemNode {
}
}

private class ImplItemNode extends ImplOrTraitItemNode instanceof Impl {
class ImplItemNode extends ImplOrTraitItemNode instanceof Impl {
override string getName() { result = "(impl)" }

override Visibility getVisibility() { none() }
override Visibility getVisibility() { result = Impl.super.getVisibility() }
}

private class MacroCallItemNode extends ItemNode instanceof MacroCall {
Expand Down Expand Up @@ -232,6 +232,12 @@ private class BlockExprItemNode extends ItemNode instanceof BlockExpr {
override Visibility getVisibility() { none() }
}

private class TypeParamItemNode extends ItemNode instanceof TypeParam {
override string getName() { result = TypeParam.super.getName().getText() }

override Visibility getVisibility() { none() }
}

/** Holds if `item` has the name `name` and is a top-level item inside `f`. */
private predicate sourceFileEdge(SourceFile f, string name, ItemNode item) {
item = f.getAnItem() and
Expand Down
12 changes: 12 additions & 0 deletions rust/ql/lib/codeql/rust/elements/internal/StructImpl.qll
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
* INTERNAL: Do not use.
*/

private import rust
private import codeql.rust.elements.internal.generated.Struct

/**
Expand All @@ -20,5 +21,16 @@ module Impl {
*/
class Struct extends Generated::Struct {
override string toString() { result = "struct " + this.getName().getText() }

/** Gets the record field named `name`, if any. */
pragma[nomagic]
RecordField getRecordField(string name) {
result = this.getFieldList().(RecordFieldList).getAField() and
result.getName().getText() = name
}

/** Gets the `i`th tuple field, if any. */
pragma[nomagic]
TupleField getTupleField(int i) { result = this.getFieldList().(TupleFieldList).getField(i) }
}
}
Loading

0 comments on commit 66bea42

Please sign in to comment.