Skip to content

Commit

Permalink
[XLS] Fix proc-state-legalization for state-reads that influence next…
Browse files Browse the repository at this point in the history
…_values

If a next-value (state write)'s predicate depends on the read of that same state element, we would previously add the predicate to the state read - introducing a cycle in the graph, and causing codegen to fail with an error.

To fix this, we automatically use a variant of the write's predicate, removing all dependency on the state read's value, but guaranteeing that the new predicate is true whenever the original expression would have been true.

PiperOrigin-RevId: 719457555
  • Loading branch information
ericastor authored and copybara-github committed Jan 24, 2025
1 parent 14e095b commit 3466ae5
Show file tree
Hide file tree
Showing 5 changed files with 491 additions and 35 deletions.
100 changes: 100 additions & 0 deletions xls/ir/node_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#include "xls/data_structures/leaf_type_tree.h"
#include "xls/ir/bits.h"
#include "xls/ir/channel.h"
#include "xls/ir/dfs_visitor.h"
#include "xls/ir/function_base.h"
#include "xls/ir/node.h"
#include "xls/ir/nodes.h"
Expand Down Expand Up @@ -825,4 +826,103 @@ bool AreAllLiteral(absl::Span<Node* const> nodes) {
return absl::c_all_of(nodes, [](Node* i) -> bool { return IsLiteral(i); });
}

namespace {

class NodeSearch : public DfsVisitorWithDefault {
public:
explicit NodeSearch(Node* target) : target_(target) {}

absl::Status DefaultHandler(Node* node) override {
if (node == target_) {
// We've found our target, and can cancel the search. (This causes
// Node::Accept to return early, since we've already accomplished our
// goal.)
return absl::CancelledError();
}
return absl::OkStatus();
}

private:
Node* target_;
};

} // namespace

bool IsAncestorOf(Node* a, Node* b) {
CHECK_NE(a, nullptr);
CHECK_NE(b, nullptr);

if (a->function_base() != b->function_base()) {
return false;
}
if (a == b) {
return false;
}

NodeSearch visitor(a);
absl::Status visitor_status = b->Accept(&visitor);
CHECK(visitor_status.ok() || absl::IsCancelled(visitor_status));
return visitor.IsVisited(a);
}

absl::StatusOr<Node*> RemoveNodeFromBooleanExpression(Node* to_remove,
Node* expression,
bool favored_outcome) {
XLS_RET_CHECK(expression->GetType()->IsBits());
XLS_RET_CHECK_EQ(expression->GetType()->AsBitsOrDie()->bit_count(), 1)
<< expression->ToString();

if (expression == to_remove) {
return expression->function_base()->MakeNode<Literal>(
expression->loc(), Value(UBits((favored_outcome ? 1 : 0), 1)));
}

if (expression->op() == Op::kNot) {
XLS_ASSIGN_OR_RETURN(
Node * new_operand,
RemoveNodeFromBooleanExpression(to_remove, expression->operand(0),
/*favored_outcome=*/!favored_outcome));
if (new_operand == expression->operand(0)) {
// No change was necessary; apparently `to_remove` was not present.
return expression;
} else {
return expression->function_base()->MakeNode<UnOp>(expression->loc(),
new_operand, Op::kNot);
}
}

if (expression->OpIn({Op::kAnd, Op::kOr, Op::kNand, Op::kNor})) {
const bool favored_operand =
favored_outcome ^ expression->OpIn({Op::kNand, Op::kNor});

std::vector<Node*> new_operands;
new_operands.reserve(expression->operands().size());
bool changed = false;
for (Node* operand : expression->operands()) {
XLS_ASSIGN_OR_RETURN(
Node * new_operand,
RemoveNodeFromBooleanExpression(to_remove, operand, favored_operand));
new_operands.push_back(new_operand);
if (new_operand != operand) {
changed = true;
}
}
if (changed) {
return expression->function_base()->MakeNode<NaryOp>(
expression->loc(), new_operands, expression->op());
} else {
// No change was necessary; apparently `to_remove` was not present.
return expression;
}
}

if (IsAncestorOf(to_remove, expression)) {
// We're unable to remove `to_remove` from `expression` directly, but it is
// an ancestor; just replace the entire expression with a literal.
return expression->function_base()->MakeNode<Literal>(
expression->loc(), Value(UBits((favored_outcome ? 1 : 0), 1)));
}
return expression;
}

} // namespace xls
15 changes: 15 additions & 0 deletions xls/ir/node_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,21 @@ inline absl::StatusOr<Node*> UnsignedUpperBoundLiteral(Node* v,
// Check if all nodes are literals
bool AreAllLiteral(absl::Span<Node* const> nodes);

// Returns whether `a` is an ancestor of `b`; i.e., whether `b` could possibly
// be affected by a change to `a`. Returns false if `a` and `b` are the same
// node.
bool IsAncestorOf(Node* a, Node* b);

// Removes the given node from the given boolean expression, returning the
// result. We guarantee that the result no longer depends on `to_remove`, and
// that whenever `old_expression == favored_outcome`, `new_expression ==
// favored_outcome`. Note that `new_expression` may be `favored_outcome` in more
// cases than `old_expression`; if necessary, `new_expression` may always equal
// `favored_outcome`.
absl::StatusOr<Node*> RemoveNodeFromBooleanExpression(Node* to_remove,
Node* expression,
bool favored_outcome);

} // namespace xls

#endif // XLS_IR_NODE_UTIL_H_
140 changes: 140 additions & 0 deletions xls/ir/node_util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -656,5 +656,145 @@ TEST_F(NodeUtilTest, UnsignedBoundByLiterals) {
m::ULt(m::Param("foo2"), m::Literal(UBits(10, 10))),
{m::Literal(UBits(10, 10))}, m::Param("foo2")));
}

TEST_F(NodeUtilTest, IsAncestorOf) {
std::unique_ptr<Package> p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
BValue a = fb.Param("a", p->GetBitsType(100));
BValue b = fb.Param("b", p->GetBitsType(100));
BValue c = fb.Param("c", p->GetBitsType(100));
BValue d = fb.Param("d", p->GetBitsType(100));
BValue a_or_b = fb.Or({a, b});
BValue a_or_b_or_d = fb.Or({a_or_b, d});
BValue c_plus_d = fb.Add(c, d);
BValue and_a_b_c = fb.And({a, b, c});
BValue big_or = fb.Or({and_a_b_c, c_plus_d});

EXPECT_FALSE(IsAncestorOf(a.node(), b.node()));
EXPECT_FALSE(IsAncestorOf(c.node(), c.node()));
EXPECT_FALSE(IsAncestorOf(d.node(), c.node()));

EXPECT_TRUE(IsAncestorOf(a.node(), a_or_b.node()));
EXPECT_TRUE(IsAncestorOf(b.node(), a_or_b.node()));
EXPECT_FALSE(IsAncestorOf(c.node(), a_or_b.node()));
EXPECT_FALSE(IsAncestorOf(d.node(), a_or_b.node()));
EXPECT_FALSE(IsAncestorOf(a_or_b.node(), a_or_b.node()));
EXPECT_FALSE(IsAncestorOf(a_or_b_or_d.node(), a_or_b.node()));
EXPECT_FALSE(IsAncestorOf(and_a_b_c.node(), a_or_b.node()));
EXPECT_FALSE(IsAncestorOf(c_plus_d.node(), a_or_b.node()));
EXPECT_FALSE(IsAncestorOf(big_or.node(), a_or_b.node()));

EXPECT_TRUE(IsAncestorOf(a.node(), a_or_b_or_d.node()));
EXPECT_TRUE(IsAncestorOf(b.node(), a_or_b_or_d.node()));
EXPECT_FALSE(IsAncestorOf(c.node(), a_or_b_or_d.node()));
EXPECT_TRUE(IsAncestorOf(d.node(), a_or_b_or_d.node()));
EXPECT_TRUE(IsAncestorOf(a_or_b.node(), a_or_b_or_d.node()));
EXPECT_FALSE(IsAncestorOf(a_or_b_or_d.node(), a_or_b_or_d.node()));
EXPECT_FALSE(IsAncestorOf(and_a_b_c.node(), a_or_b_or_d.node()));
EXPECT_FALSE(IsAncestorOf(c_plus_d.node(), a_or_b_or_d.node()));
EXPECT_FALSE(IsAncestorOf(big_or.node(), a_or_b_or_d.node()));

EXPECT_FALSE(IsAncestorOf(a.node(), c_plus_d.node()));
EXPECT_FALSE(IsAncestorOf(b.node(), c_plus_d.node()));
EXPECT_TRUE(IsAncestorOf(c.node(), c_plus_d.node()));
EXPECT_TRUE(IsAncestorOf(d.node(), c_plus_d.node()));
EXPECT_FALSE(IsAncestorOf(a_or_b.node(), c_plus_d.node()));
EXPECT_FALSE(IsAncestorOf(a_or_b_or_d.node(), c_plus_d.node()));
EXPECT_FALSE(IsAncestorOf(and_a_b_c.node(), c_plus_d.node()));
EXPECT_FALSE(IsAncestorOf(c_plus_d.node(), c_plus_d.node()));
EXPECT_FALSE(IsAncestorOf(big_or.node(), c_plus_d.node()));

EXPECT_TRUE(IsAncestorOf(a.node(), and_a_b_c.node()));
EXPECT_TRUE(IsAncestorOf(b.node(), and_a_b_c.node()));
EXPECT_TRUE(IsAncestorOf(c.node(), and_a_b_c.node()));
EXPECT_FALSE(IsAncestorOf(d.node(), and_a_b_c.node()));
EXPECT_FALSE(IsAncestorOf(a_or_b.node(), and_a_b_c.node()));
EXPECT_FALSE(IsAncestorOf(a_or_b_or_d.node(), and_a_b_c.node()));
EXPECT_FALSE(IsAncestorOf(and_a_b_c.node(), and_a_b_c.node()));
EXPECT_FALSE(IsAncestorOf(c_plus_d.node(), and_a_b_c.node()));
EXPECT_FALSE(IsAncestorOf(big_or.node(), and_a_b_c.node()));

EXPECT_TRUE(IsAncestorOf(a.node(), big_or.node()));
EXPECT_TRUE(IsAncestorOf(b.node(), big_or.node()));
EXPECT_TRUE(IsAncestorOf(c.node(), big_or.node()));
EXPECT_TRUE(IsAncestorOf(d.node(), big_or.node()));
EXPECT_FALSE(IsAncestorOf(a_or_b.node(), big_or.node()));
EXPECT_FALSE(IsAncestorOf(a_or_b_or_d.node(), big_or.node()));
EXPECT_TRUE(IsAncestorOf(and_a_b_c.node(), big_or.node()));
EXPECT_TRUE(IsAncestorOf(c_plus_d.node(), big_or.node()));
EXPECT_FALSE(IsAncestorOf(big_or.node(), big_or.node()));
}

TEST_F(NodeUtilTest, RemoveNodeFromBooleanExpression) {
std::unique_ptr<Package> p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
BValue a = fb.Param("a", p->GetBitsType(1));
BValue b = fb.Param("b", p->GetBitsType(1));
BValue c = fb.Param("c", p->GetBitsType(1));
BValue d = fb.Param("d", p->GetBitsType(1));
BValue a_or_b = fb.Or({a, b});
BValue a_or_b_or_d = fb.Or({a_or_b, d});
BValue c_plus_d = fb.Add(c, d);
BValue and_a_b_c = fb.And({a, b, c});
BValue big_nor = fb.Nor({and_a_b_c, c_plus_d});

EXPECT_THAT(RemoveNodeFromBooleanExpression(a.node(), a_or_b.node(), true),
IsOkAndHolds(m::Or(m::Literal(1), m::Param("b"))));
EXPECT_THAT(RemoveNodeFromBooleanExpression(a.node(), a_or_b.node(), false),
IsOkAndHolds(m::Or(m::Literal(0), m::Param("b"))));
EXPECT_THAT(RemoveNodeFromBooleanExpression(c.node(), a_or_b.node(), true),
IsOkAndHolds(a_or_b.node()));
EXPECT_THAT(RemoveNodeFromBooleanExpression(c.node(), a_or_b.node(), false),
IsOkAndHolds(a_or_b.node()));

EXPECT_THAT(
RemoveNodeFromBooleanExpression(b.node(), a_or_b_or_d.node(), true),
IsOkAndHolds(m::Or(m::Or(m::Param("a"), m::Literal(1)), m::Param("d"))));
EXPECT_THAT(
RemoveNodeFromBooleanExpression(b.node(), a_or_b_or_d.node(), false),
IsOkAndHolds(m::Or(m::Or(m::Param("a"), m::Literal(0)), m::Param("d"))));

EXPECT_THAT(RemoveNodeFromBooleanExpression(c.node(), c_plus_d.node(), true),
IsOkAndHolds(m::Literal(1)));
EXPECT_THAT(RemoveNodeFromBooleanExpression(c.node(), c_plus_d.node(), false),
IsOkAndHolds(m::Literal(0)));

EXPECT_THAT(
RemoveNodeFromBooleanExpression(a.node(), big_nor.node(), true),
IsOkAndHolds(m::Nor(m::And(m::Literal(0), m::Param("b"), m::Param("c")),
m::Add(m::Param("c"), m::Param("d")))));
EXPECT_THAT(
RemoveNodeFromBooleanExpression(a.node(), big_nor.node(), false),
IsOkAndHolds(m::Nor(m::And(m::Literal(1), m::Param("b"), m::Param("c")),
m::Add(m::Param("c"), m::Param("d")))));

EXPECT_THAT(
RemoveNodeFromBooleanExpression(b.node(), big_nor.node(), true),
IsOkAndHolds(m::Nor(m::And(m::Param("a"), m::Literal(0), m::Param("c")),
m::Add(m::Param("c"), m::Param("d")))));
EXPECT_THAT(
RemoveNodeFromBooleanExpression(b.node(), big_nor.node(), false),
IsOkAndHolds(m::Nor(m::And(m::Param("a"), m::Literal(1), m::Param("c")),
m::Add(m::Param("c"), m::Param("d")))));

EXPECT_THAT(
RemoveNodeFromBooleanExpression(c.node(), big_nor.node(), true),
IsOkAndHolds(m::Nor(m::And(m::Param("a"), m::Param("b"), m::Literal(0)),
m::Literal(0))));
EXPECT_THAT(
RemoveNodeFromBooleanExpression(c.node(), big_nor.node(), false),
IsOkAndHolds(m::Nor(m::And(m::Param("a"), m::Param("b"), m::Literal(1)),
m::Literal(1))));

EXPECT_THAT(
RemoveNodeFromBooleanExpression(d.node(), big_nor.node(), true),
IsOkAndHolds(m::Nor(m::And(m::Param("a"), m::Param("b"), m::Param("c")),
m::Literal(0))));
EXPECT_THAT(
RemoveNodeFromBooleanExpression(d.node(), big_nor.node(), false),
IsOkAndHolds(m::Nor(m::And(m::Param("a"), m::Param("b"), m::Param("c")),
m::Literal(1))));
}

} // namespace
} // namespace xls
Loading

0 comments on commit 3466ae5

Please sign in to comment.